Skip to content
Snippets Groups Projects
Commit f48054fc authored by James Briggs's avatar James Briggs
Browse files

fix: logic changes

parent 81051ec1
No related branches found
No related tags found
No related merge requests found
...@@ -488,6 +488,8 @@ class PineconeIndex(BaseIndex): ...@@ -488,6 +488,8 @@ class PineconeIndex(BaseIndex):
self.dimensions = self.dimensions or len(embeddings[0]) self.dimensions = self.dimensions or len(embeddings[0])
self.index = self._init_index(force_create=True) self.index = self._init_index(force_create=True)
print([(r, u, fs, m) for r, u, fs, m in zip(routes, utterances, function_schemas, metadata_list)])
vectors_to_upsert = [ vectors_to_upsert = [
PineconeRecord( PineconeRecord(
values=vector, values=vector,
...@@ -635,6 +637,7 @@ class PineconeIndex(BaseIndex): ...@@ -635,6 +637,7 @@ class PineconeIndex(BaseIndex):
if self.index is None: if self.index is None:
raise ValueError("Index has not been initialized.") raise ValueError("Index has not been initialized.")
hash_id = f"sr_hash#{self.namespace}" hash_id = f"sr_hash#{self.namespace}"
print(f"hash_id: {hash_id}")
hash_record = self.index.fetch( hash_record = self.index.fetch(
ids=[hash_id], ids=[hash_id],
namespace="sr_config", namespace="sr_config",
......
...@@ -530,6 +530,7 @@ class RouteLayer: ...@@ -530,6 +530,7 @@ class RouteLayer:
) )
self.routes.append(route) self.routes.append(route)
self._write_hash() # update current hash in index
def list_route_names(self) -> List[str]: def list_route_names(self) -> List[str]:
return [route.name for route in self.routes] return [route.name for route in self.routes]
...@@ -625,15 +626,26 @@ class RouteLayer: ...@@ -625,15 +626,26 @@ class RouteLayer:
logger.error(f"Failed to add routes to the index: {e}") logger.error(f"Failed to add routes to the index: {e}")
raise Exception("Indexing error occurred") from e raise Exception("Indexing error occurred") from e
self._write_hash()
def _get_hash(self) -> ConfigParameter: def _get_hash(self) -> ConfigParameter:
config = self.to_config() config = self.to_config()
return config.get_hash() return config.get_hash()
def _write_hash(self) -> ConfigParameter:
config = self.to_config()
hash_config = config.get_hash()
self.index._write_config(config=hash_config)
return hash_config
def is_synced(self) -> bool: def is_synced(self) -> bool:
"""Check if the local and remote route layer instances are synchronized.""" """Check if the local and remote route layer instances are
# if not self.index.sync: synchronized.
# raise ValueError("Index is not set to sync with remote index.")
:return: True if the local and remote route layers are synchronized,
False otherwise.
:rtype: bool
"""
# first check hash # first check hash
local_hash = self._get_hash() local_hash = self._get_hash()
remote_hash = self.index._read_hash() remote_hash = self.index._read_hash()
...@@ -674,6 +686,8 @@ class RouteLayer: ...@@ -674,6 +686,8 @@ class RouteLayer:
# sort local and remote utterances # sort local and remote utterances
local_utterances.sort() local_utterances.sort()
remote_utterances.sort() remote_utterances.sort()
print(remote_utterances)
print(local_utterances)
# now get diff # now get diff
differ = Differ() differ = Differ()
diff = list(differ.compare(local_utterances, remote_utterances)) diff = list(differ.compare(local_utterances, remote_utterances))
...@@ -740,6 +754,7 @@ class RouteLayer: ...@@ -740,6 +754,7 @@ class RouteLayer:
metadata=data.get("metadata", {}), metadata=data.get("metadata", {}),
) )
) )
self._write_hash()
def _extract_routes_details( def _extract_routes_details(
self, routes: List[Route], include_metadata: bool = False self, routes: List[Route], include_metadata: bool = False
......
...@@ -38,7 +38,7 @@ def init_index( ...@@ -38,7 +38,7 @@ def init_index(
index_cls, index_cls,
dimensions: Optional[int] = None, dimensions: Optional[int] = None,
namespace: Optional[str] = "", namespace: Optional[str] = "",
sync: str = "local", sync: str = None,
): ):
"""We use this function to initialize indexes with different names to avoid """We use this function to initialize indexes with different names to avoid
issues during testing. issues during testing.
...@@ -175,7 +175,7 @@ def test_data(): ...@@ -175,7 +175,7 @@ def test_data():
def get_test_indexes(): def get_test_indexes():
indexes = [LocalIndex] indexes = []
if importlib.util.find_spec("qdrant_client") is not None: if importlib.util.find_spec("qdrant_client") is not None:
indexes.append(QdrantIndex) indexes.append(QdrantIndex)
...@@ -193,7 +193,7 @@ class TestRouteLayer: ...@@ -193,7 +193,7 @@ class TestRouteLayer:
def test_initialization(self, openai_encoder, routes, index_cls): def test_initialization(self, openai_encoder, routes, index_cls):
index = init_index(index_cls) index = init_index(index_cls)
_ = RouteLayer( _ = RouteLayer(
encoder=openai_encoder, routes=routes, top_k=10, index=index, sync=None encoder=openai_encoder, routes=routes, top_k=10, index=index
) )
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -202,7 +202,7 @@ class TestRouteLayer: ...@@ -202,7 +202,7 @@ class TestRouteLayer:
def test_second_initialization_sync(self, openai_encoder, routes, index_cls): def test_second_initialization_sync(self, openai_encoder, routes, index_cls):
index = init_index(index_cls) index = init_index(index_cls)
route_layer = RouteLayer( route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, top_k=10, index=index, sync=None encoder=openai_encoder, routes=routes, top_k=10, index=index
) )
if index_cls is PineconeIndex: if index_cls is PineconeIndex:
time.sleep(PINECONE_SLEEP) # allow for index to be populated time.sleep(PINECONE_SLEEP) # allow for index to be populated
...@@ -216,7 +216,7 @@ class TestRouteLayer: ...@@ -216,7 +216,7 @@ class TestRouteLayer:
): ):
index = init_index(index_cls) index = init_index(index_cls)
route_layer = RouteLayer( route_layer = RouteLayer(
encoder=openai_encoder, routes=routes_2, top_k=10, index=index, sync=None encoder=openai_encoder, routes=routes_2, top_k=10, index=index
) )
if index_cls is PineconeIndex: if index_cls is PineconeIndex:
time.sleep(PINECONE_SLEEP) # allow for index to be populated time.sleep(PINECONE_SLEEP) # allow for index to be populated
...@@ -225,14 +225,22 @@ class TestRouteLayer: ...@@ -225,14 +225,22 @@ class TestRouteLayer:
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
) )
def test_utterance_diff(self, openai_encoder, routes_2, index_cls): def test_utterance_diff(self, openai_encoder, routes, routes_2, index_cls):
index = init_index(index_cls) index = init_index(index_cls)
route_layer = RouteLayer( _ = RouteLayer(
encoder=openai_encoder, routes=routes_2, top_k=10, index=index, sync=None encoder=openai_encoder, routes=routes, top_k=10, index=index
) )
if index_cls is PineconeIndex: if index_cls is PineconeIndex:
time.sleep(PINECONE_SLEEP) # allow for index to be populated time.sleep(PINECONE_SLEEP) # allow for index to be populated
diff = route_layer.get_utterance_diff() route_layer_2 = RouteLayer(
assert "+ Route 1: Hi" in diff encoder=openai_encoder, routes=routes_2, top_k=10, index=index
)
if index_cls is PineconeIndex:
time.sleep(PINECONE_SLEEP) # allow for index to be populated
diff = route_layer_2.get_utterance_diff()
assert " Route 1: Hello" in diff assert " Route 1: Hello" in diff
assert "- Route 2: Hi" in diff assert "+ Route 1: Hi" in diff
assert "+ Route 2: Au revoir" in diff
assert "+ Route 2: Bye" in diff
assert "+ Route 2: Goodbye" in diff
assert " Route 2: Hi" in diff
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment