From c34175fbdcfcc319590e93006f258d8cfbb3c080 Mon Sep 17 00:00:00 2001
From: "Huu Le (Lee)" <39040748+leehuwuj@users.noreply.github.com>
Date: Thu, 11 Jan 2024 10:03:12 +0700
Subject: [PATCH] Feat[cl]: Add postgresql vector store for fastapi (#318)

---
 .../vectordbs/python/pg/__init__.py           |  0
 .../vectordbs/python/pg/constants.py          |  5 +++
 .../components/vectordbs/python/pg/context.py | 14 +++++++
 .../vectordbs/python/pg/generate.py           | 38 ++++++++++++++++++
 .../components/vectordbs/python/pg/index.py   | 16 ++++++++
 .../components/vectordbs/python/pg/utils.py   | 23 +++++++++++
 templates/python.ts                           | 40 ++++++++++++++++---
 7 files changed, 131 insertions(+), 5 deletions(-)
 create mode 100644 templates/components/vectordbs/python/pg/__init__.py
 create mode 100644 templates/components/vectordbs/python/pg/constants.py
 create mode 100644 templates/components/vectordbs/python/pg/context.py
 create mode 100644 templates/components/vectordbs/python/pg/generate.py
 create mode 100644 templates/components/vectordbs/python/pg/index.py
 create mode 100644 templates/components/vectordbs/python/pg/utils.py

diff --git a/templates/components/vectordbs/python/pg/__init__.py b/templates/components/vectordbs/python/pg/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/templates/components/vectordbs/python/pg/constants.py b/templates/components/vectordbs/python/pg/constants.py
new file mode 100644
index 00000000..efc5105a
--- /dev/null
+++ b/templates/components/vectordbs/python/pg/constants.py
@@ -0,0 +1,5 @@
+DATA_DIR = "data"  # directory containing the documents to index
+CHUNK_SIZE = 1024
+CHUNK_OVERLAP = 20
+PGVECTOR_SCHEMA = "public"
+PGVECTOR_TABLE = "llamaindex_embedding"
\ No newline at end of file
diff --git a/templates/components/vectordbs/python/pg/context.py b/templates/components/vectordbs/python/pg/context.py
new file mode 100644
index 00000000..ceb8a50a
--- /dev/null
+++ b/templates/components/vectordbs/python/pg/context.py
@@ -0,0 +1,14 @@
+from llama_index import ServiceContext
+
+from app.context import create_base_context
+from app.engine.constants import CHUNK_SIZE, CHUNK_OVERLAP
+
+
+def create_service_context():
+    base = create_base_context()
+    return ServiceContext.from_defaults(
+        llm=base.llm,
+        embed_model=base.embed_model,
+        chunk_size=CHUNK_SIZE,
+        chunk_overlap=CHUNK_OVERLAP,
+    )
diff --git a/templates/components/vectordbs/python/pg/generate.py b/templates/components/vectordbs/python/pg/generate.py
new file mode 100644
index 00000000..ee07e7a4
--- /dev/null
+++ b/templates/components/vectordbs/python/pg/generate.py
@@ -0,0 +1,38 @@
+from dotenv import load_dotenv
+
+load_dotenv()
+import logging
+
+from app.engine.constants import DATA_DIR
+from app.engine.context import create_service_context
+from app.engine.utils import init_pg_vector_store_from_env
+
+from llama_index import (
+    SimpleDirectoryReader,
+    VectorStoreIndex,
+    StorageContext,
+)
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger()
+
+
+def generate_datasource(service_context):
+    logger.info("Creating new index")
+    # load the documents and create the index
+    documents = SimpleDirectoryReader(DATA_DIR).load_data()
+    store = init_pg_vector_store_from_env()
+    storage_context = StorageContext.from_defaults(vector_store=store)
+    VectorStoreIndex.from_documents(
+        documents,
+        service_context=service_context,
+        storage_context=storage_context,
+        show_progress=True,  # this will show you a progress bar as the embeddings are created
+    )
+    logger.info(
+        f"Successfully created embeddings in the PG vector store, schema={store.schema_name} table={store.table_name}"
+    )
+
+
+if __name__ == "__main__":
+    generate_datasource(create_service_context())
diff --git a/templates/components/vectordbs/python/pg/index.py b/templates/components/vectordbs/python/pg/index.py
new file mode 100644
index 00000000..5c902772
--- /dev/null
+++ b/templates/components/vectordbs/python/pg/index.py
@@ -0,0 +1,16 @@
+import logging
+from llama_index import (
+    VectorStoreIndex,
+)
+from app.engine.context import create_service_context
+from app.engine.utils import init_pg_vector_store_from_env
+
+
+def get_chat_engine():
+    service_context = create_service_context()
+    logger = logging.getLogger("uvicorn")
+    logger.info("Connecting to index from PGVector...")
+    store = init_pg_vector_store_from_env()
+    index = VectorStoreIndex.from_vector_store(store, service_context)
+    logger.info("Finished connecting to index from PGVector.")
+    return index.as_chat_engine(similarity_top_k=5)
diff --git a/templates/components/vectordbs/python/pg/utils.py b/templates/components/vectordbs/python/pg/utils.py
new file mode 100644
index 00000000..4453d13e
--- /dev/null
+++ b/templates/components/vectordbs/python/pg/utils.py
@@ -0,0 +1,23 @@
+import os
+from llama_index.vector_stores import PGVectorStore
+from urllib.parse import urlparse
+from app.engine.constants import PGVECTOR_SCHEMA, PGVECTOR_TABLE
+
+
+def init_pg_vector_store_from_env():
+    original_conn_string = os.environ.get("PG_CONNECTION_STRING")
+    if original_conn_string is None:
+        raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
+
+    # The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
+    # Update the configured scheme with the psycopg2 and asyncpg schemes
+    original_scheme = urlparse(original_conn_string).scheme + "://"
+    conn_string = original_conn_string.replace(original_scheme, "postgresql+psycopg2://")
+    async_conn_string = original_conn_string.replace(original_scheme, "postgresql+asyncpg://")
+
+    return PGVectorStore(
+        connection_string=conn_string,
+        async_connection_string=async_conn_string,
+        schema_name=PGVECTOR_SCHEMA,
+        table_name=PGVECTOR_TABLE
+    )
diff --git a/templates/python.ts b/templates/python.ts
index b563b2ba..9dba1876 100644
--- a/templates/python.ts
+++ b/templates/python.ts
@@ -7,7 +7,8 @@ import { InstallTemplateArgs, TemplateVectorDB } from "./types";
 
 interface Dependency {
   name: string;
-  version: string;
+  version?: string;
+  extras?: string[];
 }
 
 const getAdditionalDependencies = (vectorDb?: TemplateVectorDB) => {
@@ -21,12 +22,43 @@ const getAdditionalDependencies = (vectorDb?: TemplateVectorDB) => {
       });
       break;
     }
