Skip to content
Snippets Groups Projects
Commit 40546927 authored by James Briggs's avatar James Briggs
Browse files

feat: continued refactoring for sync features

parent e27f4446
No related branches found
No related tags found
No related merge requests found
...@@ -4,7 +4,7 @@ import json ...@@ -4,7 +4,7 @@ import json
import numpy as np import numpy as np
from pydantic.v1 import BaseModel from pydantic.v1 import BaseModel
from semantic_router.schema import ConfigParameter from semantic_router.schema import ConfigParameter, Utterance
from semantic_router.route import Route from semantic_router.route import Route
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
...@@ -40,7 +40,7 @@ class BaseIndex(BaseModel): ...@@ -40,7 +40,7 @@ class BaseIndex(BaseModel):
""" """
raise NotImplementedError("This method should be implemented by subclasses.") raise NotImplementedError("This method should be implemented by subclasses.")
def get_utterances(self) -> List[Tuple]: def get_utterances(self) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the """Gets a list of route and utterance objects currently stored in the
index, including additional metadata. index, including additional metadata.
...@@ -50,7 +50,7 @@ class BaseIndex(BaseModel): ...@@ -50,7 +50,7 @@ class BaseIndex(BaseModel):
""" """
_, metadata = self._get_all(include_metadata=True) _, metadata = self._get_all(include_metadata=True)
route_tuples = parse_route_info(metadata=metadata) route_tuples = parse_route_info(metadata=metadata)
return route_tuples return [Utterance.from_tuple(x) for x in route_tuples]
def get_routes(self) -> List[Route]: def get_routes(self) -> List[Route]:
"""Gets a list of route objects currently stored in the index. """Gets a list of route objects currently stored in the index.
......
...@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple, Dict ...@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple, Dict
import numpy as np import numpy as np
from semantic_router.schema import ConfigParameter from semantic_router.schema import ConfigParameter, Utterance
from semantic_router.index.base import BaseIndex from semantic_router.index.base import BaseIndex
from semantic_router.linear import similarity_matrix, top_scores from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
...@@ -61,7 +61,7 @@ class LocalIndex(BaseIndex): ...@@ -61,7 +61,7 @@ class LocalIndex(BaseIndex):
if self.sync is not None: if self.sync is not None:
logger.error("Sync remove is not implemented for LocalIndex.") logger.error("Sync remove is not implemented for LocalIndex.")
def get_utterances(self) -> List[Tuple]: def get_utterances(self) -> List[Utterance]:
""" """
Gets a list of route and utterance objects currently stored in the index. Gets a list of route and utterance objects currently stored in the index.
...@@ -70,7 +70,9 @@ class LocalIndex(BaseIndex): ...@@ -70,7 +70,9 @@ class LocalIndex(BaseIndex):
""" """
if self.routes is None or self.utterances is None: if self.routes is None or self.utterances is None:
return [] return []
return list(zip(self.routes, self.utterances)) return [
Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)
]
def describe(self) -> Dict: def describe(self) -> Dict:
return { return {
......
...@@ -15,7 +15,7 @@ from semantic_router.index.base import BaseIndex ...@@ -15,7 +15,7 @@ from semantic_router.index.base import BaseIndex
from semantic_router.index.local import LocalIndex from semantic_router.index.local import LocalIndex
from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route from semantic_router.route import Route
from semantic_router.schema import ConfigParameter, EncoderType, RouteChoice from semantic_router.schema import ConfigParameter, EncoderType, RouteChoice, Utterance, UtteranceDiff
from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
...@@ -218,31 +218,23 @@ class LayerConfig: ...@@ -218,31 +218,23 @@ class LayerConfig:
elif ext in [".yaml", ".yml"]: elif ext in [".yaml", ".yml"]:
yaml.safe_dump(self.to_dict(), f) yaml.safe_dump(self.to_dict(), f)
def _get_diff(self, other: "LayerConfig") -> List[str]: def to_utterances(self) -> List[Utterance]:
"""Get the difference between two LayerConfigs. """Convert the routes to a list of Utterance objects.
:param other: The LayerConfig to compare to. :return: A list of Utterance objects.
:type other: LayerConfig :rtype: List[Utterance]
:return: A list of differences between the two LayerConfigs.
:rtype: List[Dict[str, Any]]
""" """
# TODO: formalize diffs into likely LayerDiff objects that can then utterances = []
# output different formats as required to enable smarter syncs for route in self.routes:
self_yaml = yaml.dump(self.to_dict()) utterances.extend([
other_yaml = yaml.dump(other.to_dict()) Utterance(
differ = Differ() route=route.name,
return list(differ.compare(self_yaml.splitlines(), other_yaml.splitlines())) utterance=x,
function_schemas=route.function_schemas,
def show_diff(self, other: "LayerConfig") -> str: metadata=route.metadata
"""Show the difference between two LayerConfigs. ) for x in route.utterances
])
:param other: The LayerConfig to compare to. return utterances
:type other: LayerConfig
:return: A string showing the difference between the two LayerConfigs.
:rtype: str
"""
diff = self._get_diff(other)
return "\n".join(diff)
def add(self, route: Route): def add(self, route: Route):
self.routes.append(route) self.routes.append(route)
...@@ -283,6 +275,7 @@ class RouteLayer: ...@@ -283,6 +275,7 @@ class RouteLayer:
index: Optional[BaseIndex] = None, # type: ignore index: Optional[BaseIndex] = None, # type: ignore
top_k: int = 5, top_k: int = 5,
aggregation: str = "sum", aggregation: str = "sum",
auto_sync: Optional[str] = None,
): ):
self.index: BaseIndex = index if index is not None else LocalIndex() self.index: BaseIndex = index if index is not None else LocalIndex()
if encoder is None: if encoder is None:
...@@ -310,14 +303,17 @@ class RouteLayer: ...@@ -310,14 +303,17 @@ class RouteLayer:
f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'." f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'."
) )
self.aggregation_method = self._set_aggregation_method(self.aggregation) self.aggregation_method = self._set_aggregation_method(self.aggregation)
self.auto_sync = auto_sync
# set route score thresholds if not already set # set route score thresholds if not already set
for route in self.routes: for route in self.routes:
if route.score_threshold is None: if route.score_threshold is None:
route.score_threshold = self.score_threshold route.score_threshold = self.score_threshold
# if routes list has been passed, we initialize index now # if routes list has been passed, we initialize index now
if self.index.sync: if self.auto_sync:
# initialize index now # initialize index now
dims = self.encoder.dimensions
self.index._init_index(force_create=True, dimensions=dims)
if len(self.routes) > 0: if len(self.routes) > 0:
self._add_and_sync_routes(routes=self.routes) self._add_and_sync_routes(routes=self.routes)
else: else:
...@@ -447,6 +443,100 @@ class RouteLayer: ...@@ -447,6 +443,100 @@ class RouteLayer:
return route_choices return route_choices
def sync(self, sync_mode: str, force: bool = False) -> List[str]:
"""Runs a sync of the local routes with the remote index.
:param sync_mode: The mode to sync the routes with the remote index.
:type sync_mode: str
:param force: Whether to force the sync even if the local and remote
hashes already match. Defaults to False.
:type force: bool, optional
:return: A list of diffs describing the addressed differences between
the local and remote route layers.
:rtype: List[str]
"""
if not force and self.is_synced():
logger.warning("Local and remote route layers are already synchronized.")
# create utterance diff to return, but just using local instance
# for speed
local_utterances = self.to_config().to_utterances()
diff = UtteranceDiff.from_utterances(
local_utterances=local_utterances,
remote_utterances=local_utterances,
)
return diff.to_utterance_str()
# otherwise we continue with the sync, first creating a diff
local_utterances = self.to_config().to_utterances()
remote_utterances = self.index.get_utterances()
diff = UtteranceDiff.from_utterances(
local_utterances=local_utterances,
remote_utterances=remote_utterances,
)
# generate sync strategy
sync_strategy = diff.to_sync_strategy()
# and execute
self._execute_sync_strategy(sync_strategy)
return diff.to_utterance_str()
def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]):
"""Executes the provided sync strategy, either deleting or upserting
routes from the local and remote instances as defined in the strategy.
:param strategy: The sync strategy to execute.
:type strategy: Dict[str, Dict[str, List[Utterance]]]
"""
if strategy["remote"]["delete"]:
data_to_delete = {} # type: ignore
for utt_obj in strategy["remote"]["delete"]:
data_to_delete.setdefault(
utt_obj.route, []
).append(utt_obj.utterance)
self.index._remove_and_sync(data_to_delete)
if strategy["remote"]["upsert"]:
utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]]
self.index.add(
embeddings=self.encoder(utterances_text),
routes=[utt.route for utt in strategy["remote"]["upsert"]],
utterances=utterances_text,
function_schemas=[utt.function_schemas for utt in strategy["remote"]["upsert"]],
metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]],
)
if strategy["local"]["delete"]:
self._local_delete(utterances=strategy["local"]["delete"])
if strategy["local"]["upsert"]:
self._local_upsert(utterances=strategy["local"]["upsert"])
# update hash
self._write_hash()
def _local_upsert(self, utterances: List[Utterance]):
"""Adds new routes to the RouteLayer.
:param utterances: The utterances to add to the local RouteLayer.
:type utterances: List[Utterance]
"""
new_routes = {}
for utt_obj in utterances:
if utt_obj.route not in new_routes.keys():
new_routes[utt_obj.route] = Route(
name=utt_obj.route,
utterances=[utt_obj.utterance],
function_schemas=utt_obj.function_schemas,
metadata=utt_obj.metadata
)
else:
new_routes[utt_obj.route].utterances.append(utt_obj.utterance)
self.routes.extend(list(new_routes.values()))
def _local_delete(self, utterances: List[Utterance]):
"""Deletes routes from the local RouteLayer.
:param utterances: The utterances to delete from the local RouteLayer.
:type utterances: List[Utterance]
"""
route_names = set([utt.route for utt in utterances])
self.routes = [route for route in self.routes if route.name not in route_names]
def _retrieve_top_route( def _retrieve_top_route(
self, vector: List[float], route_filter: Optional[List[str]] = None self, vector: List[float], route_filter: Optional[List[str]] = None
) -> Tuple[Optional[Route], List[float]]: ) -> Tuple[Optional[Route], List[float]]:
...@@ -735,97 +825,27 @@ class RouteLayer: ...@@ -735,97 +825,27 @@ class RouteLayer:
"route2: utterance4", which do not exist locally. "route2: utterance4", which do not exist locally.
""" """
# first we get remote and local utterances # first we get remote and local utterances
remote_utterances = [f"{x[0]}: {x[1]}" for x in self.index.get_utterances()] remote_utterances = self.index.get_utterances()
local_routes, local_utterance_arr, _ = self._extract_routes_details( local_utterances = self.to_config().to_utterances()
self.routes, include_metadata=False
)
local_utterances = [
f"{x[0]}: {x[1]}" for x in zip(local_routes, local_utterance_arr)
]
# sort local and remote utterances
local_utterances.sort()
remote_utterances.sort()
# now get diff
differ = Differ()
diff = list(differ.compare(local_utterances, remote_utterances))
return diff
def _add_and_sync_routes(self, routes: List[Route]):
# get current local hash
current_local_hash = self._get_hash()
current_remote_hash = self.index._read_hash()
if current_remote_hash.value == "":
# if remote hash is empty, the index is to be initialized
current_remote_hash = current_local_hash
# create embeddings for all routes and sync at startup with remote ones based on sync setting
local_route_names, local_utterances, local_function_schemas, local_metadata = (
self._extract_routes_details(routes, include_metadata=True)
)
routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index( diff_obj = UtteranceDiff.from_utterances(
local_route_names, local_utterances=local_utterances, remote_utterances=remote_utterances
local_utterances,
local_function_schemas,
local_metadata,
dimensions=self.index.dimensions or len(self.encoder(["dummy"])[0]),
) )
return diff_obj.to_utterance_str()
data_to_delete = {} # type: ignore def _add_and_sync_routes(self, routes: List[Route]):
for route, utterance in routes_to_delete: self.routes.extend(routes)
data_to_delete.setdefault(route, []).append(utterance) # first we get remote and local utterances
self.index._remove_and_sync(data_to_delete) remote_utterances = self.index.get_utterances()
local_utterances = self.to_config().to_utterances()
# Prepare data for addition
if routes_to_add:
(
route_names_to_add,
all_utterances_to_add,
function_schemas_to_add,
metadata_to_add,
) = map(list, zip(*routes_to_add))
else:
(
route_names_to_add,
all_utterances_to_add,
function_schemas_to_add,
metadata_to_add,
) = ([], [], [], [])
embedded_utterances_to_add = (
self.encoder(all_utterances_to_add) if all_utterances_to_add else []
)
self.index.add( diff_obj = UtteranceDiff.from_utterances(
embeddings=embedded_utterances_to_add, local_utterances=local_utterances, remote_utterances=remote_utterances
routes=route_names_to_add,
utterances=all_utterances_to_add,
function_schemas=function_schemas_to_add,
metadata_list=metadata_to_add,
) )
sync_strategy = diff_obj.get_sync_strategy(sync_mode=self.auto_sync)
# Update local route layer state self._execute_sync_strategy(strategy=sync_strategy)
self.routes = [] # update remote hash
for route, data in layer_routes_dict.items(): self._write_hash()
function_schemas = data.get("function_schemas", None)
if function_schemas is not None:
function_schemas = [function_schemas]
self.routes.append(
Route(
name=route,
utterances=data.get("utterances", []),
function_schemas=function_schemas,
metadata=data.get("metadata", {}),
)
)
# update hash IF index and local hash were aligned
if current_local_hash.value == current_remote_hash.value:
self._write_hash()
else:
logger.warning(
"Local and remote route layers were not aligned. Remote hash "
"not updated. Use `RouteLayer.get_utterance_diff()` to see "
"details."
)
def _extract_routes_details( def _extract_routes_details(
self, routes: List[Route], include_metadata: bool = False self, routes: List[Route], include_metadata: bool = False
......
from datetime import datetime from datetime import datetime
from difflib import Differ
from enum import Enum from enum import Enum
from typing import List, Optional, Union, Any, Dict from typing import List, Optional, Union, Any, Dict, Tuple
from pydantic.v1 import BaseModel, Field from pydantic.v1 import BaseModel, Field
...@@ -18,6 +19,9 @@ class EncoderType(Enum): ...@@ -18,6 +19,9 @@ class EncoderType(Enum):
GOOGLE = "google" GOOGLE = "google"
BEDROCK = "bedrock" BEDROCK = "bedrock"
def to_list():
return [encoder.value for encoder in EncoderType]
class EncoderInfo(BaseModel): class EncoderInfo(BaseModel):
name: str name: str
...@@ -86,6 +90,234 @@ class ConfigParameter(BaseModel): ...@@ -86,6 +90,234 @@ class ConfigParameter(BaseModel):
} }
class Utterance(BaseModel):
route: str
utterance: str
function_schemas: Optional[List[Dict]] = None
metadata: Optional[Dict] = None
diff_tag: str = " "
@classmethod
def from_tuple(cls, tuple_obj: Tuple):
"""Create an Utterance object from a tuple. The tuple must contain
route and utterance as the first two elements. Then optionally
function schemas and metadata as the third and fourth elements
respectively. If this order is not followed an invalid Utterance
object will be returned.
:param tuple_obj: A tuple containing route, utterance, function schemas and metadata.
:type tuple_obj: Tuple
:return: An Utterance object.
:rtype: Utterance
"""
route, utterance = tuple_obj[0], tuple_obj[1]
function_schemas = tuple_obj[2] if len(tuple_obj) > 2 else None
metadata = tuple_obj[3] if len(tuple_obj) > 3 else None
return cls(
route=route,
utterance=utterance,
function_schemas=function_schemas,
metadata=metadata
)
def to_tuple(self):
return (
self.route,
self.utterance,
self.function_schemas,
self.metadata,
)
def to_str(self, include_metadata: bool = False):
if include_metadata:
return f"{self.route}: {self.utterance} | {self.function_schemas} | {self.metadata}"
return f"{self.route}: {self.utterance}"
def to_diff_str(self):
return f"{self.diff_tag} {self.to_str()}"
class SyncMode(Enum):
"""Synchronization modes for local (route layer) and remote (index)
instances.
"""
ERROR = "error"
REMOTE = "remote"
LOCAL = "local"
MERGE_FORCE_REMOTE = "merge-force-remote"
MERGE_FORCE_LOCAL = "merge-force-local"
MERGE = "merge"
def to_list() -> List[str]:
return [mode.value for mode in SyncMode]
class UtteranceDiff(BaseModel):
diff: List[Utterance]
@classmethod
def from_utterances(
cls,
local_utterances: List[Utterance],
remote_utterances: List[Utterance]
):
local_utterances_map = {x.to_str(): x for x in local_utterances}
remote_utterances_map = {x.to_str(): x for x in remote_utterances}
# sort local and remote utterances
local_utterances_str = list(local_utterances_map.keys())
local_utterances_str.sort()
remote_utterances_str = list(remote_utterances_map.keys())
remote_utterances_str.sort()
# get diff
differ = Differ()
diff_obj = list(differ.compare(local_utterances_str, remote_utterances_str))
# create UtteranceDiff list
utterance_diffs = []
for line in diff_obj:
utterance_str = line[2:]
utterance_diff_tag = line[0]
utterance = remote_utterances_map[utterance_str] if utterance_diff_tag == "+" else local_utterances_map[utterance_str]
utterance.diff_tag = utterance_diff_tag
utterance_diffs.append(utterance)
return UtteranceDiff(diff=utterance_diffs)
def to_utterance_str(self) -> List[str]:
"""Outputs the utterance diff as a list of diff strings. Returns a list
of strings showing what is different in the remote when compared to the
local. For example:
[" route1: utterance1",
" route1: utterance2",
"- route2: utterance3",
"- route2: utterance4"]
Tells us that the remote is missing "route2: utterance3" and "route2:
utterance4", which do exist locally. If we see:
[" route1: utterance1",
" route1: utterance2",
"+ route2: utterance3",
"+ route2: utterance4"]
This diff tells us that the remote has "route2: utterance3" and
"route2: utterance4", which do not exist locally.
"""
return [x.to_diff_str() for x in self.diff]
def get_tag(self, diff_tag: str) -> List[Utterance]:
"""Get all utterances with a given diff tag.
:param diff_tag: The diff tag to filter by. Must be one of "+", "-", or
" ".
:type diff_tag: str
:return: A list of Utterance objects.
:rtype: List[Utterance]
"""
if diff_tag not in ["+", "-", " "]:
raise ValueError("diff_tag must be one of '+', '-', or ' '")
return [x for x in self.diff if x.diff_tag == diff_tag]
def get_sync_strategy(self, sync_mode: str) -> dict:
"""Generates the optimal synchronization plan for local and remote
instances.
:param sync_mode: The mode to sync the routes with the remote index.
:type sync_mode: str
:return: A dictionary describing the synchronization strategy.
:rtype: dict
"""
if sync_mode not in SyncMode.to_list():
raise ValueError(f"sync_mode must be one of {SyncMode.to_list()}")
local_only = self.get_tag("-")
remote_only = self.get_tag("+")
local_and_remote = self.get_tag(" ")
if sync_mode == "error":
if len(local_only) > 0 or len(remote_only) > 0:
raise ValueError(
"There are utterances that exist in the local or remote "
"instance that do not exist in the other instance. Please "
"sync the routes before running this command."
)
else:
return {
"remote": {
"upsert": [],
"delete": []
},
"local": {
"upsert": [],
"delete": []
}
}
elif sync_mode == "local":
return {
"remote": {
"upsert": local_only,
"delete": remote_only
},
"local": {
"upsert": [],
"delete": []
}
}
elif sync_mode == "remote":
return {
"remote": {
"upsert": [],
"delete": []
},
"local": {
"upsert": remote_only,
"delete": local_only
}
}
elif sync_mode == "merge-force-remote":
# get set of route names that exist in both local and remote
routes_in_both = set([utt.route for utt in local_and_remote])
# get remote utterances that belong to routes_in_both
remote_to_keep = [utt for utt in remote_only if utt.route in routes_in_both]
# get remote utterances that do NOT belong to routes_in_both
remote_to_delete = [utt for utt in remote_only if utt.route not in routes_in_both]
return {
"remote": {
"upsert": local_only,
"delete": remote_to_delete
},
"local": {
"upsert": remote_to_keep,
"delete": []
}
}
elif sync_mode == "merge-force-local":
# get set of route names that exist in both local and remote
routes_in_both = set([utt.route for utt in local_and_remote])
# get local utterances that belong to routes_in_both
local_to_keep = [utt for utt in local_only if utt.route in routes_in_both]
# get local utterances that do NOT belong to routes_in_both
local_to_delete = [utt for utt in local_only if utt.route not in routes_in_both]
return {
"remote": {
"upsert": local_to_keep,
"delete": []
},
"local": {
"upsert": remote_only,
"delete": local_to_delete
}
}
elif sync_mode == "merge":
return {
"remote": {
"upsert": local_only,
"delete": []
},
"local": {
"upsert": remote_only,
"delete": []
}
}
class Metric(Enum): class Metric(Enum):
COSINE = "cosine" COSINE = "cosine"
DOTPRODUCT = "dotproduct" DOTPRODUCT = "dotproduct"
......
...@@ -238,6 +238,7 @@ class TestRouteLayer: ...@@ -238,6 +238,7 @@ class TestRouteLayer:
assert route_layer_openai.score_threshold == 0.3 assert route_layer_openai.score_threshold == 0.3
def test_delete_index(self, openai_encoder, routes, index_cls): def test_delete_index(self, openai_encoder, routes, index_cls):
# TODO merge .delete_index() and .delete_all() and get working
index = init_index(index_cls) index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index) route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
if index_cls is PineconeIndex: if index_cls is PineconeIndex:
......
...@@ -237,3 +237,101 @@ class TestRouteLayer: ...@@ -237,3 +237,101 @@ class TestRouteLayer:
assert "+ Route 2: Bye" in diff assert "+ Route 2: Bye" in diff
assert "+ Route 2: Goodbye" in diff assert "+ Route 2: Goodbye" in diff
assert " Route 2: Hi" in diff assert " Route 2: Hi" in diff
@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_auto_sync_local(self, openai_encoder, routes, routes_2, routes_4, index_cls):
if index_cls is PineconeIndex:
# TEST LOCAL
pinecone_index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes_2, index=pinecone_index,
auto_sync="local"
)
time.sleep(PINECONE_SLEEP) # allow for index to be populated
assert route_layer.index.get_utterances() == [
("Route 1", "Hello", None, {}),
("Route 2", "Hi", None, {}),
], "The routes in the index should match the local routes"
@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_auto_sync_remote(self, openai_encoder, routes, index_cls):
if index_cls is PineconeIndex:
# TEST REMOTE
pinecone_index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index,
auto_sync="remote"
)
time.sleep(PINECONE_SLEEP) # allow for index to be populated
assert route_layer.index.get_utterances() == [
("Route 1", "Hello", None, {}),
("Route 2", "Hi", None, {}),
], "The routes in the index should match the local routes"
@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_auto_sync_merge_force_remote(self, openai_encoder, routes, index_cls):
if index_cls is PineconeIndex:
# TEST MERGE FORCE REMOTE
pinecone_index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index,
auto_sync="merge-force-remote"
)
time.sleep(PINECONE_SLEEP) # allow for index to be populated
assert route_layer.index.get_utterances() == [
("Route 1", "Hello", None, {}),
("Route 2", "Hi", None, {}),
], "The routes in the index should match the local routes"
@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_auto_sync_merge_force_local(self, openai_encoder, routes, index_cls):
if index_cls is PineconeIndex:
# TEST MERGE FORCE LOCAL
pinecone_index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index,
auto_sync="merge-force-local"
)
time.sleep(PINECONE_SLEEP) # allow for index to be populated
assert route_layer.index.get_utterances() == [
("Route 1", "Hello", None, {"type": "default"}),
("Route 1", "Hi", None, {"type": "default"}),
("Route 2", "Bye", None, {}),
("Route 2", "Au revoir", None, {}),
("Route 2", "Goodbye", None, {}),
], "The routes in the index should match the local routes"
@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_auto_sync_merge(self, openai_encoder, routes_4, index_cls):
if index_cls is PineconeIndex:
# TEST MERGE
pinecone_index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes_4, index=pinecone_index,
auto_sync="merge"
)
time.sleep(PINECONE_SLEEP) # allow for index to be populated
assert route_layer.index.get_utterances() == [
("Route 1", "Hello", None, {"type": "default"}),
("Route 1", "Hi", None, {"type": "default"}),
("Route 1", "Goodbye", None, {"type": "default"}),
("Route 2", "Bye", None, {}),
("Route 2", "Asparagus", None, {}),
("Route 2", "Au revoir", None, {}),
("Route 2", "Goodbye", None, {}),
], "The routes in the index should match the local routes"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment