From bbd5b8ddd65983675260343d500a0c0f078ab2f4 Mon Sep 17 00:00:00 2001
From: "Huu Le (Lee)" <39040748+leehuwuj@users.noreply.github.com>
Date: Wed, 22 May 2024 16:12:44 +0700
Subject: [PATCH] fix: Reuse PG vector store to avoid recreating sqlalchemy
 engine (#95)

---
 .changeset/slow-rivers-breathe.md             |  5 ++
 .../components/vectordbs/python/none/index.py | 14 +++++-
 .../vectordbs/python/pg/vectordb.py           | 49 +++++++++++--------
 .../types/streaming/fastapi/pyproject.toml    |  1 +
 4 files changed, 46 insertions(+), 23 deletions(-)
 create mode 100644 .changeset/slow-rivers-breathe.md

diff --git a/.changeset/slow-rivers-breathe.md b/.changeset/slow-rivers-breathe.md
new file mode 100644
index 00000000..e65551ed
--- /dev/null
+++ b/.changeset/slow-rivers-breathe.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+Fix postgres connection leaking issue
diff --git a/templates/components/vectordbs/python/none/index.py b/templates/components/vectordbs/python/none/index.py
index 7e9482c8..f7949d66 100644
--- a/templates/components/vectordbs/python/none/index.py
+++ b/templates/components/vectordbs/python/none/index.py
@@ -1,12 +1,22 @@
-import logging
 import os
+import logging
+from datetime import timedelta
 
+from cachetools import cached, TTLCache
 from llama_index.core.storage import StorageContext
 from llama_index.core.indices import load_index_from_storage
 
 logger = logging.getLogger("uvicorn")
 
 
+@cached(
+    TTLCache(maxsize=10, ttl=timedelta(minutes=5).total_seconds()),
+    key=lambda *args, **kwargs: "global_storage_context",
+)
+def get_storage_context(persist_dir: str) -> StorageContext:
+    return StorageContext.from_defaults(persist_dir=persist_dir)
+
+
 def get_index():
     storage_dir = os.getenv("STORAGE_DIR", "storage")
     # check if storage already exists
@@ -14,7 +24,7 @@ def get_index():
         return None
     # load the existing index
     logger.info(f"Loading index from {storage_dir}...")
-    storage_context = StorageContext.from_defaults(persist_dir=storage_dir)
+    storage_context = get_storage_context(storage_dir)
     index = load_index_from_storage(storage_context)
     logger.info(f"Finished loading index from {storage_dir}")
     return index
diff --git a/templates/components/vectordbs/python/pg/vectordb.py b/templates/components/vectordbs/python/pg/vectordb.py
index f7e0c11a..58e4c7ec 100644
--- a/templates/components/vectordbs/python/pg/vectordb.py
+++ b/templates/components/vectordbs/python/pg/vectordb.py
@@ -5,26 +5,33 @@ from urllib.parse import urlparse
 PGVECTOR_SCHEMA = "public"
 PGVECTOR_TABLE = "llamaindex_embedding"
 
+vector_store: PGVectorStore = None
+
 
 def get_vector_store():
-    original_conn_string = os.environ.get("PG_CONNECTION_STRING")
-    if original_conn_string is None or original_conn_string == "":
-        raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
-
-    # The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
-    # Update the configured scheme with the psycopg2 and asyncpg schemes
-    original_scheme = urlparse(original_conn_string).scheme + "://"
-    conn_string = original_conn_string.replace(
-        original_scheme, "postgresql+psycopg2://"
-    )
-    async_conn_string = original_conn_string.replace(
-        original_scheme, "postgresql+asyncpg://"
-    )
-
-    return PGVectorStore(
-        connection_string=conn_string,
-        async_connection_string=async_conn_string,
-        schema_name=PGVECTOR_SCHEMA,
-        table_name=PGVECTOR_TABLE,
-        embed_dim=int(os.environ.get("EMBEDDING_DIM", 1024)),
-    )
+    global vector_store
+
+    if vector_store is None:
+        original_conn_string = os.environ.get("PG_CONNECTION_STRING")
+        if original_conn_string is None or original_conn_string == "":
+            raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
+
+        # The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
+        # Update the configured scheme with the psycopg2 and asyncpg schemes
+        original_scheme = urlparse(original_conn_string).scheme + "://"
+        conn_string = original_conn_string.replace(
+            original_scheme, "postgresql+psycopg2://"
+        )
+        async_conn_string = original_conn_string.replace(
+            original_scheme, "postgresql+asyncpg://"
+        )
+
+        vector_store = PGVectorStore(
+            connection_string=conn_string,
+            async_connection_string=async_conn_string,
+            schema_name=PGVECTOR_SCHEMA,
+            table_name=PGVECTOR_TABLE,
+            embed_dim=int(os.environ.get("EMBEDDING_DIM", 1024)),
+        )
+
+    return vector_store
diff --git a/templates/types/streaming/fastapi/pyproject.toml b/templates/types/streaming/fastapi/pyproject.toml
index 286b7228..33737fdc 100644
--- a/templates/types/streaming/fastapi/pyproject.toml
+++ b/templates/types/streaming/fastapi/pyproject.toml
@@ -16,6 +16,7 @@ python-dotenv = "^1.0.0"
 aiostream = "^0.5.2"
 llama-index = "0.10.28"
 llama-index-core = "0.10.28"
+cachetools = "^5.3.3"
 
 [build-system]
 requires = ["poetry-core"]
-- 
GitLab