From e39be29c8e07e96858a652dcda337faf59209bc8 Mon Sep 17 00:00:00 2001
From: Hao Chen <chenh1987@gmail.com>
Date: Tue, 6 Feb 2024 11:07:15 -0800
Subject: [PATCH] Add persist and load method for Colbert Index (#10477)

* Add persist and load method for Colbert Index

* Remove notebook output
---
 docs/examples/managed/ColBERT_demo.ipynb      | 217 ++++++++++++++++++
 .../indices/managed/colbert_index/base.py     |  35 +++
 2 files changed, 252 insertions(+)
 create mode 100644 docs/examples/managed/ColBERT_demo.ipynb

diff --git a/docs/examples/managed/ColBERT_demo.ipynb b/docs/examples/managed/ColBERT_demo.ipynb
new file mode 100644
index 0000000000..0ebbd770be
--- /dev/null
+++ b/docs/examples/managed/ColBERT_demo.ipynb
@@ -0,0 +1,217 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "5c82c41f",
+   "metadata": {},
+   "source": [
+    "## ColBERT-V2 Managed Index and Retrieval"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a6a077d7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!git -C ColBERT/ pull || git clone https://github.com/stanford-futuredata/ColBERT.git\n",
+    "import sys\n",
+    "\n",
+    "sys.path.insert(0, \"ColBERT/\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "875c98ee",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!pip install faiss-cpu"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "fea6ae14",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!pip install torch==2.1.2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b4b135c4",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from llama_index import SimpleDirectoryReader, ServiceContext\n",
+    "from llama_index.indices import ColbertIndex\n",
+    "from llama_index.llms import OpenAI"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7ae72cb6",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "\n",
+    "OPENAI_API_TOKEN = \"sk-\"\n",
+    "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_TOKEN"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8d7dc1b0",
+   "metadata": {},
+   "source": [
+    "### Build ColBERT-V2 end-to-end Index"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8bc7b689",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "llm = OpenAI(temperature=0, model=\"gpt-3.5-turbo\")\n",
+    "service_context = ServiceContext.from_defaults(llm=llm)\n",
+    "\n",
+    "documents = SimpleDirectoryReader(\"../data/paul_graham/\").load_data()\n",
+    "index = ColbertIndex.from_documents(\n",
+    "    documents=documents, service_context=service_context\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "fd87d213",
+   "metadata": {},
+   "source": [
+    "### Persist ColBERT Index"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ffbf63f4",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "index.persist(persist_dir=\"paul_graham_colbert_index\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e3c337d7",
+   "metadata": {},
+   "source": [
+    "### Load ColBERT Index from disk"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "70800485",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "reload_index = ColbertIndex.load_from_disk(\n",
+    "    persist_dir=\"paul_graham_colbert_index\"\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "49a81996",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "The author attended the Accademia di Belli Arti in Florence.\n"
+     ]
+    }
+   ],
+   "source": [
+    "query_engine = reload_index.as_query_engine(similarity_top_k=3)\n",
+    "response = query_engine.query(\"Which program did this author attend?\")\n",
+    "print(response.response)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "34c51d3e",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Node ID: 6f824032-4065-44fb-a1fb-b48a59e4376b\n",
+      "Text: What I Worked On  February 2021  Before college the two main\n",
+      "things I worked on, outside of school, were writing and programming. I\n",
+      "didn't write essays. I wrote what beginning writers were supposed to\n",
+      "write then, and probably still are: short stories. My stories were\n",
+      "awful. They had hardly any plot, just characters with strong feelings,\n",
+      "which I ...\n",
+      "Score:  30.306\n",
+      "\n",
+      "Node ID: e6dcaf4c-7cc3-43e8-ba5d-96fd2eac85ed\n",
+      "Text: I didn't want to drop out of grad school, but how else was I\n",
+      "going to get out? I remember when my friend Robert Morris got kicked\n",
+      "out of Cornell for writing the internet worm of 1988, I was envious\n",
+      "that he'd found such a spectacular way to get out of grad school.\n",
+      "Then one day in April 1990 a crack appeared in the wall. I ran into\n",
+      "professor Chea...\n",
+      "Score:  26.641\n",
+      "\n",
+      "Node ID: 6aaa4637-79ee-45c0-999e-6cc1997c3fc7\n",
+      "Text: [15] We got 225 applications for the Summer Founders Program,\n",
+      "and we were surprised to find that a lot of them were from people\n",
+      "who'd already graduated, or were about to that spring. Already this\n",
+      "SFP thing was starting to feel more serious than we'd intended.  We\n",
+      "invited about 20 of the 225 groups to interview in person, and from\n",
+      "those we picked...\n",
+      "Score:  26.093\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "for node in response.source_nodes:\n",
+    "    print(node)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "llama-index-env",
+   "language": "python",
+   "name": "llama-index-env"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/llama_index/indices/managed/colbert_index/base.py b/llama_index/indices/managed/colbert_index/base.py
index bf678303ab..3e35dbbfd3 100644
--- a/llama_index/indices/managed/colbert_index/base.py
+++ b/llama_index/indices/managed/colbert_index/base.py
@@ -1,3 +1,6 @@
+import os
+import shutil
+from pathlib import Path
 from typing import Any, Dict, List, Optional, Sequence
 
 from llama_index.core.base_retriever import BaseRetriever
@@ -139,6 +142,38 @@ class ColbertIndex(BaseIndex[IndexDict]):
     #     for doc in docs:
     #         doc.score = math.exp(doc.score) / Z
 
+    def persist(self, persist_dir: str) -> None:
+        # Check if the destination directory exists
+        if os.path.exists(persist_dir):
+            # Remove the existing destination directory
+            shutil.rmtree(persist_dir)
+
+        # Copy PLAID vectors
+        shutil.copytree(
+            Path(self.index_path) / self.index_name, Path(persist_dir) / self.index_name
+        )
+        self._storage_context.persist(persist_dir=persist_dir)
+
+    @classmethod
+    def load_from_disk(cls, persist_dir: str, index_name: str = "") -> "ColbertIndex":
+        from colbert import Searcher
+        from colbert.infra import ColBERTConfig
+
+        colbert_config = ColBERTConfig.load_from_index(Path(persist_dir) / index_name)
+        searcher = Searcher(
+            index=index_name, index_root=persist_dir, config=colbert_config
+        )
+        sc = StorageContext.from_defaults(persist_dir=persist_dir)
+        colbert_index = ColbertIndex(
+            index_struct=sc.index_store.index_structs()[0], storage_context=sc
+        )
+        docs_pos_to_node_id = {
+            int(k): v for k, v in colbert_index.index_struct.nodes_dict.items()
+        }
+        colbert_index._docs_pos_to_node_id = docs_pos_to_node_id
+        colbert_index.store = searcher
+        return colbert_index
+
     def query(self, query_str: str, top_k: int = 10) -> List[NodeWithScore]:
         """
         Query the Colbert v2 + Plaid store.
-- 
GitLab