From 87d6ae4b160f66efbdaaabd0234129dc50c6a669 Mon Sep 17 00:00:00 2001 From: jamescalam <james.briggs@hotmail.com> Date: Tue, 26 Nov 2024 22:39:28 +0100 Subject: [PATCH] feat: new sparse embedding support and abstractions --- docs/examples/pinecone-hybrid.ipynb | 196 ++++++++++++++++------------ semantic_router/encoders/aurelio.py | 17 +-- semantic_router/index/pinecone.py | 27 ++-- semantic_router/routers/base.py | 92 ++++++++----- semantic_router/routers/hybrid.py | 47 ++----- semantic_router/routers/semantic.py | 100 +------------- semantic_router/schema.py | 38 +++++- 7 files changed, 245 insertions(+), 272 deletions(-) diff --git a/docs/examples/pinecone-hybrid.ipynb b/docs/examples/pinecone-hybrid.ipynb index f18eef6e..134b6e0d 100644 --- a/docs/examples/pinecone-hybrid.ipynb +++ b/docs/examples/pinecone-hybrid.ipynb @@ -174,9 +174,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-24 19:41:05 - pinecone_plugin_interface.logging - INFO - discover_namespace_packages.py:12 - discover_subpackages() - Discovering subpackages in _NamespacePath(['/Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/pinecone_plugins'])\n", - "2024-11-24 19:41:05 - pinecone_plugin_interface.logging - INFO - discover_plugins.py:9 - discover_plugins() - Looking for plugins in pinecone_plugins.inference\n", - "2024-11-24 19:41:05 - pinecone_plugin_interface.logging - INFO - installation.py:10 - install_plugins() - Installing plugin inference into Pinecone\n" + "2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - discover_namespace_packages.py:12 - discover_subpackages() - Discovering subpackages in _NamespacePath(['/Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/pinecone_plugins'])\n", + "2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - discover_plugins.py:9 - discover_plugins() - Looking for plugins in pinecone_plugins.inference\n", + "2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - installation.py:10 - install_plugins() - Installing plugin inference into Pinecone\n" ] } ], @@ -205,29 +205,7 @@ "cell_type": "code", "execution_count": 7, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-11-24 19:41:15 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n", - "2024-11-24 19:41:17 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n", - "politics: isn't politics the best thing ever\n", - "politics: why don't you tell me about your political opinions\n", - "politics: don't you just love the president\n", - "politics: don't you just hate the president\n", - "politics: they're going to destroy this country!\n", - "politics: they will save the country!\n", - "2024-11-24 19:41:17 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n", - "2024-11-24 19:41:18 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n", - "chitchat: how's the weather today?\n", - "chitchat: how are things going?\n", - "chitchat: lovely weather today\n", - "chitchat: the weather is horrendous\n", - "chitchat: let's go to the chippy\n" - ] - } - ], + "outputs": [], "source": [ "from semantic_router.routers import HybridRouter\n", "\n", @@ -248,23 +226,16 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-11-24 19:42:06 - semantic_router.utils.logger - WARNING - pinecone.py:424 - _read_hash() - Configuration for hash parameter not found in index.\n" - ] - }, { "data": { "text/plain": [ - "False" + "True" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -282,26 +253,26 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['- chitchat: how are things going?',\n", - " \"- chitchat: how's the weather today?\",\n", - " \"- chitchat: let's go to the chippy\",\n", - " '- chitchat: lovely weather today',\n", - " '- chitchat: the weather is horrendous',\n", - " \"- politics: don't you just hate the president\",\n", - " \"- politics: don't you just love the president\",\n", - " \"- politics: isn't politics the best thing ever\",\n", - " '- politics: they will save the country!',\n", - " \"- politics: they're going to destroy this country!\",\n", - " \"- politics: why don't you tell me about your political opinions\"]" + "[' chitchat: how are things going?',\n", + " \" chitchat: how's the weather today?\",\n", + " \" chitchat: let's go to the chippy\",\n", + " ' chitchat: lovely weather today',\n", + " ' chitchat: the weather is horrendous',\n", + " \" politics: don't you just hate the president\",\n", + " \" politics: don't you just love the president\",\n", + " \" politics: isn't politics the best thing ever\",\n", + " ' politics: they will save the country!',\n", + " \" politics: they're going to destroy this country!\",\n", + " \" politics: why don't you tell me about your political opinions\"]" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -319,16 +290,26 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[]" + "[Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"don't you just hate the president\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"don't you just love the president\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"they're going to destroy this country!\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"isn't politics the best thing ever\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"why don't you tell me about your political opinions\", function_schemas=None, metadata={}, diff_tag=' ')]" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -346,31 +327,9 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-11-24 19:48:29 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n", - "2024-11-24 19:48:31 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n", - "politics: isn't politics the best thing ever\n", - "politics: why don't you tell me about your political opinions\n", - "politics: don't you just love the president\n", - "politics: don't you just hate the president\n", - "politics: they're going to destroy this country!\n", - "politics: they will save the country!\n", - "2024-11-24 19:48:31 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n", - "2024-11-24 19:48:32 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n", - "chitchat: how's the weather today?\n", - "chitchat: how are things going?\n", - "chitchat: lovely weather today\n", - "chitchat: the weather is horrendous\n", - "chitchat: let's go to the chippy\n" - ] - } - ], + "outputs": [], "source": [ "router = HybridRouter(\n", " encoder=encoder,\n", @@ -390,16 +349,16 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "False" + "True" ] }, - "execution_count": 16, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -408,6 +367,36 @@ "router.is_synced()" ] }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[' chitchat: how are things going?',\n", + " \" chitchat: how's the weather today?\",\n", + " \" chitchat: let's go to the chippy\",\n", + " ' chitchat: lovely weather today',\n", + " ' chitchat: the weather is horrendous',\n", + " \" politics: don't you just hate the president\",\n", + " \" politics: don't you just love the president\",\n", + " \" politics: isn't politics the best thing ever\",\n", + " ' politics: they will save the country!',\n", + " \" politics: they're going to destroy this country!\",\n", + " \" politics: why don't you tell me about your political opinions\"]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "router.get_utterance_diff()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -417,9 +406,54 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-26 22:35:56 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name=None, function_call=None, similarity_score=None)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "router(\"it's raining cats and dogs today\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-26 22:35:20 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name=None, function_call=None, similarity_score=None)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "router(\"I'm interested in learning about llama 2\")" ] diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py index c257b514..bc150e50 100644 --- a/semantic_router/encoders/aurelio.py +++ b/semantic_router/encoders/aurelio.py @@ -5,6 +5,7 @@ from pydantic.v1 import Field from aurelio_sdk import AurelioClient, AsyncAurelioClient, EmbeddingResponse from semantic_router.encoders.base import BaseEncoder +from semantic_router.schema import SparseEmbedding class AurelioSparseEncoder(BaseEncoder): @@ -28,19 +29,15 @@ class AurelioSparseEncoder(BaseEncoder): self.client = AurelioClient(api_key=api_key) self.async_client = AsyncAurelioClient(api_key=api_key) - def __call__(self, docs: list[str]) -> list[dict[int, float]]: + def __call__(self, docs: list[str]) -> list[SparseEmbedding]: res: EmbeddingResponse = self.client.embedding(input=docs, model=self.name) - embeds = [r.embedding.model_dump() for r in res.data] - # convert sparse vector to {index: value} format - sparse_dicts = [{i: v for i, v in zip(e["indices"], e["values"])} for e in embeds] - return sparse_dicts + embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] + return embeds - async def acall(self, docs: list[str]) -> list[dict[int, float]]: + async def acall(self, docs: list[str]) -> list[SparseEmbedding]: res: EmbeddingResponse = await self.async_client.embedding(input=docs, model=self.name) - embeds = [r.embedding.model_dump() for r in res.data] - # convert sparse vector to {index: value} format - sparse_dicts = [{i: v for i, v in zip(e["indices"], e["values"])} for e in embeds] - return sparse_dicts + embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] + return embeds def fit(self, docs: List[str]): raise NotImplementedError("AurelioSparseEncoder does not support fit.") diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 2d432d33..5eb0aecb 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -11,7 +11,7 @@ import numpy as np from pydantic.v1 import BaseModel, Field from semantic_router.index.base import BaseIndex -from semantic_router.schema import ConfigParameter +from semantic_router.schema import ConfigParameter, SparseEmbedding from semantic_router.utils.logger import logger @@ -243,8 +243,6 @@ class PineconeIndex(BaseIndex): sparse_embeddings: Optional[List[dict[int, float]]] = None, ): """Add vectors to Pinecone in batches.""" - temp = "\n".join([f"{x[0]}: {x[1]}" for x in zip(routes, utterances)]) - logger.warning("TEMP | add:\n" + temp) if self.index is None: self.dimensions = self.dimensions or len(embeddings[0]) self.index = self._init_index(force_create=True) @@ -272,10 +270,6 @@ class PineconeIndex(BaseIndex): self._batch_upsert(batch) def _remove_and_sync(self, routes_to_delete: dict): - temp = "\n".join( - [f"{route}: {utterances}" for route, utterances in routes_to_delete.items()] - ) - logger.warning("TEMP | _remove_and_sync:\n" + temp) for route, utterances in routes_to_delete.items(): remote_routes = self._get_routes_with_ids(route_name=route) ids_to_delete = [ @@ -364,6 +358,7 @@ class PineconeIndex(BaseIndex): vector: np.ndarray, top_k: int = 5, route_filter: Optional[List[str]] = None, + sparse_vector: dict[int, float] | SparseEmbedding | None = None, **kwargs: Any, ) -> Tuple[np.ndarray, List[str]]: """Search the index for the query vector and return the top_k results. @@ -374,10 +369,10 @@ class PineconeIndex(BaseIndex): :type top_k: int, optional :param route_filter: A list of route names to filter the search results, defaults to None. :type route_filter: Optional[List[str]], optional + :param sparse_vector: An optional sparse vector to include in the query. + :type sparse_vector: Optional[SparseEmbedding] :param kwargs: Additional keyword arguments for the query, including sparse_vector. :type kwargs: Any - :keyword sparse_vector: An optional sparse vector to include in the query. - :type sparse_vector: Optional[dict] :return: A tuple containing an array of scores and a list of route names. :rtype: Tuple[np.ndarray, List[str]] :raises ValueError: If the index is not populated. @@ -389,9 +384,13 @@ class PineconeIndex(BaseIndex): filter_query = {"sr_route": {"$in": route_filter}} else: filter_query = None + if sparse_vector is not None: + if isinstance(sparse_vector, dict): + sparse_vector = SparseEmbedding.from_dict(sparse_vector) + sparse_vector = sparse_vector.to_pinecone() results = self.index.query( vector=[query_vector_list], - sparse_vector=kwargs.get("sparse_vector", None), + sparse_vector=sparse_vector, top_k=top_k, filter=filter_query, include_metadata=True, @@ -653,6 +652,8 @@ class PineconeIndex(BaseIndex): ) def __len__(self): - return self.index.describe_index_stats()["namespaces"][self.namespace][ - "vector_count" - ] + namespace_stats = self.index.describe_index_stats()["namespaces"].get(self.namespace) + if namespace_stats: + return namespace_stats["vector_count"] + else: + return 0 diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index f9170051..f7ab5226 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -250,8 +250,39 @@ class RouterConfig: return utterances def add(self, route: Route): + """Add a route to the local SemanticRouter and index. + + :param route: The route to add. + :type route: Route + """ + current_local_hash = self._get_hash() + current_remote_hash = self.index._read_hash() + if current_remote_hash.value == "": + # if remote hash is empty, the index is to be initialized + current_remote_hash = current_local_hash + embedded_utterances = self.encoder(route.utterances) + self.index.add( + embeddings=embedded_utterances, + routes=[route.name] * len(route.utterances), + utterances=route.utterances, + function_schemas=( + route.function_schemas * len(route.utterances) + if route.function_schemas + else [{}] * len(route.utterances) + ), + metadata_list=[route.metadata if route.metadata else {}] + * len(route.utterances), + ) + self.routes.append(route) - logger.info(f"Added route `{route.name}`") + if current_local_hash.value == current_remote_hash.value: + self._write_hash() # update current hash in index + else: + logger.warning( + "Local and remote route layers were not aligned. Remote hash " + "not updated. Use `SemanticRouter.get_utterance_diff()` to see " + "details." + ) def get(self, name: str) -> Optional[Route]: for route in self.routes: @@ -289,10 +320,6 @@ class BaseRouter(BaseModel): class Config: arbitrary_types_allowed = True - @validator("index", pre=True, always=True) - def set_index(cls, v): - return v if v is not None else LocalIndex() - def __init__( self, encoder: Optional[BaseEncoder] = None, @@ -321,14 +348,11 @@ class BaseRouter(BaseModel): else: self.encoder = encoder self.llm = llm - self.routes = routes if routes else [] - if self.encoder.score_threshold is not None: - self.score_threshold = self.encoder.score_threshold - if self.score_threshold is None: - logger.warning( - "No score threshold value found in encoder. Using the default " - "'None' value can lead to unexpected results." - ) + self.routes = routes.copy() if routes else [] + # initialize index + self._set_index(index=index) + # set score threshold using default method + self._set_score_threshold() self.top_k = top_k if self.top_k < 1: raise ValueError(f"top_k needs to be >= 1, but was: {self.top_k}.") @@ -344,15 +368,29 @@ class BaseRouter(BaseModel): for route in self.routes: if route.score_threshold is None: route.score_threshold = self.score_threshold - # if routes list has been passed, we initialize index now + # run initialize index now if auto sync is active + if self.auto_sync: + self._init_index_state() + + def _set_index(self, index: Optional[BaseIndex]): + if index is None: + logger.warning("No index provided. Using default LocalIndex.") + self.index = LocalIndex() + else: + self.index = index + + def _init_index_state(self): + """Initializes an index (where required) and runs auto_sync if active. + """ + # initialize index now, check if we need dimensions + if self.index.dimensions is None: + dims = len(self.encoder(["test"])[0]) + self.index.dimensions = dims + # now init index + if isinstance(self.index, PineconeIndex): + self.index.index = self.index._init_index(force_create=True) + # run auto sync if active if self.auto_sync: - # initialize index now, check if we need dimensions - if self.index.dimensions is None: - dims = len(self.encoder(["test"])[0]) - self.index.dimensions = dims - # now init index - if isinstance(self.index, PineconeIndex): - self.index.index = self.index._init_index(force_create=True) local_utterances = self.to_config().to_utterances() remote_utterances = self.index.get_utterances() diff = UtteranceDiff.from_utterances( @@ -585,8 +623,6 @@ class BaseRouter(BaseModel): new_routes[utt_obj.route].utterances.append(utt_obj.utterance) new_routes[utt_obj.route].function_schemas = utt_obj.function_schemas new_routes[utt_obj.route].metadata = utt_obj.metadata - temp = "\n".join([f"{name}: {r.utterances}" for name, r in new_routes.items()]) - logger.warning("TEMP | _local_upsert:\n" + temp) self.routes = list(new_routes.values()) def _local_delete(self, utterances: List[Utterance]): @@ -599,8 +635,6 @@ class BaseRouter(BaseModel): route_dict: dict[str, List[str]] = {} for utt in utterances: route_dict.setdefault(utt.route, []).append(utt.utterance) - temp = "\n".join([f"{r}: {u}" for r, u in route_dict.items()]) - logger.warning("TEMP | _local_delete:\n" + temp) # iterate over current routes and delete specific utterance if found new_routes = [] for route in self.routes: @@ -622,17 +656,9 @@ class BaseRouter(BaseModel): metadata=route.metadata, ) ) - logger.warning( - f"TEMP | _local_delete OLD | {route.name}: {route.utterances}" - ) - logger.warning( - f"TEMP | _local_delete NEW | {route.name}: {new_routes[-1].utterances}" - ) else: # the route is not in the route_dict, so we keep it as is new_routes.append(route) - temp = "\n".join([f"{r}: {u}" for r, u in route_dict.items()]) - logger.warning("TEMP | _local_delete:\n" + temp) self.routes = new_routes diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py index 0b9271a7..a0a96671 100644 --- a/semantic_router/routers/hybrid.py +++ b/semantic_router/routers/hybrid.py @@ -11,7 +11,7 @@ from semantic_router.encoders import ( ) from semantic_router.route import Route from semantic_router.index.hybrid_local import HybridLocalIndex -from semantic_router.schema import RouteChoice +from semantic_router.schema import RouteChoice, SparseEmbedding from semantic_router.utils.logger import logger from semantic_router.routers.base import BaseRouter from semantic_router.llms import BaseLLM @@ -36,10 +36,13 @@ class HybridRouter(BaseRouter): auto_sync: Optional[str] = None, alpha: float = 0.3, ): + if index is None: + logger.warning("No index provided. Using default HybridLocalIndex.") + index = HybridLocalIndex() super().__init__( encoder=encoder, llm=llm, - #routes=routes.copy(), + routes=routes, index=index, top_k=top_k, aggregation=aggregation, @@ -49,28 +52,14 @@ class HybridRouter(BaseRouter): self._set_sparse_encoder(sparse_encoder=sparse_encoder) # set alpha self.alpha = alpha - # create copy of routes - routes_copy = routes.copy() # fit sparse encoder if needed if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( self.sparse_encoder, "fit" ): - self.sparse_encoder.fit(routes_copy) - # initialize index if not provided - self._set_index(index=index) - # add routes if we have them - if routes_copy: - for route in routes_copy: - self.add(route) - # set score threshold using default method - self._set_score_threshold() # TODO: we can't really use this with hybrid... - - def _set_index(self, index: Optional[HybridLocalIndex]): - if index is None: - logger.warning("No index provided. Using default HybridLocalIndex.") - self.index = HybridLocalIndex() - else: - self.index = index + self.sparse_encoder.fit(self.routes) + # run initialize index now if auto sync is active + if self.auto_sync: + self._init_index_state() def _set_sparse_encoder(self, sparse_encoder: Optional[BaseEncoder]): if sparse_encoder is None: @@ -121,7 +110,7 @@ class HybridRouter(BaseRouter): vector: Optional[List[float]] = None, simulate_static: bool = False, route_filter: Optional[List[str]] = None, - sparse_vector: Optional[dict[int, float]] = None, + sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> RouteChoice: # if no vector provided, encode text to get vector if vector is None: @@ -148,22 +137,6 @@ class HybridRouter(BaseRouter): else: return RouteChoice() - def add(self, route: Route): - self.routes += [route] - - route_names = [route.name] * len(route.utterances) - - # create embeddings for all routes - dense_embeds, sparse_embeds = self._encode(route.utterances) - self.index.add( - embeddings=dense_embeds, - sparse_embeddings=sparse_embeds, - routes=route_names, # TODO: aligning names of routes v route_names - utterances=route.utterances, - ) - # TODO: in some places we say vector, sparse_vector and in others - # TODO: we say embeddings, sparse_embeddings - def _convex_scaling(self, dense: np.ndarray, sparse: list[dict[int, float]]): # scale sparse and dense vecs scaled_dense = np.array(dense) * self.alpha diff --git a/semantic_router/routers/semantic.py b/semantic_router/routers/semantic.py index 45493e67..f0df2717 100644 --- a/semantic_router/routers/semantic.py +++ b/semantic_router/routers/semantic.py @@ -53,12 +53,6 @@ def is_valid(layer_config: str) -> bool: class SemanticRouter(BaseRouter): - index: BaseIndex = Field(default_factory=LocalIndex) - - @validator("index", pre=True, always=True) - def set_index(cls, v): - return v if v is not None else LocalIndex() - def __init__( self, encoder: Optional[BaseEncoder] = None, @@ -72,56 +66,15 @@ class SemanticRouter(BaseRouter): super().__init__( encoder=encoder, llm=llm, - routes=routes.copy() if routes else [], + routes=routes if routes else [], index=index, top_k=top_k, aggregation=aggregation, auto_sync=auto_sync, ) - if encoder is None: - logger.warning( - "No encoder provided. Using default OpenAIEncoder. Ensure " - "that you have set OPENAI_API_KEY in your environment." - ) - self.encoder = OpenAIEncoder() - else: - self.encoder = encoder - self.llm = llm - self.routes = routes if routes else [] - # set score threshold using default method - self._set_score_threshold() - self.top_k = top_k - if self.top_k < 1: - raise ValueError(f"top_k needs to be >= 1, but was: {self.top_k}.") - self.aggregation = aggregation - if self.aggregation not in ["sum", "mean", "max"]: - raise ValueError( - f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'." - ) - self.aggregation_method = self._set_aggregation_method(self.aggregation) - self.auto_sync = auto_sync - - # set route score thresholds if not already set - for route in self.routes: - if route.score_threshold is None: - route.score_threshold = self.score_threshold - # if routes list has been passed, we initialize index now + # run initialize index now if auto sync is active if self.auto_sync: - # initialize index now, check if we need dimensions - if self.index.dimensions is None: - dims = len(self.encoder(["test"])[0]) - self.index.dimensions = dims - # now init index - if isinstance(self.index, PineconeIndex): - self.index.index = self.index._init_index(force_create=True) - local_utterances = self.to_config().to_utterances() - remote_utterances = self.index.get_utterances() - diff = UtteranceDiff.from_utterances( - local_utterances=local_utterances, - remote_utterances=remote_utterances, - ) - sync_strategy = diff.get_sync_strategy(self.auto_sync) - self._execute_sync_strategy(sync_strategy) + self._init_index_state() def check_for_matching_routes(self, top_class: str) -> Optional[Route]: matching_route = next( @@ -331,8 +284,6 @@ class SemanticRouter(BaseRouter): new_routes[utt_obj.route].utterances.append(utt_obj.utterance) new_routes[utt_obj.route].function_schemas = utt_obj.function_schemas new_routes[utt_obj.route].metadata = utt_obj.metadata - temp = "\n".join([f"{name}: {r.utterances}" for name, r in new_routes.items()]) - logger.warning("TEMP | _local_upsert:\n" + temp) self.routes = list(new_routes.values()) def _local_delete(self, utterances: List[Utterance]): @@ -345,8 +296,6 @@ class SemanticRouter(BaseRouter): route_dict: dict[str, List[str]] = {} for utt in utterances: route_dict.setdefault(utt.route, []).append(utt.utterance) - temp = "\n".join([f"{r}: {u}" for r, u in route_dict.items()]) - logger.warning("TEMP | _local_delete:\n" + temp) # iterate over current routes and delete specific utterance if found new_routes = [] for route in self.routes: @@ -368,17 +317,9 @@ class SemanticRouter(BaseRouter): metadata=route.metadata, ) ) - logger.warning( - f"TEMP | _local_delete OLD | {route.name}: {route.utterances}" - ) - logger.warning( - f"TEMP | _local_delete NEW | {route.name}: {new_routes[-1].utterances}" - ) else: # the route is not in the route_dict, so we keep it as is new_routes.append(route) - temp = "\n".join([f"{r}: {u}" for r, u in route_dict.items()]) - logger.warning("TEMP | _local_delete:\n" + temp) self.routes = new_routes @@ -449,41 +390,6 @@ class SemanticRouter(BaseRouter): encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model return cls(encoder=encoder, routes=config.routes, index=index) - def add(self, route: Route): - """Add a route to the local SemanticRouter and index. - - :param route: The route to add. - :type route: Route - """ - current_local_hash = self._get_hash() - current_remote_hash = self.index._read_hash() - if current_remote_hash.value == "": - # if remote hash is empty, the index is to be initialized - current_remote_hash = current_local_hash - embedded_utterances = self.encoder(route.utterances) - self.index.add( - embeddings=embedded_utterances, - routes=[route.name] * len(route.utterances), - utterances=route.utterances, - function_schemas=( - route.function_schemas * len(route.utterances) - if route.function_schemas - else [{}] * len(route.utterances) - ), - metadata_list=[route.metadata if route.metadata else {}] - * len(route.utterances), - ) - - self.routes.append(route) - if current_local_hash.value == current_remote_hash.value: - self._write_hash() # update current hash in index - else: - logger.warning( - "Local and remote route layers were not aligned. Remote hash " - "not updated. Use `SemanticRouter.get_utterance_diff()` to see " - "details." - ) - def list_route_names(self) -> List[str]: return [route.name for route in self.routes] diff --git a/semantic_router/schema.py b/semantic_router/schema.py index eea86b2e..2d00572f 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -4,7 +4,7 @@ from enum import Enum from typing import List, Optional, Union, Any, Dict, Tuple from pydantic.v1 import BaseModel, Field from semantic_router.utils.logger import logger - +from aurelio_sdk.schema import BM25Embedding class EncoderType(Enum): AURELIO = "aurelio" @@ -404,3 +404,39 @@ class Metric(Enum): DOTPRODUCT = "dotproduct" EUCLIDEAN = "euclidean" MANHATTAN = "manhattan" + + +class SparseValue(BaseModel): + index: int + value: float + + +class SparseEmbedding(BaseModel): + embedding: List[SparseValue] + + def to_dict(self): + return {x.index: x.value for x in self.embedding} + + def to_pinecone(self): + return { + "indices": [x.index for x in self.embedding], + "values": [x.value for x in self.embedding], + } + + @classmethod + def from_dict(cls, sparse_dict: dict): + return cls(embedding=[SparseValue(index=i, value=v) for i, v in sparse_dict.items()]) + + @classmethod + def from_aurelio(cls, embedding: BM25Embedding): + return cls(embedding=[ + SparseValue( + index=i, + value=v + ) for i, v in zip(embedding.indices, embedding.values) + ]) + + # dictionary interface + def items(self): + return [(x.index, x.value) for x in self.embedding] + -- GitLab