Skip to content
Snippets Groups Projects
Unverified Commit 2de0badc authored by Yannic Spreen-Ledebur's avatar Yannic Spreen-Ledebur Committed by GitHub
Browse files

Add use_jsonb flag for metadata data type in PostgresVectorStore (#8910)

parent 9fbaeed2
No related branches found
No related tags found
No related merge requests found
...@@ -31,13 +31,14 @@ def get_data_model( ...@@ -31,13 +31,14 @@ def get_data_model(
text_search_config: str, text_search_config: str,
cache_okay: bool, cache_okay: bool,
embed_dim: int = 1536, embed_dim: int = 1536,
use_jsonb: bool = False,
) -> Any: ) -> Any:
""" """
This part create a dynamic sqlalchemy model with a new table. This part create a dynamic sqlalchemy model with a new table.
""" """
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from sqlalchemy import Column, Computed from sqlalchemy import Column, Computed
from sqlalchemy.dialects.postgresql import BIGINT, JSON, TSVECTOR, VARCHAR from sqlalchemy.dialects.postgresql import BIGINT, JSON, JSONB, TSVECTOR, VARCHAR
from sqlalchemy.schema import Index from sqlalchemy.schema import Index
from sqlalchemy.types import TypeDecorator from sqlalchemy.types import TypeDecorator
...@@ -49,13 +50,15 @@ def get_data_model( ...@@ -49,13 +50,15 @@ def get_data_model(
class_name = "Data%s" % index_name # dynamic class name class_name = "Data%s" % index_name # dynamic class name
indexname = "%s_idx" % index_name # dynamic class name indexname = "%s_idx" % index_name # dynamic class name
metadata_dtype = JSONB if use_jsonb else JSON
if hybrid_search: if hybrid_search:
class HybridAbstractData(base): # type: ignore class HybridAbstractData(base): # type: ignore
__abstract__ = True # this line is necessary __abstract__ = True # this line is necessary
id = Column(BIGINT, primary_key=True, autoincrement=True) id = Column(BIGINT, primary_key=True, autoincrement=True)
text = Column(VARCHAR, nullable=False) text = Column(VARCHAR, nullable=False)
metadata_ = Column(JSON) metadata_ = Column(metadata_dtype)
node_id = Column(VARCHAR) node_id = Column(VARCHAR)
embedding = Column(Vector(embed_dim)) # type: ignore embedding = Column(Vector(embed_dim)) # type: ignore
text_search_tsv = Column( # type: ignore text_search_tsv = Column( # type: ignore
...@@ -82,7 +85,7 @@ def get_data_model( ...@@ -82,7 +85,7 @@ def get_data_model(
__abstract__ = True # this line is necessary __abstract__ = True # this line is necessary
id = Column(BIGINT, primary_key=True, autoincrement=True) id = Column(BIGINT, primary_key=True, autoincrement=True)
text = Column(VARCHAR, nullable=False) text = Column(VARCHAR, nullable=False)
metadata_ = Column(JSON) metadata_ = Column(metadata_dtype)
node_id = Column(VARCHAR) node_id = Column(VARCHAR)
embedding = Column(Vector(embed_dim)) # type: ignore embedding = Column(Vector(embed_dim)) # type: ignore
...@@ -111,6 +114,7 @@ class PGVectorStore(BasePydanticVectorStore): ...@@ -111,6 +114,7 @@ class PGVectorStore(BasePydanticVectorStore):
cache_ok: bool cache_ok: bool
perform_setup: bool perform_setup: bool
debug: bool debug: bool
use_jsonb: bool
_base: Any = PrivateAttr() _base: Any = PrivateAttr()
_table_class: Any = PrivateAttr() _table_class: Any = PrivateAttr()
...@@ -132,6 +136,7 @@ class PGVectorStore(BasePydanticVectorStore): ...@@ -132,6 +136,7 @@ class PGVectorStore(BasePydanticVectorStore):
cache_ok: bool = False, cache_ok: bool = False,
perform_setup: bool = True, perform_setup: bool = True,
debug: bool = False, debug: bool = False,
use_jsonb: bool = False,
) -> None: ) -> None:
try: try:
import asyncpg # noqa import asyncpg # noqa
...@@ -166,6 +171,7 @@ class PGVectorStore(BasePydanticVectorStore): ...@@ -166,6 +171,7 @@ class PGVectorStore(BasePydanticVectorStore):
text_search_config, text_search_config,
cache_ok, cache_ok,
embed_dim=embed_dim, embed_dim=embed_dim,
use_jsonb=use_jsonb,
) )
super().__init__( super().__init__(
...@@ -212,6 +218,7 @@ class PGVectorStore(BasePydanticVectorStore): ...@@ -212,6 +218,7 @@ class PGVectorStore(BasePydanticVectorStore):
cache_ok: bool = False, cache_ok: bool = False,
perform_setup: bool = True, perform_setup: bool = True,
debug: bool = False, debug: bool = False,
use_jsonb: bool = False,
) -> "PGVectorStore": ) -> "PGVectorStore":
"""Return connection string from database parameters.""" """Return connection string from database parameters."""
conn_str = ( conn_str = (
...@@ -232,6 +239,7 @@ class PGVectorStore(BasePydanticVectorStore): ...@@ -232,6 +239,7 @@ class PGVectorStore(BasePydanticVectorStore):
cache_ok=cache_ok, cache_ok=cache_ok,
perform_setup=perform_setup, perform_setup=perform_setup,
debug=debug, debug=debug,
use_jsonb=use_jsonb,
) )
@property @property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment