From 2de0badc7e009a8eabe587c8f48a58fae8cd58e9 Mon Sep 17 00:00:00 2001 From: Yannic Spreen-Ledebur <35889034+spreeni@users.noreply.github.com> Date: Tue, 14 Nov 2023 23:33:21 +0100 Subject: [PATCH] Add use_jsonb flag for metadata data type in PostgresVectorStore (#8910) --- llama_index/vector_stores/postgres.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/llama_index/vector_stores/postgres.py b/llama_index/vector_stores/postgres.py index 7d25103287..0365ff7c59 100644 --- a/llama_index/vector_stores/postgres.py +++ b/llama_index/vector_stores/postgres.py @@ -31,13 +31,14 @@ def get_data_model( text_search_config: str, cache_okay: bool, embed_dim: int = 1536, + use_jsonb: bool = False, ) -> Any: """ This part create a dynamic sqlalchemy model with a new table. """ from pgvector.sqlalchemy import Vector from sqlalchemy import Column, Computed - from sqlalchemy.dialects.postgresql import BIGINT, JSON, TSVECTOR, VARCHAR + from sqlalchemy.dialects.postgresql import BIGINT, JSON, JSONB, TSVECTOR, VARCHAR from sqlalchemy.schema import Index from sqlalchemy.types import TypeDecorator @@ -49,13 +50,15 @@ def get_data_model( class_name = "Data%s" % index_name # dynamic class name indexname = "%s_idx" % index_name # dynamic class name + metadata_dtype = JSONB if use_jsonb else JSON + if hybrid_search: class HybridAbstractData(base): # type: ignore __abstract__ = True # this line is necessary id = Column(BIGINT, primary_key=True, autoincrement=True) text = Column(VARCHAR, nullable=False) - metadata_ = Column(JSON) + metadata_ = Column(metadata_dtype) node_id = Column(VARCHAR) embedding = Column(Vector(embed_dim)) # type: ignore text_search_tsv = Column( # type: ignore @@ -82,7 +85,7 @@ def get_data_model( __abstract__ = True # this line is necessary id = Column(BIGINT, primary_key=True, autoincrement=True) text = Column(VARCHAR, nullable=False) - metadata_ = Column(JSON) + metadata_ = Column(metadata_dtype) node_id = Column(VARCHAR) embedding = Column(Vector(embed_dim)) # type: ignore @@ -111,6 +114,7 @@ class PGVectorStore(BasePydanticVectorStore): cache_ok: bool perform_setup: bool debug: bool + use_jsonb: bool _base: Any = PrivateAttr() _table_class: Any = PrivateAttr() @@ -132,6 +136,7 @@ class PGVectorStore(BasePydanticVectorStore): cache_ok: bool = False, perform_setup: bool = True, debug: bool = False, + use_jsonb: bool = False, ) -> None: try: import asyncpg # noqa @@ -166,6 +171,7 @@ class PGVectorStore(BasePydanticVectorStore): text_search_config, cache_ok, embed_dim=embed_dim, + use_jsonb=use_jsonb, ) super().__init__( @@ -212,6 +218,7 @@ class PGVectorStore(BasePydanticVectorStore): cache_ok: bool = False, perform_setup: bool = True, debug: bool = False, + use_jsonb: bool = False, ) -> "PGVectorStore": """Return connection string from database parameters.""" conn_str = ( @@ -232,6 +239,7 @@ class PGVectorStore(BasePydanticVectorStore): cache_ok=cache_ok, perform_setup=perform_setup, debug=debug, + use_jsonb=use_jsonb, ) @property -- GitLab