diff --git a/docs/00-introduction.ipynb b/docs/00-introduction.ipynb index 30e08413915ded118a514620d302dfd6f02dbf51..e96e75b89978a0c9c168e995b8bfb8895a690b21 100644 --- a/docs/00-introduction.ipynb +++ b/docs/00-introduction.ipynb @@ -189,7 +189,7 @@ "source": [ "from semantic_router.routers import SemanticRouter\n", "\n", - "sr = SemanticRouter(encoder=encoder, routes=routes)" + "sr = SemanticRouter(encoder=encoder, routes=routes, auto_sync=\"local\")" ] }, { diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py index 5ec3381e31a3be32ed23e2dbb283db899c78c3e6..7adf60c717252b71da0d9f604aec43d2dc5226b9 100644 --- a/semantic_router/encoders/bedrock.py +++ b/semantic_router/encoders/bedrock.py @@ -17,7 +17,7 @@ Classes: """ import json -from typing import List, Optional, Any +from typing import Dict, List, Optional, Any, Union import os from time import sleep import tiktoken @@ -138,11 +138,14 @@ class BedrockEncoder(DenseEncoder): ) from err return bedrock_client - def __call__(self, docs: List[str]) -> List[List[float]]: + def __call__( + self, docs: List[Union[str, Dict]], model_kwargs: Optional[Dict] = None + ) -> List[List[float]]: """Generates embeddings for the given documents. Args: docs: A list of strings representing the documents to embed. + model_kwargs: A dictionary of model-specific inference parameters. Returns: A list of lists, where each inner list contains the embedding values for a @@ -168,13 +171,29 @@ class BedrockEncoder(DenseEncoder): embeddings = [] if self.name and "amazon" in self.name: for doc in docs: - embedding_body = json.dumps( - { - "inputText": doc, - } - ) + + embedding_body = {} + + if isinstance(doc, dict): + embedding_body["inputText"] = doc.get("text") + embedding_body["inputImage"] = doc.get( + "image" + ) # expects a base64-encoded image + else: + embedding_body["inputText"] = doc + + # Add model-specific inference parameters + if model_kwargs: + embedding_body = embedding_body | model_kwargs + + # Clean up null values + embedding_body = {k: v for k, v in embedding_body.items() if v} + + # Format payload + embedding_body_payload: str = json.dumps(embedding_body) + response = self.client.invoke_model( - body=embedding_body, + body=embedding_body_payload, modelId=self.name, accept="application/json", contentType="application/json", @@ -184,9 +203,16 @@ class BedrockEncoder(DenseEncoder): elif self.name and "cohere" in self.name: chunked_docs = self.chunk_strings(docs) for chunk in chunked_docs: - chunk = json.dumps( - {"texts": chunk, "input_type": self.input_type} - ) + chunk = {"texts": chunk, "input_type": self.input_type} + + # Add model-specific inference parameters + # Note: if specified, input_type will be overwritten by model_kwargs + if model_kwargs: + chunk = chunk | model_kwargs + + # Format payload + chunk = json.dumps(chunk) + response = self.client.invoke_model( body=chunk, modelId=self.name, diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 884106c0fbe5b6fa79f201b761d2b18368282be5..243ad433fce6203b368e18f1ab00784149452931 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -1,3 +1,5 @@ +from datetime import datetime +import time from typing import Any, List, Optional, Tuple, Union, Dict import json @@ -157,26 +159,91 @@ class BaseIndex(BaseModel): logger.warning("This method should be implemented by subclasses.") self.index = None - def _read_hash(self) -> ConfigParameter: - """ - Read the hash of the previously written index. + def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: + """Read a config parameter from the index. - This method should be implemented by subclasses. + :param field: The field to read. + :type field: str + :param scope: The scope to read. + :type scope: str | None + :return: The config parameter that was read. + :rtype: ConfigParameter """ logger.warning("This method should be implemented by subclasses.") return ConfigParameter( - field="sr_hash", + field=field, value="", - namespace="", + scope=scope, ) - def _write_config(self, config: ConfigParameter): + def _read_hash(self) -> ConfigParameter: + """Read the hash of the previously written index. + + :return: The config parameter that was read. + :rtype: ConfigParameter """ - Write a config parameter to the index. + return self._read_config(field="sr_hash") - This method should be implemented by subclasses. + def _write_config(self, config: ConfigParameter) -> ConfigParameter: + """Write a config parameter to the index. + + :param config: The config parameter to write. + :type config: ConfigParameter + :return: The config parameter that was written. + :rtype: ConfigParameter """ logger.warning("This method should be implemented by subclasses.") + return config + + def lock( + self, value: bool, wait: int = 0, scope: str | None = None + ) -> ConfigParameter: + """Lock/unlock the index for a given scope (if applicable). If index + already locked/unlocked, raises ValueError. + + :param scope: The scope to lock. + :type scope: str | None + :param wait: The number of seconds to wait for the index to be unlocked, if + set to 0, will raise an error if index is already locked/unlocked. + :type wait: int + :return: The config parameter that was locked. + :rtype: ConfigParameter + """ + start_time = datetime.now() + while True: + if self._is_locked(scope=scope) != value: + # in this case, we can set the lock value + break + if (datetime.now() - start_time).total_seconds() < wait: + # wait for 2.5 seconds before checking again + time.sleep(2.5) + else: + raise ValueError( + f"Index is already {'locked' if value else 'unlocked'}." + ) + lock_param = ConfigParameter( + field="sr_lock", + value=str(value), + scope=scope, + ) + self._write_config(lock_param) + return lock_param + + def _is_locked(self, scope: str | None = None) -> bool: + """Check if the index is locked for a given scope (if applicable). + + :param scope: The scope to check. + :type scope: str | None + :return: True if the index is locked, False otherwise. + :rtype: bool + """ + lock_config = self._read_config(field="sr_lock", scope=scope) + if lock_config.value == "True": + return True + elif lock_config.value == "False" or not lock_config.value: + return False + else: + raise ValueError(f"Invalid lock value: {lock_config.value}") def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): """ diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index b4ba144e89fefe739f5a1d0dbc1ac84be58460f6..469a41411f8f1779f5ecfb9fc1dbd727d349085e 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -405,39 +405,43 @@ class PineconeIndex(BaseIndex): route_names = [result["metadata"]["sr_route"] for result in results["matches"]] return np.array(scores), route_names - def _read_hash(self) -> ConfigParameter: + def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: + scope = scope or self.namespace if self.index is None: return ConfigParameter( - field="sr_hash", + field=field, value="", - namespace=self.namespace, + scope=scope, ) - hash_id = f"sr_hash#{self.namespace}" - hash_record = self.index.fetch( - ids=[hash_id], + config_id = f"{field}#{scope}" + config_record = self.index.fetch( + ids=[config_id], namespace="sr_config", ) - if hash_record["vectors"]: + if config_record["vectors"]: return ConfigParameter( - field="sr_hash", - value=hash_record["vectors"][hash_id]["metadata"]["value"], - created_at=hash_record["vectors"][hash_id]["metadata"]["created_at"], - namespace=self.namespace, + field=field, + value=config_record["vectors"][config_id]["metadata"]["value"], + created_at=config_record["vectors"][config_id]["metadata"][ + "created_at" + ], + scope=scope, ) else: - logger.warning("Configuration for hash parameter not found in index.") + logger.warning(f"Configuration for {field} parameter not found in index.") return ConfigParameter( - field="sr_hash", + field=field, value="", - namespace=self.namespace, + scope=scope, ) - def _write_config(self, config: ConfigParameter) -> None: + def _write_config(self, config: ConfigParameter) -> ConfigParameter: """Method to write a config parameter to the remote Pinecone index. :param config: The config parameter to write to the index. :type config: ConfigParameter """ + config.scope = config.scope or self.namespace if self.index is None: raise ValueError("Index has not been initialized.") if self.dimensions is None: @@ -446,6 +450,7 @@ class PineconeIndex(BaseIndex): vectors=[config.to_pinecone(dimensions=self.dimensions)], namespace="sr_config", ) + return config async def aquery( self, diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 3628fda8553a8a6e6979d6684e2fdfaa1611ead5..328cf2b77c4ff45bed8c861dda72407568841451 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -543,7 +543,7 @@ class BaseRouter(BaseModel): route = self.check_for_matching_routes(top_class) return route, top_class_scores - def sync(self, sync_mode: str, force: bool = False) -> List[str]: + def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: """Runs a sync of the local routes with the remote index. :param sync_mode: The mode to sync the routes with the remote index. @@ -551,6 +551,10 @@ class BaseRouter(BaseModel): :param force: Whether to force the sync even if the local and remote hashes already match. Defaults to False. :type force: bool, optional + :param wait: The number of seconds to wait for the index to be unlocked + before proceeding with the sync. If set to 0, will raise an error if + index is already locked/unlocked. + :type wait: int :return: A list of diffs describing the addressed differences between the local and remote route layers. :rtype: List[str] @@ -565,7 +569,9 @@ class BaseRouter(BaseModel): remote_utterances=local_utterances, ) return diff.to_utterance_str() - # otherwise we continue with the sync, first creating a diff + # otherwise we continue with the sync, first locking the index + _ = self.index.lock(value=True, wait=wait) + # first creating a diff local_utterances = self.to_config().to_utterances() remote_utterances = self.index.get_utterances() diff = UtteranceDiff.from_utterances( @@ -576,6 +582,8 @@ class BaseRouter(BaseModel): sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) # and execute self._execute_sync_strategy(sync_strategy) + # unlock index after sync + _ = self.index.lock(value=False) return diff.to_utterance_str() def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): @@ -781,6 +789,9 @@ class BaseRouter(BaseModel): :param route_name: the name of the route to be deleted :type str: """ + # ensure index is not locked + if self.index._is_locked(): + raise ValueError("Index is locked. Cannot delete route.") current_local_hash = self._get_hash() current_remote_hash = self.index._read_hash() if current_remote_hash.value == "": diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 2a94b3559f572dd878355ad19b1ab3ffdbd25108..273043f55ced8f5fdedb2b91e2ef9b0d35ee3d17 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from difflib import Differ from enum import Enum import json @@ -62,12 +62,13 @@ class Message(BaseModel): class ConfigParameter(BaseModel): field: str value: str - namespace: Optional[str] = None - created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) + scope: Optional[str] = None + created_at: str = Field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) def to_pinecone(self, dimensions: int): - if self.namespace is None: - namespace = "" + namespace = self.scope or "" return { "id": f"{self.field}#{namespace}", "values": [0.1] * dimensions, diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 8e73de34e740562981716e2aafb1a32f29e86620..4439598f3bde912bc956ab1dfe7dc8a59cb607c1 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -485,5 +485,64 @@ class TestSemanticRouter: Utterance(route="Route 3", utterance="Boo"), ], "The routes in the index should match the local routes" - # clear index - route_layer.index.index.delete(namespace="", delete_all=True) + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + def test_sync_lock_prevents_concurrent_sync( + self, openai_encoder, routes, index_cls + ): + """Test that sync lock prevents concurrent synchronization operations""" + index = init_index(index_cls) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=index, + auto_sync=None, + ) + + # Acquire sync lock + route_layer.index.lock(value=True) + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) + + # Attempt to sync while lock is held should raise exception + with pytest.raises(Exception): + route_layer.sync("local") + + # Release lock + route_layer.index.lock(value=False) + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) + + # Should succeed after lock is released + route_layer.sync("local") + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) + assert route_layer.is_synced() + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls): + """Test that sync lock is automatically released after sync operations""" + index = init_index(index_cls) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=index, + auto_sync=None, + ) + + # Initial sync should acquire and release lock + route_layer.sync("local") + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) + + # Lock should be released, allowing another sync + route_layer.sync("local") # Should not raise exception + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) + assert route_layer.is_synced() + + # clear index + route_layer.index.index.delete(namespace="", delete_all=True)