From c11d013f088a52cfed45b83031553c13cadcada5 Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Fri, 8 Nov 2024 18:14:58 +0100
Subject: [PATCH] feat: add fast sync check

---
 semantic_router/index/base.py     | 18 +++++++++++++++++
 semantic_router/index/pinecone.py | 33 +++++++++++++++++++++++++++++--
 semantic_router/layer.py          | 33 +++++++++++++++++++++++++++----
 semantic_router/schema.py         | 20 ++++++++++++++++++-
 4 files changed, 97 insertions(+), 7 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index 7add4cb4..27112021 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -3,6 +3,8 @@ from typing import Any, List, Optional, Tuple, Union, Dict
 import numpy as np
 from pydantic.v1 import BaseModel
 
+from semantic_router.schema import ConfigParameter
+
 
 class BaseIndex(BaseModel):
     """
@@ -145,5 +147,21 @@ class BaseIndex(BaseModel):
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
 
+    def _read_hash(self) -> ConfigParameter:
+        """
+        Read the hash of the previously written index.
+
+        This method should be implemented by subclasses.
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
+    def _write_config(self, config: ConfigParameter):
+        """
+        Write a config parameter to the index.
+
+        This method should be implemented by subclasses.
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
     class Config:
         arbitrary_types_allowed = True
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 23c5a05a..846890fc 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -11,6 +11,7 @@ import numpy as np
 from pydantic.v1 import BaseModel, Field
 
 from semantic_router.index.base import BaseIndex
+from semantic_router.schema import ConfigParameter
 from semantic_router.utils.logger import logger
 
 
@@ -23,7 +24,7 @@ class PineconeRecord(BaseModel):
     values: List[float]
     route: str
     utterance: str
-    function_schema: str
+    function_schema: str = "{}"
     metadata: Dict[str, Any] = {}  # Additional metadata dictionary
 
     def __init__(self, **data):
@@ -75,7 +76,7 @@ class PineconeIndex(BaseIndex):
         host: str = "",
         namespace: Optional[str] = "",
         base_url: Optional[str] = "https://api.pinecone.io",
-        sync: str = "local",
+        sync: Optional[str] = None,
         init_async_index: bool = False,
     ):
         super().__init__()
@@ -85,6 +86,8 @@ class PineconeIndex(BaseIndex):
         self.cloud = cloud
         self.region = region
         self.host = host
+        if namespace == "sr_config":
+            raise ValueError("Namespace 'sr_config' is reserved for internal use.")
         self.namespace = namespace
         self.type = "pinecone"
         self.api_key = api_key or os.getenv("PINECONE_API_KEY")
@@ -674,6 +677,32 @@ class PineconeIndex(BaseIndex):
         route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
         return np.array(scores), route_names
 
+    def _read_hash(self) -> ConfigParameter:
+        hash_record = self.index.fetch(
+            ids=[f"sr_hash#{self.namespace}"],
+            namespace="sr_config",
+        )
+        if hash_record["vectors"]:
+            return ConfigParameter(
+                field="sr_hash",
+                value=hash_record["vectors"]["sr_hash"]["metadata"]["value"],
+                created_at=hash_record["vectors"]["sr_hash"]["metadata"]["created_at"],
+                namespace=self.namespace,
+            )
+        else:
+            raise ValueError("Configuration for hash parameter not found in index.")
+
+    def _write_config(self, config: ConfigParameter) -> None:
+        """Method to write a config parameter to the remote Pinecone index.
+
+        :param config: The config parameter to write to the index.
+        :type config: ConfigParameter
+        """
+        self.index.upsert(
+            vectors=[config.to_pinecone(dimensions=self.dimensions)],
+            namespace="sr_config",
+        )
+
     async def aquery(
         self,
         vector: np.ndarray,
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 659ae6e6..1bf13dbe 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -2,6 +2,7 @@ import importlib
 import json
 import os
 import random
+import hashlib
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
@@ -13,7 +14,7 @@ from semantic_router.index.base import BaseIndex
 from semantic_router.index.local import LocalIndex
 from semantic_router.llms import BaseLLM, OpenAILLM
 from semantic_router.route import Route
-from semantic_router.schema import EncoderType, RouteChoice
+from semantic_router.schema import ConfigParameter, EncoderType, RouteChoice
 from semantic_router.utils.defaults import EncoderDefault
 from semantic_router.utils.logger import logger
 
@@ -170,6 +171,13 @@ class LayerConfig:
             self.routes = [route for route in self.routes if route.name != name]
             logger.info(f"Removed route `{name}`")
 
+    def get_hash(self) -> ConfigParameter:
+        layer = self.to_dict()
+        return ConfigParameter(
+            field="sr_hash",
+            value=hashlib.sha256(json.dumps(layer).encode()).hexdigest(),
+        )
+
 
 class RouteLayer:
     score_threshold: float
@@ -526,15 +534,32 @@ class RouteLayer:
             logger.error(f"Failed to add routes to the index: {e}")
             raise Exception("Indexing error occurred") from e
 
+    def _get_hash(self) -> ConfigParameter:
+        config = self.to_config()
+        return config.get_hash()
+
     def is_synced(self) -> bool:
-        if not self.index.sync:
-            raise ValueError("Index is not set to sync with remote index.")
+        """Check if the local and remote route layer instances are synchronized.
+        """
+        #if not self.index.sync:
+        #    raise ValueError("Index is not set to sync with remote index.")
 
+        # first check hash
+        local_hash = self._get_hash()
+        remote_hash = self.index._read_hash()
+        if local_hash.value == remote_hash.value:
+            return True
+        # TODO: we may be able to remove the below logic
+        # if hashes are different, double check
         local_route_names, local_utterances, local_function_schemas, local_metadata = (
             self._extract_routes_details(self.routes, include_metadata=True)
         )
+        # return result of double check
         return self.index.is_synced(
-            local_route_names, local_utterances, local_function_schemas, local_metadata
+            local_route_names=local_route_names,
+            local_utterances_list=local_utterances,
+            local_function_schemas_list=local_function_schemas,
+            local_metadata_list=local_metadata,
         )
 
     def _add_and_sync_routes(self, routes: List[Route]):
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index b444c988..755792d6 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -1,6 +1,7 @@
+from datetime import datetime
 from enum import Enum
 from typing import List, Optional, Union, Any, Dict
-from pydantic.v1 import BaseModel
+from pydantic.v1 import BaseModel, Field
 
 
 class EncoderType(Enum):
@@ -63,6 +64,23 @@ class DocumentSplit(BaseModel):
     def content(self) -> str:
         return " ".join([doc if isinstance(doc, str) else "" for doc in self.docs])
 
+class ConfigParameter(BaseModel):
+    field: str
+    value: str
+    namespace: str = ""
+    created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
+
+    def to_pinecone(self, dimensions: int):
+        return {
+            "id": f"{self.field}#{self.namespace}",
+            "values": [0.1] * dimensions,
+            "metadata": {
+                "value": self.value,
+                "created_at": self.created_at,
+                "namespace": self.namespace,
+                "field": self.field,
+            },
+        }
 
 class Metric(Enum):
     COSINE = "cosine"
-- 
GitLab