Skip to content
Snippets Groups Projects
Unverified Commit 2de0badc authored by Yannic Spreen-Ledebur's avatar Yannic Spreen-Ledebur Committed by GitHub
Browse files

Add use_jsonb flag for metadata data type in PostgresVectorStore (#8910)

parent 9fbaeed2
No related branches found
No related tags found
No related merge requests found
......@@ -31,13 +31,14 @@ def get_data_model(
text_search_config: str,
cache_okay: bool,
embed_dim: int = 1536,
use_jsonb: bool = False,
) -> Any:
"""
This part create a dynamic sqlalchemy model with a new table.
"""
from pgvector.sqlalchemy import Vector
from sqlalchemy import Column, Computed
from sqlalchemy.dialects.postgresql import BIGINT, JSON, TSVECTOR, VARCHAR
from sqlalchemy.dialects.postgresql import BIGINT, JSON, JSONB, TSVECTOR, VARCHAR
from sqlalchemy.schema import Index
from sqlalchemy.types import TypeDecorator
......@@ -49,13 +50,15 @@ def get_data_model(
class_name = "Data%s" % index_name # dynamic class name
indexname = "%s_idx" % index_name # dynamic class name
metadata_dtype = JSONB if use_jsonb else JSON
if hybrid_search:
class HybridAbstractData(base): # type: ignore
__abstract__ = True # this line is necessary
id = Column(BIGINT, primary_key=True, autoincrement=True)
text = Column(VARCHAR, nullable=False)
metadata_ = Column(JSON)
metadata_ = Column(metadata_dtype)
node_id = Column(VARCHAR)
embedding = Column(Vector(embed_dim)) # type: ignore
text_search_tsv = Column( # type: ignore
......@@ -82,7 +85,7 @@ def get_data_model(
__abstract__ = True # this line is necessary
id = Column(BIGINT, primary_key=True, autoincrement=True)
text = Column(VARCHAR, nullable=False)
metadata_ = Column(JSON)
metadata_ = Column(metadata_dtype)
node_id = Column(VARCHAR)
embedding = Column(Vector(embed_dim)) # type: ignore
......@@ -111,6 +114,7 @@ class PGVectorStore(BasePydanticVectorStore):
cache_ok: bool
perform_setup: bool
debug: bool
use_jsonb: bool
_base: Any = PrivateAttr()
_table_class: Any = PrivateAttr()
......@@ -132,6 +136,7 @@ class PGVectorStore(BasePydanticVectorStore):
cache_ok: bool = False,
perform_setup: bool = True,
debug: bool = False,
use_jsonb: bool = False,
) -> None:
try:
import asyncpg # noqa
......@@ -166,6 +171,7 @@ class PGVectorStore(BasePydanticVectorStore):
text_search_config,
cache_ok,
embed_dim=embed_dim,
use_jsonb=use_jsonb,
)
super().__init__(
......@@ -212,6 +218,7 @@ class PGVectorStore(BasePydanticVectorStore):
cache_ok: bool = False,
perform_setup: bool = True,
debug: bool = False,
use_jsonb: bool = False,
) -> "PGVectorStore":
"""Return connection string from database parameters."""
conn_str = (
......@@ -232,6 +239,7 @@ class PGVectorStore(BasePydanticVectorStore):
cache_ok=cache_ok,
perform_setup=perform_setup,
debug=debug,
use_jsonb=use_jsonb,
)
@property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment