From d2800980534f1dd2a7811a98b9830d2778be1b54 Mon Sep 17 00:00:00 2001 From: Vits <vittorio.mayellaro.dev@gmail.com> Date: Mon, 5 Aug 2024 09:54:15 +0200 Subject: [PATCH] Added Dockerfile and postgres.compose.yaml for postgres container setup example. Adjusted notebook example in docs. Modified extensively Postgres Index and added docstrings. Removed psycopg and update PostgreSQL and psycopg2 dependencies. --- docs/indexes/postgres/Dockerfile | 13 + docs/indexes/postgres/postgres.compose.yaml | 16 + docs/indexes/postgres/postgres.ipynb | 259 ++++++++++++---- poetry.lock | 126 ++------ pyproject.toml | 2 +- semantic_router/index/postgres.py | 313 +++++++++++++++++--- 6 files changed, 523 insertions(+), 206 deletions(-) create mode 100644 docs/indexes/postgres/Dockerfile diff --git a/docs/indexes/postgres/Dockerfile b/docs/indexes/postgres/Dockerfile new file mode 100644 index 00000000..25c756e1 --- /dev/null +++ b/docs/indexes/postgres/Dockerfile @@ -0,0 +1,13 @@ +FROM postgres:latest + +RUN apt-get update && \ + apt-get install -y build-essential postgresql-server-dev-all git && \ + git clone https://github.com/pgvector/pgvector.git && \ + cd pgvector && \ + make && \ + make install && \ + cd .. && \ + rm -rf pgvector && \ + apt-get remove -y build-essential postgresql-server-dev-all git && \ + apt-get autoremove -y && \ + apt-get clean diff --git a/docs/indexes/postgres/postgres.compose.yaml b/docs/indexes/postgres/postgres.compose.yaml index e69de29b..cc0bdb29 100644 --- a/docs/indexes/postgres/postgres.compose.yaml +++ b/docs/indexes/postgres/postgres.compose.yaml @@ -0,0 +1,16 @@ +version: '3.8' + +services: + pgvector: + build: . + environment: + POSTGRES_DB: semantic_router + POSTGRES_USER: admin + POSTGRES_PASSWORD: root + volumes: + - db_data:/var/lib/postgresql/data + ports: + - "5432:5432" + +volumes: + db_data: diff --git a/docs/indexes/postgres/postgres.ipynb b/docs/indexes/postgres/postgres.ipynb index c0de3ddf..db1e2663 100644 --- a/docs/indexes/postgres/postgres.ipynb +++ b/docs/indexes/postgres/postgres.ipynb @@ -4,56 +4,53 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Postgres pgvector index example\n", + "# Postgres pgvector Index Example\n", "\n", - "**Note**: You'll require docker to be installed locally, or a remote instance of Postgres with the pgvector extension installed." + "**Note**: You'll require Docker to be installed locally, or a remote instance of Postgres with the pgvector extension installed." ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Start the Postgres instance with the pgvector extension using Docker Compose\n", + "!echo \"Running Docker Compose to start Postgres instance with pgvector extension\"\n", + "!docker compose -f ./docs/indexes/postgres/postgres.compose.yaml up -d" + ] + }, + { + "cell_type": "code", + "execution_count": 1, "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "/Users/frankjames/Projects/semantic-router/docs/indexes/postgres\n", - "\u001b[1A\u001b[1B\u001b[0G\u001b[?25l[+] Running 0/0\n", - " \u001b[33m⠋\u001b[0m Container pgvector Starting \u001b[34m0.1s \u001b[0m\n", - "\u001b[?25h\u001b[1A\u001b[1A\u001b[0G\u001b[?25l\u001b[34m[+] Running 1/1\u001b[0m\n", - " \u001b[32m✔\u001b[0m Container pgvector \u001b[32mStarted\u001b[0m \u001b[34m0.1s \u001b[0m\n", - "\u001b[?25h" + "/home/vittorio/.cache/pypoetry/virtualenvs/semantic-router-EZimjtOW-py3.10/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ - "!echo \"Running docker compose to start postgres instance with pgvector extension\"\n", - "!docker compose -f ./postgres.compose.yaml up --detach" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "# Import necessary modules\n", "from semantic_router import Route\n", "\n", - "# we could use this as a guide for our chatbot to avoid political conversations\n", + "# Define routes to guide the chatbot's responses\n", "politics = Route(\n", " name=\"politics\",\n", " utterances=[\n", " \"isn't politics the best thing ever\",\n", " \"why don't you tell me about your political opinions\",\n", - " \"don't you just love the president\" \"don't you just hate the president\",\n", + " \"don't you just love the president\", \"don't you just hate the president\",\n", " \"they're going to destroy this country!\",\n", " \"they will save the country!\",\n", " ],\n", ")\n", "\n", - "# this could be used as an indicator to our chatbot to switch to a more\n", - "# conversational prompt\n", + "# Define a chitchat route for general conversations\n", "chitchat = Route(\n", " name=\"chitchat\",\n", " utterances=[\n", @@ -65,21 +62,23 @@ " ],\n", ")\n", "\n", - "# we place both of our decisions together into single list\n", + "# Combine both routes into a single list\n", "routes = [politics, chitchat]" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ + "# Import necessary modules\n", "import os\n", "from getpass import getpass\n", "from semantic_router.encoders import OpenAIEncoder\n", "\n", - "# get at platform.openai.com\n", + "# Set OpenAI API key for the encoder\n", + "# You can get your API key from platform.openai.com\n", "os.environ[\"OPENAI_API_KEY\"] = os.environ.get(\"OPENAI_API_KEY\") or getpass(\n", " \"Enter OpenAI API key: \"\n", ")\n", @@ -88,27 +87,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ + "# Import the Postgres index module\n", "from semantic_router.index.postgres import PostgresIndex\n", "import os\n", "\n", + "# Set Postgres connection string\n", "os.environ[\"POSTGRES_CONNECTION_STRING\"] = (\n", - " \"postgresql://user:password@localhost:5432/semantic_router\"\n", + " \"postgresql://admin:root@localhost:5432/semantic_router\"\n", ")\n", + "# Initialize the Postgres index\n", "postgres_index = PostgresIndex()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ + "# Import the RouteLayer class\n", "from semantic_router.layer import RouteLayer\n", "\n", + "# Initialize the RouteLayer with the encoder, routes, and index\n", "rl = RouteLayer(encoder=encoder, routes=routes, index=postgres_index)" ] }, @@ -116,24 +120,49 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "## Check Route Layer and Index Information\n", "We can check our route layer and index information." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "['politics', 'chitchat']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "# List the names of the defined routes\n", "rl.list_route_names()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "# Check the total number of entries in the index\n", "len(rl.index)" ] }, @@ -141,15 +170,32 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "## View All Records for a Given Route\n", "We can also view all of the records for a given route:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "['politics#af3cd8c8-defb-5940-b66a-d59e8dbbada7',\n", + " 'politics#ce06418e-133f-5484-83fb-d8d3f136fc5d',\n", + " 'politics#c24f2d13-c2d5-5ddf-9169-03646cc3dad0',\n", + " 'politics#5f21f564-3d1e-58d9-a0b5-9e0c79e33d72',\n", + " 'politics#f248c2fa-4ab0-5bf9-8678-b25f7945705a']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "# Get all records for the 'politics' route\n", "rl.index._get_route_ids(route_name=\"politics\")" ] }, @@ -157,33 +203,70 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "And query:" + "## Query the Routes\n", + "We can query the routes to get the appropriate responses." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='politics', function_call=None, similarity_score=None)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "rl(\"I like voting\")" + "# Query the route layer with a statement related to politics\n", + "rl(\"I like voting. What do you think about the president?\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'chitchat'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "# Query the route layer with a chitchat statement\n", "rl(\"how's the weather today?\").name" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'chitchat'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "# Query the route layer with another chitchat statement\n", "rl(\"where are you?\").name" ] }, @@ -191,24 +274,49 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can delete or update routes." + "## Delete or Update Routes\n", + "We can delete or update routes as needed." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "# Check the total number of entries in the index before deletion\n", "len(rl.index)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "5" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "# Delete the 'chitchat' route and check the index length after deletion\n", "import time\n", "\n", "rl.delete(route_name=\"chitchat\")\n", @@ -218,28 +326,65 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "None\n" + ] + } + ], "source": [ - "rl(\"how's the weather today?\").name" + "# Attempt to query the deleted 'chitchat' route\n", + "print(rl(\"how's the weather today?\").name)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[('politics', \"isn't politics the best thing ever\"),\n", + " ('politics', \"why don't you tell me about your political opinions\"),\n", + " ('politics', \"don't you just love the president\"),\n", + " ('politics', \"they're going to destroy this country!\"),\n", + " ('politics', 'they will save the country!')]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "# Get all the current routes and their utterances\n", "rl.index.get_routes()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'type': 'postgres', 'dimensions': 1536, 'total_vector_count': 5}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "# Describe the index to get details like type, dimensions, and total vector count\n", "rl.index.describe()" ] } @@ -260,7 +405,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/poetry.lock b/poetry.lock index 2d116d98..81d55634 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -2771,6 +2771,7 @@ description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, ] @@ -3281,99 +3282,25 @@ files = [ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] [[package]] -name = "psycopg" -version = "3.1.19" -description = "PostgreSQL database adapter for Python" -optional = true -python-versions = ">=3.7" -files = [ - {file = "psycopg-3.1.19-py3-none-any.whl", hash = "sha256:dca5e5521c859f6606686432ae1c94e8766d29cc91f2ee595378c510cc5b0731"}, - {file = "psycopg-3.1.19.tar.gz", hash = "sha256:92d7b78ad82426cdcf1a0440678209faa890c6e1721361c2f8901f0dccd62961"}, -] - -[package.dependencies] -psycopg-binary = {version = "3.1.19", optional = true, markers = "implementation_name != \"pypy\" and extra == \"binary\""} -typing-extensions = ">=4.1" -tzdata = {version = "*", markers = "sys_platform == \"win32\""} - -[package.extras] -binary = ["psycopg-binary (==3.1.19)"] -c = ["psycopg-c (==3.1.19)"] -dev = ["black (>=24.1.0)", "codespell (>=2.2)", "dnspython (>=2.1)", "flake8 (>=4.0)", "mypy (>=1.4.1)", "types-setuptools (>=57.4)", "wheel (>=0.37)"] -docs = ["Sphinx (>=5.0)", "furo (==2022.6.21)", "sphinx-autobuild (>=2021.3.14)", "sphinx-autodoc-typehints (>=1.12)"] -pool = ["psycopg-pool"] -test = ["anyio (>=3.6.2,<4.0)", "mypy (>=1.4.1)", "pproxy (>=2.7)", "pytest (>=6.2.5)", "pytest-cov (>=3.0)", "pytest-randomly (>=3.5)"] - -[[package]] -name = "psycopg-binary" -version = "3.1.19" -description = "PostgreSQL database adapter for Python -- C optimisation distribution" -optional = true +name = "psycopg2" +version = "2.9.9" +description = "psycopg2 - Python-PostgreSQL Database Adapter" +optional = false python-versions = ">=3.7" files = [ - {file = "psycopg_binary-3.1.19-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7204818f05151dd08f8f851defb01972ec9d2cc925608eb0de232563f203f354"}, - {file = "psycopg_binary-3.1.19-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4e67fd86758dbeac85641419a54f84d74495a8683b58ad5dfad08b7fc37a8f"}, - {file = "psycopg_binary-3.1.19-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e12173e34b176e93ad2da913de30f774d5119c2d4d4640c6858d2d77dfa6c9bf"}, - {file = "psycopg_binary-3.1.19-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:052f5193304066318853b4b2e248f523c8f52b371fc4e95d4ef63baee3f30955"}, - {file = "psycopg_binary-3.1.19-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29008f3f8977f600b8a7fb07c2e041b01645b08121760609cc45e861a0364dc9"}, - {file = "psycopg_binary-3.1.19-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c6a9a651a08d876303ed059c9553df18b3c13c3406584a70a8f37f1a1fe2709"}, - {file = "psycopg_binary-3.1.19-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:91a645e6468c4f064b7f4f3b81074bdd68fe5aa2b8c5107de15dcd85ba6141be"}, - {file = "psycopg_binary-3.1.19-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5c6956808fd5cf0576de5a602243af8e04594b25b9a28675feddc71c5526410a"}, - {file = "psycopg_binary-3.1.19-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:1622ca27d5a7a98f7d8f35e8b146dc7efda4a4b6241d2edf7e076bd6bcecbeb4"}, - {file = "psycopg_binary-3.1.19-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a100482950a55228f648bd382bb71bfaff520002f29845274fccbbf02e28bd52"}, - {file = "psycopg_binary-3.1.19-cp310-cp310-win_amd64.whl", hash = "sha256:955ca8905c0251fc4af7ce0a20999e824a25652f53a558ab548b60969f1f368e"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1cf49e91dcf699b8a449944ed898ef1466b39b92720613838791a551bc8f587a"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:964c307e400c5f33fa762ba1e19853e048814fcfbd9679cc923431adb7a2ead2"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3433924e1b14074798331dc2bfae2af452ed7888067f2fc145835704d8981b15"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:00879d4c6be4b3afc510073f48a5e960f797200e261ab3d9bd9b7746a08c669d"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:34a6997c80f86d3dd80a4f078bb3b200079c47eeda4fd409d8899b883c90d2ac"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0106e42b481677c41caa69474fe530f786dcef88b11b70000f0e45a03534bc8f"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81efe09ba27533e35709905c3061db4dc9fb814f637360578d065e2061fbb116"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d312d6dddc18d9c164e1893706269c293cba1923118349d375962b1188dafb01"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:bfd2c734da9950f7afaad5f132088e0e1478f32f042881fca6651bb0c8d14206"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8a732610a5a6b4f06dadcf9288688a8ff202fd556d971436a123b7adb85596e2"}, - {file = "psycopg_binary-3.1.19-cp311-cp311-win_amd64.whl", hash = "sha256:321814a9a3ad785855a821b842aba08ca1b7de7dfb2979a2f0492dca9ec4ae70"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4aa0ca13bb8a725bb6d12c13999217fd5bc8b86a12589f28a74b93e076fbb959"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:469424e354ebcec949aa6aa30e5a9edc352a899d9a68ad7a48f97df83cc914cf"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04f5349313529ae1f1c42fe1aa0443faaf50fdf12d13866c2cc49683bfa53d0"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:959feabddc7fffac89b054d6f23f3b3c62d7d3c90cd414a02e3747495597f150"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9da624a6ca4bc5f7fa1f03f8485446b5b81d5787b6beea2b4f8d9dbef878ad7"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1823221a6b96e38b15686170d4fc5b36073efcb87cce7d3da660440b50077f6"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:866db42f986298f0cf15d805225eb8df2228bf19f7997d7f1cb5f388cbfc6a0f"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:738c34657305b5973af6dbb6711b07b179dfdd21196d60039ca30a74bafe9648"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb9758473200384a04374d0e0cac6f451218ff6945a024f65a1526802c34e56e"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0e991632777e217953ac960726158987da684086dd813ac85038c595e7382c91"}, - {file = "psycopg_binary-3.1.19-cp312-cp312-win_amd64.whl", hash = "sha256:1d87484dd42c8783c44a30400949efb3d81ef2487eaa7d64d1c54df90cf8b97a"}, - {file = "psycopg_binary-3.1.19-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:d1d1723d7449c12bb61aca7eb6e0c6ab2863cd8dc0019273cc4d4a1982f84bdb"}, - {file = "psycopg_binary-3.1.19-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e538a8671005641fa195eab962f85cf0504defbd3b548c4c8fc27102a59f687b"}, - {file = "psycopg_binary-3.1.19-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c50592bc8517092f40979e4a5d934f96a1737a77724bb1d121eb78b614b30fc8"}, - {file = "psycopg_binary-3.1.19-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:95f16ae82bc242b76cd3c3e5156441e2bd85ff9ec3a9869d750aad443e46073c"}, - {file = "psycopg_binary-3.1.19-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aebd1e98e865e9a28ce0cb2c25b7dfd752f0d1f0a423165b55cd32a431dcc0f4"}, - {file = "psycopg_binary-3.1.19-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:49cd7af7d49e438a39593d1dd8cab106a1912536c2b78a4d814ebdff2786094e"}, - {file = "psycopg_binary-3.1.19-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:affebd61aa3b7a8880fd4ac3ee94722940125ff83ff485e1a7c76be9adaabb38"}, - {file = "psycopg_binary-3.1.19-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:d1bac282f140fa092f2bbb6c36ed82270b4a21a6fc55d4b16748ed9f55e50fdb"}, - {file = "psycopg_binary-3.1.19-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1285aa54449e362b1d30d92b2dc042ad3ee80f479cc4e323448d0a0a8a1641fa"}, - {file = "psycopg_binary-3.1.19-cp37-cp37m-win_amd64.whl", hash = "sha256:6cff31af8155dc9ee364098a328bab688c887c732c66b8d027e5b03818ca0287"}, - {file = "psycopg_binary-3.1.19-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d9b689c4a17dd3130791dcbb8c30dbf05602f7c2d56c792e193fb49adc7bf5f8"}, - {file = "psycopg_binary-3.1.19-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:017518bd2de4851adc826a224fb105411e148ad845e11355edd6786ba3dfedf5"}, - {file = "psycopg_binary-3.1.19-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c35fd811f339a3cbe7f9b54b2d9a5e592e57426c6cc1051632a62c59c4810208"}, - {file = "psycopg_binary-3.1.19-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38ed45ec9673709bfa5bc17f140e71dd4cca56d4e58ef7fd50d5a5043a4f55c6"}, - {file = "psycopg_binary-3.1.19-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:433f1c256108f9e26f480a8cd6ddb0fb37dbc87d7f5a97e4540a9da9b881f23f"}, - {file = "psycopg_binary-3.1.19-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ed61e43bf5dc8d0936daf03a19fef3168d64191dbe66483f7ad08c4cea0bc36b"}, - {file = "psycopg_binary-3.1.19-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4ae8109ff9fdf1fa0cb87ab6645298693fdd2666a7f5f85660df88f6965e0bb7"}, - {file = "psycopg_binary-3.1.19-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a53809ee02e3952fae7977c19b30fd828bd117b8f5edf17a3a94212feb57faaf"}, - {file = "psycopg_binary-3.1.19-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9d39d5ffc151fb33bcd55b99b0e8957299c0b1b3e5a1a5f4399c1287ef0051a9"}, - {file = "psycopg_binary-3.1.19-cp38-cp38-win_amd64.whl", hash = "sha256:e14bc8250000921fcccd53722f86b3b3d1b57db901e206e49e2ab2afc5919c2d"}, - {file = "psycopg_binary-3.1.19-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cd88c5cea4efe614d5004fb5f5dcdea3d7d59422be796689e779e03363102d24"}, - {file = "psycopg_binary-3.1.19-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:621a814e60825162d38760c66351b4df679fd422c848b7c2f86ad399bff27145"}, - {file = "psycopg_binary-3.1.19-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46e50c05952b59a214e27d3606f6d510aaa429daed898e16b8a37bfbacc81acc"}, - {file = "psycopg_binary-3.1.19-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:03354a9db667c27946e70162cb0042c3929154167f3678a30d23cebfe0ad55b5"}, - {file = "psycopg_binary-3.1.19-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:703c2f3b79037581afec7baa2bdbcb0a1787f1758744a7662099b0eca2d721cb"}, - {file = "psycopg_binary-3.1.19-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6469ebd9e93327e9f5f36dcf8692fb1e7aeaf70087c1c15d4f2c020e0be3a891"}, - {file = "psycopg_binary-3.1.19-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:85bca9765c04b6be90cb46e7566ffe0faa2d7480ff5c8d5e055ac427f039fd24"}, - {file = "psycopg_binary-3.1.19-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:a836610d5c75e9cff98b9fdb3559c007c785c09eaa84a60d5d10ef6f85f671e8"}, - {file = "psycopg_binary-3.1.19-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ef8de7a1d9fb3518cc6b58e3c80b75a824209ad52b90c542686c912db8553dad"}, - {file = "psycopg_binary-3.1.19-cp39-cp39-win_amd64.whl", hash = "sha256:76fcd33342f38e35cd6b5408f1bc117d55ab8b16e5019d99b6d3ce0356c51717"}, + {file = "psycopg2-2.9.9-cp310-cp310-win32.whl", hash = "sha256:38a8dcc6856f569068b47de286b472b7c473ac7977243593a288ebce0dc89516"}, + {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, + {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, + {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, + {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, + {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, + {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, + {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, + {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, + {file = "psycopg2-2.9.9-cp38-cp38-win_amd64.whl", hash = "sha256:bac58c024c9922c23550af2a581998624d6e02350f4ae9c5f0bc642c633a2d5e"}, + {file = "psycopg2-2.9.9-cp39-cp39-win32.whl", hash = "sha256:c92811b2d4c9b6ea0285942b2e7cac98a59e166d59c588fe5cfe1eda58e72d59"}, + {file = "psycopg2-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:de80739447af31525feddeb8effd640782cf5998e1a4e9192ebdf829717e3913"}, + {file = "psycopg2-2.9.9.tar.gz", hash = "sha256:d1454bde93fb1e224166811694d600e746430c006fbb031ea06ecc2ea41bf156"}, ] [[package]] @@ -4849,17 +4776,6 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] -[[package]] -name = "tzdata" -version = "2024.1" -description = "Provider of IANA time zone data" -optional = true -python-versions = ">=2" -files = [ - {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, - {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, -] - [[package]] name = "urllib3" version = "1.26.19" @@ -5054,7 +4970,7 @@ hybrid = ["pinecone-text"] local = ["llama-cpp-python", "tokenizers", "torch", "transformers"] mistralai = ["mistralai"] pinecone = ["pinecone-client"] -postgres = ["psycopg"] +postgres = [] processing = ["matplotlib"] qdrant = ["qdrant-client"] vision = ["pillow", "torch", "torchvision", "transformers"] @@ -5062,4 +4978,4 @@ vision = ["pillow", "torch", "torchvision", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "02d8883fa539b8fd95e6e5c4d6bfba4e3b84509fb96cd91b796db40397046a20" +content-hash = "078a0e297649999123d74c1ee0e4fee597471fdf5734da4b51307e2529521dd7" diff --git a/pyproject.toml b/pyproject.toml index 3c3328b0..2246805c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,9 +41,9 @@ google-cloud-aiplatform = {version = "^1.45.0", optional = true} requests-mock = "^1.12.1" boto3 = { version = "^1.34.98", optional = true } botocore = {version = "^1.34.110", optional = true} -psycopg = { version = "^3.1.19", optional = true, extras = ["binary"] } aiohttp = "^3.9.5" fastembed = {version = "^0.3.0", optional = true} +psycopg2 = "^2.9.9" [tool.poetry.extras] hybrid = ["pinecone-text"] diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py index 710fd07c..13c67ffd 100644 --- a/semantic_router/index/postgres.py +++ b/semantic_router/index/postgres.py @@ -1,16 +1,21 @@ -from semantic_router.index.base import BaseIndex -import psycopg -from psycopg.connection import Connection -from pydantic import BaseModel -from typing import Any, List, Optional, Tuple, Dict, Union -from enum import Enum -from semantic_router.schema import Metric -import numpy as np import os import uuid +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import psycopg2 +from pydantic import BaseModel + +from semantic_router.index.base import BaseIndex +from semantic_router.schema import Metric class MetricPgVecOperatorMap(Enum): + """ + Enum to map the metric to PostgreSQL vector operators. + """ + cosine = "<=>" dotproduct = "<#>" # inner product euclidean = "<->" # L2 distance @@ -18,28 +23,67 @@ class MetricPgVecOperatorMap(Enum): def parse_vector(vector_str: Union[str, Any]) -> List[float]: + """ + Parses a vector from a string or other representation. + + :param vector_str: The string or object representation of a vector. + :type vector_str: Union[str, Any] + :return: A list of floats representing the vector. + :rtype: List[float] + """ if isinstance(vector_str, str): - vector_str = str(vector_str) vector_str = vector_str.strip('()"[]') return list(map(float, vector_str.split(","))) else: return vector_str +def clean_route_name(route_name: str) -> str: + """ + Cleans and formats the route name by stripping spaces and replacing them with hyphens. + + :param route_name: The original route name. + :type route_name: str + :return: The cleaned and formatted route name. + :rtype: str + """ + return route_name.strip().replace(" ", "-") + + class PostgresIndexRecord(BaseModel): + """ + Model to represent a record in the Postgres index. + """ + id: str = "" route: str utterance: str vector: List[float] def __init__(self, **data) -> None: + """ + Initializes a new Postgres index record with given data. + + :param data: Field values for the record. + :type data: dict + """ super().__init__(**data) clean_route = self.route.strip().replace(" ", "-") + if len(clean_route) > 255: + raise ValueError( + f"The cleaned route name '{clean_route}' exceeds the 255 character limit." + ) route_namespace_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, clean_route) hashed_uuid = uuid.uuid5(route_namespace_uuid, self.utterance) - self.id = str(hashed_uuid) + self.id = clean_route + "#" + str(hashed_uuid) def to_dict(self) -> Dict: + """ + Converts the record to a dictionary. + + :return: A dictionary representation of the record. + :rtype: Dict + """ return { "id": self.id, "vector": self.vector, @@ -50,7 +94,7 @@ class PostgresIndexRecord(BaseModel): class PostgresIndex(BaseIndex): """ - Postgres implementation of Index + Postgres implementation of Index. """ connection_string: Optional[str] = None @@ -59,7 +103,7 @@ class PostgresIndex(BaseIndex): dimensions: int = 1536 metric: Metric = Metric.COSINE namespace: Optional[str] = "" - conn: Optional[Connection] = None + conn: Optional[psycopg2.extensions.connection] = None type: str = "postgres" def __init__( @@ -71,95 +115,136 @@ class PostgresIndex(BaseIndex): metric: Metric = Metric.COSINE, namespace: Optional[str] = "", ): + """ + Initializes the Postgres index with the specified parameters. + + :param connection_string: The connection string for the PostgreSQL database. + :type connection_string: Optional[str] + :param index_prefix: The prefix for the index table name. + :type index_prefix: str + :param index_name: The name of the index table. + :type index_name: str + :param dimensions: The number of dimensions for the vectors. + :type dimensions: int + :param metric: The metric used for vector comparisons. + :type metric: Metric + :param namespace: An optional namespace for the index. + :type namespace: Optional[str] + """ super().__init__() if connection_string: self.connection_string = connection_string else: - connection_string = os.environ["POSTGRES_CONNECTION_STRING"] + connection_string = os.environ.get("POSTGRES_CONNECTION_STRING") if not connection_string: raise ValueError("No connection string provided") - else: - self.connection_string = str(connection_string) + self.connection_string = connection_string self.index_prefix = index_prefix self.index_name = index_name self.dimensions = dimensions self.metric = metric self.namespace = namespace - self.conn = psycopg.connect(conninfo=self.connection_string) + self.conn = psycopg2.connect(dsn=self.connection_string) self.setup_index() def _get_table_name(self) -> str: + """ + Returns the name of the table for the index. + + :return: The table name. + :rtype: str + """ return f"{self.index_prefix}{self.index_name}" def _get_metric_operator(self) -> str: + """ + Returns the PostgreSQL operator for the specified metric. + + :return: The PostgreSQL operator. + :rtype: str + """ return MetricPgVecOperatorMap[self.metric.value].value def _get_score_query(self, embeddings_str: str) -> str: """ Creates the select statement required to return the embeddings distance. + + :param embeddings_str: The string representation of the embeddings. + :type embeddings_str: str + :return: The SQL query part for scoring. + :rtype: str """ - opperator = self._get_metric_operator() + operator = self._get_metric_operator() if self.metric == Metric.COSINE: - return f"1 - (vector {opperator} {embeddings_str}) AS score" + return f"1 - (vector {operator} {embeddings_str}) AS score" elif self.metric == Metric.DOTPRODUCT: - return f"(vector {opperator} {embeddings_str}) * -1 AS score" + return f"(vector {operator} {embeddings_str}) * -1 AS score" elif self.metric == Metric.EUCLIDEAN: - return f"vector {opperator} {embeddings_str} AS score" + return f"vector {operator} {embeddings_str} AS score" elif self.metric == Metric.MANHATTAN: - return f"vector {opperator} {embeddings_str} AS score" + return f"vector {operator} {embeddings_str} AS score" else: raise ValueError(f"Unsupported metric: {self.metric}") def setup_index(self) -> None: + """ + Sets up the index by creating the table and vector extension if they do not exist. + + :raises ValueError: If the existing table's vector dimensions do not match the expected dimensions. + :raises TypeError: If the database connection is not established. + """ table_name = self._get_table_name() if not self._check_embeddings_dimensions(): raise ValueError( f"The length of the vector embeddings in the existing table {table_name} does not match the expected dimensions of {self.dimensions}." ) - if not isinstance(self.conn, psycopg.Connection): + if not isinstance(self.conn, psycopg2.extensions.connection): raise TypeError("Index has not established a connection to Postgres") with self.conn.cursor() as cur: cur.execute( f""" CREATE EXTENSION IF NOT EXISTS vector; CREATE TABLE IF NOT EXISTS {table_name} ( - id uuid PRIMARY KEY, - route TEXT, + id VARCHAR(255) PRIMARY KEY, + route VARCHAR(255), utterance TEXT, vector VECTOR({self.dimensions}) ); COMMENT ON COLUMN {table_name}.vector IS '{self.dimensions}'; - """ + """ ) self.conn.commit() def _check_embeddings_dimensions(self) -> bool: """ - True where the length of the vector embeddings in the table matches the expected dimensions, or no table yet exists. + Checks if the length of the vector embeddings in the table matches the expected dimensions, or if no table exists. + + :return: True if the dimensions match or the table does not exist, False otherwise. + :rtype: bool + :raises ValueError: If the vector column comment does not contain a valid integer. """ table_name = self._get_table_name() - if not isinstance(self.conn, psycopg.Connection): + if not isinstance(self.conn, psycopg2.extensions.connection): raise TypeError("Index has not established a connection to Postgres") with self.conn.cursor() as cur: cur.execute( f"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name='{table_name}');" ) fetch_result = cur.fetchone() - exists = fetch_result[0] if fetch_result is not None else None + exists = fetch_result[0] if fetch_result else None if not exists: return True cur.execute( - f"""SELECT col_description('{table_name}':: regclass, attnum) AS column_comment - FROM pg_attribute - WHERE attrelid = '{table_name}':: regclass - AND attname='vector'""" + f"""SELECT col_description('{table_name}'::regclass, attnum) AS column_comment + FROM pg_attribute + WHERE attrelid = '{table_name}'::regclass + AND attname='vector'""" ) result = cur.fetchone() dimension_comment = result[0] if result else None if dimension_comment: try: vector_length = int(dimension_comment.split()[-1]) - print(vector_length) return vector_length == self.dimensions except ValueError: raise ValueError( @@ -171,6 +256,18 @@ class PostgresIndex(BaseIndex): def add( self, embeddings: List[List[float]], routes: List[str], utterances: List[Any] ) -> None: + """ + Adds vectors to the index. + + :param embeddings: A list of vector embeddings to add. + :type embeddings: List[List[float]] + :param routes: A list of route names corresponding to the embeddings. + :type routes: List[str] + :param utterances: A list of utterances corresponding to the embeddings. + :type utterances: List[Any] + :raises ValueError: If the vector embeddings being added do not match the expected dimensions. + :raises TypeError: If the database connection is not established. + """ table_name = self._get_table_name() new_embeddings_length = len(embeddings[0]) if new_embeddings_length != self.dimensions: @@ -181,7 +278,7 @@ class PostgresIndex(BaseIndex): PostgresIndexRecord(vector=vector, route=route, utterance=utterance) for vector, route, utterance in zip(embeddings, routes, utterances) ] - if not isinstance(self.conn, psycopg.Connection): + if not isinstance(self.conn, psycopg2.extensions.connection): raise TypeError("Index has not established a connection to Postgres") with self.conn.cursor() as cur: cur.executemany( @@ -194,22 +291,38 @@ class PostgresIndex(BaseIndex): self.conn.commit() def delete(self, route_name: str) -> None: + """ + Deletes records with the specified route name. + + :param route_name: The name of the route to delete records for. + :type route_name: str + :raises TypeError: If the database connection is not established. + """ table_name = self._get_table_name() - if not isinstance(self.conn, psycopg.Connection): + if not isinstance(self.conn, psycopg2.extensions.connection): raise TypeError("Index has not established a connection to Postgres") with self.conn.cursor() as cur: cur.execute(f"DELETE FROM {table_name} WHERE route = '{route_name}'") self.conn.commit() def describe(self) -> Dict: + """ + Describes the index by returning its type, dimensions, and total vector count. + + :return: A dictionary containing the index's type, dimensions, and total vector count. + :rtype: Dict + :raises TypeError: If the database connection is not established. + """ table_name = self._get_table_name() - if not isinstance(self.conn, psycopg.Connection): + if not isinstance(self.conn, psycopg2.extensions.connection): raise TypeError("Index has not established a connection to Postgres") with self.conn.cursor() as cur: cur.execute(f"SELECT COUNT(*) FROM {table_name}") count = cur.fetchone() if count is None: count = 0 + else: + count = count[0] # Extract the actual count from the tuple return { "type": self.type, "dimensions": self.dimensions, @@ -223,33 +336,147 @@ class PostgresIndex(BaseIndex): route_filter: Optional[List[str]] = None, ) -> Tuple[np.ndarray, List[str]]: """ - Search the index for the query and return top_k results. + Searches the index for the query vector and returns the top_k results. + + :param vector: The query vector. + :type vector: np.ndarray + :param top_k: The number of top results to return. + :type top_k: int + :param route_filter: Optional list of routes to filter the results by. + :type route_filter: Optional[List[str]] + :return: A tuple containing the scores and routes of the top_k results. + :rtype: Tuple[np.ndarray, List[str]] + :raises TypeError: If the database connection is not established. """ table_name = self._get_table_name() - if not isinstance(self.conn, psycopg.Connection): + if not isinstance(self.conn, psycopg2.extensions.connection): raise TypeError("Index has not established a connection to Postgres") with self.conn.cursor() as cur: filter_query = f" AND route = ANY({route_filter})" if route_filter else "" - # create the string representation of vector + # Create the string representation of vector vector_str = f"'[{','.join(map(str, vector.tolist()))}]'" score_query = self._get_score_query(vector_str) - opperator = self._get_metric_operator() + operator = self._get_metric_operator() cur.execute( - f"SELECT route, {score_query} FROM {table_name} WHERE true{filter_query} ORDER BY vector {opperator} {vector_str} LIMIT {top_k}" + f"SELECT route, {score_query} FROM {table_name} WHERE true{filter_query} ORDER BY vector {operator} {vector_str} LIMIT {top_k}" ) results = cur.fetchall() - print(results) return np.array([result[1] for result in results]), [ result[0] for result in results ] + def _get_route_ids(self, route_name: str): + """ + Retrieves all vector IDs for a specific route. + + :param route_name: The name of the route to retrieve IDs for. + :type route_name: str + :return: A list of vector IDs. + :rtype: List[str] + """ + clean_route = clean_route_name(route_name) + ids, _ = self._get_all(route_name=f"{clean_route}") + return ids + + def _get_all( + self, route_name: Optional[str] = None, include_metadata: bool = False + ): + """ + Retrieves all vector IDs and optionally metadata from the Postgres index. + + :param route_name: Optional route name to filter the results by. + :type route_name: Optional[str] + :param include_metadata: Whether to include metadata in the results. + :type include_metadata: bool + :return: A tuple containing the list of vector IDs and optionally metadata. + :rtype: Tuple[List[str], List[Dict]] + :raises TypeError: If the database connection is not established. + """ + table_name = self._get_table_name() + if not isinstance(self.conn, psycopg2.extensions.connection): + raise TypeError("Index has not established a connection to Postgres") + + query = "SELECT id" + if include_metadata: + query += ", route, utterance" + query += f" FROM {table_name}" + + if route_name: + query += f" WHERE route LIKE '{route_name}%'" + + all_vector_ids = [] + metadata = [] + + with self.conn.cursor() as cur: + cur.execute(query) + results = cur.fetchall() + for row in results: + all_vector_ids.append(row[0]) + if include_metadata: + metadata.append({"sr_route": row[1], "sr_utterance": row[2]}) + + return all_vector_ids, metadata + + def get_routes(self) -> List[Tuple]: + """ + Gets a list of route and utterance objects currently stored in the index. + + :return: A list of (route_name, utterance) tuples. + :rtype: List[Tuple] + """ + # Get all records with metadata + _, metadata = self._get_all(include_metadata=True) + # Create a list of (route_name, utterance) tuples + route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata] + return route_tuples + + def delete_all(self): + """ + Deletes all records from the Postgres index. + + :raises TypeError: If the database connection is not established. + """ + table_name = self._get_table_name() + if not isinstance(self.conn, psycopg2.extensions.connection): + raise TypeError("Index has not established a connection to Postgres") + with self.conn.cursor() as cur: + cur.execute(f"DELETE FROM {table_name}") + self.conn.commit() + def delete_index(self) -> None: + """ + Deletes the entire table for the index. + + :raises TypeError: If the database connection is not established. + """ table_name = self._get_table_name() - if not isinstance(self.conn, psycopg.Connection): + if not isinstance(self.conn, psycopg2.extensions.connection): raise TypeError("Index has not established a connection to Postgres") with self.conn.cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {table_name}") self.conn.commit() + def __len__(self): + """ + Returns the total number of vectors in the index. + + :return: The total number of vectors. + :rtype: int + :raises TypeError: If the database connection is not established. + """ + table_name = self._get_table_name() + if not isinstance(self.conn, psycopg2.extensions.connection): + raise TypeError("Index has not established a connection to Postgres") + with self.conn.cursor() as cur: + cur.execute(f"SELECT COUNT(*) FROM {table_name}") + count = cur.fetchone() + if count is None: + return 0 + return count[0] + class Config: + """ + Configuration for the Pydantic BaseModel. + """ + arbitrary_types_allowed = True -- GitLab