+    case "pg": {
+      dependencies.push({
+        name: "llama-index",
+        extras: ["postgres"],
+      });
+    }
   }
 
   return dependencies;
 };
 
-const addDependencies = async (
+const mergePoetryDependencies = (
+  dependencies: Dependency[],
+  existingDependencies: any,
+) => {
+  for (const dependency of dependencies) {
+    let value = existingDependencies[dependency.name] ?? {};
+
+    // default string value is equal to attribute "version"
+    if (typeof value === "string") {
+      value = { version: value };
+    }
+
+    value.version = dependency.version ?? value.version;
+    value.extras = dependency.extras ?? value.extras;
+
+    if (value.version === undefined) {
+      throw new Error(
+        `Dependency "${dependency.name}" is missing attribute "version"!`,
+      );
+    }
+
+    existingDependencies[dependency.name] = value;
+  }
+};
+
+export const addDependencies = async (
   projectDir: string,
   dependencies: Dependency[],
 ) => {
@@ -42,9 +74,7 @@ const addDependencies = async (
     // Modify toml dependencies
     const tool = fileParsed.tool as any;
     const existingDependencies = tool.poetry.dependencies as any;
-    for (const dependency of dependencies) {
-      existingDependencies[dependency.name] = dependency.version;
-    }
+    mergePoetryDependencies(dependencies, existingDependencies);
 
     // Write toml file
     const newFileContent = stringify(fileParsed);
-- 
GitLab