Skip to content
Snippets Groups Projects
Unverified Commit a0d52ac6 authored by Daniel Thomas's avatar Daniel Thomas Committed by GitHub
Browse files

Make SupabaseVectorStore a subclass of BasePydanticVectorStore (#11476)

parent 8623877b
No related branches found
No related tags found
No related merge requests found
import logging
import math
from collections import defaultdict
from typing import Any, List
from typing import Any, List, Optional
import vecs
from vecs.collection import Collection
from llama_index.core.constants import DEFAULT_EMBEDDING_DIM
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.vector_stores.types import (
MetadataFilters,
VectorStore,
BasePydanticVectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
)
......@@ -22,7 +24,7 @@ from vecs.collection import CollectionNotFound
logger = logging.getLogger(__name__)
class SupabaseVectorStore(VectorStore):
class SupabaseVectorStore(BasePydanticVectorStore):
"""Supbabase Vector.
In this vector store, embeddings are stored in Postgres table using pgvector.
......@@ -41,6 +43,8 @@ class SupabaseVectorStore(VectorStore):
stores_text = True
flat_metadata = False
_client: Optional[Any] = PrivateAttr()
_collection: Optional[Collection] = PrivateAttr()
def __init__(
self,
......@@ -49,17 +53,17 @@ class SupabaseVectorStore(VectorStore):
dimension: int = DEFAULT_EMBEDDING_DIM,
**kwargs: Any,
) -> None:
"""Init params."""
client = vecs.create_client(postgres_connection_string)
super().__init__()
self._client = vecs.create_client(postgres_connection_string)
try:
self._collection = client.get_collection(name=collection_name)
self._collection = self._client.get_collection(name=collection_name)
except CollectionNotFound:
logger.info(
f"Collection {collection_name} does not exist, "
f"try creating one with dimension={dimension}"
)
self._collection = client.create_collection(
self._collection = self._client.create_collection(
name=collection_name, dimension=dimension
)
......
from llama_index.core.vector_stores.types import VectorStore
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.supabase import SupabaseVectorStore
def test_class():
names_of_base_classes = [b.__name__ for b in SupabaseVectorStore.__mro__]
assert VectorStore.__name__ in names_of_base_classes
assert BasePydanticVectorStore.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment