From 2de0badc7e009a8eabe587c8f48a58fae8cd58e9 Mon Sep 17 00:00:00 2001
From: Yannic Spreen-Ledebur <35889034+spreeni@users.noreply.github.com>
Date: Tue, 14 Nov 2023 23:33:21 +0100
Subject: [PATCH] Add use_jsonb flag for metadata data type in
 PostgresVectorStore (#8910)

---
 llama_index/vector_stores/postgres.py | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/llama_index/vector_stores/postgres.py b/llama_index/vector_stores/postgres.py
index 7d25103287..0365ff7c59 100644
--- a/llama_index/vector_stores/postgres.py
+++ b/llama_index/vector_stores/postgres.py
@@ -31,13 +31,14 @@ def get_data_model(
     text_search_config: str,
     cache_okay: bool,
     embed_dim: int = 1536,
+    use_jsonb: bool = False,
 ) -> Any:
     """
     This part create a dynamic sqlalchemy model with a new table.
     """
     from pgvector.sqlalchemy import Vector
     from sqlalchemy import Column, Computed
-    from sqlalchemy.dialects.postgresql import BIGINT, JSON, TSVECTOR, VARCHAR
+    from sqlalchemy.dialects.postgresql import BIGINT, JSON, JSONB, TSVECTOR, VARCHAR
     from sqlalchemy.schema import Index
     from sqlalchemy.types import TypeDecorator
 
@@ -49,13 +50,15 @@ def get_data_model(
     class_name = "Data%s" % index_name  # dynamic class name
     indexname = "%s_idx" % index_name  # dynamic class name
 
+    metadata_dtype = JSONB if use_jsonb else JSON
+
     if hybrid_search:
 
         class HybridAbstractData(base):  # type: ignore
             __abstract__ = True  # this line is necessary
             id = Column(BIGINT, primary_key=True, autoincrement=True)
             text = Column(VARCHAR, nullable=False)
-            metadata_ = Column(JSON)
+            metadata_ = Column(metadata_dtype)
             node_id = Column(VARCHAR)
             embedding = Column(Vector(embed_dim))  # type: ignore
             text_search_tsv = Column(  # type: ignore
@@ -82,7 +85,7 @@ def get_data_model(
             __abstract__ = True  # this line is necessary
             id = Column(BIGINT, primary_key=True, autoincrement=True)
             text = Column(VARCHAR, nullable=False)
-            metadata_ = Column(JSON)
+            metadata_ = Column(metadata_dtype)
             node_id = Column(VARCHAR)
             embedding = Column(Vector(embed_dim))  # type: ignore
 
@@ -111,6 +114,7 @@ class PGVectorStore(BasePydanticVectorStore):
     cache_ok: bool
     perform_setup: bool
     debug: bool
+    use_jsonb: bool
 
     _base: Any = PrivateAttr()
     _table_class: Any = PrivateAttr()
@@ -132,6 +136,7 @@ class PGVectorStore(BasePydanticVectorStore):
         cache_ok: bool = False,
         perform_setup: bool = True,
         debug: bool = False,
+        use_jsonb: bool = False,
     ) -> None:
         try:
             import asyncpg  # noqa
@@ -166,6 +171,7 @@ class PGVectorStore(BasePydanticVectorStore):
             text_search_config,
             cache_ok,
             embed_dim=embed_dim,
+            use_jsonb=use_jsonb,
         )
 
         super().__init__(
@@ -212,6 +218,7 @@ class PGVectorStore(BasePydanticVectorStore):
         cache_ok: bool = False,
         perform_setup: bool = True,
         debug: bool = False,
+        use_jsonb: bool = False,
     ) -> "PGVectorStore":
         """Return connection string from database parameters."""
         conn_str = (
@@ -232,6 +239,7 @@ class PGVectorStore(BasePydanticVectorStore):
             cache_ok=cache_ok,
             perform_setup=perform_setup,
             debug=debug,
+            use_jsonb=use_jsonb,
         )
 
     @property
-- 
GitLab