From 5013bb899df4dc84381cd8b599c0ca246b2a8698 Mon Sep 17 00:00:00 2001 From: jamescalam <james.briggs@hotmail.com> Date: Sat, 23 Nov 2024 23:15:47 +0100 Subject: [PATCH] feat: abstracted semantic routelayer --- docs/indexes/pinecone-sync-routes.ipynb | 424 +++++++++++++++++++----- semantic_router/routers/base.py | 34 +- semantic_router/routers/hybrid.py | 3 +- semantic_router/routers/semantic.py | 35 +- 4 files changed, 392 insertions(+), 104 deletions(-) diff --git a/docs/indexes/pinecone-sync-routes.ipynb b/docs/indexes/pinecone-sync-routes.ipynb index 14e78905..6cb9b4eb 100644 --- a/docs/indexes/pinecone-sync-routes.ipynb +++ b/docs/indexes/pinecone-sync-routes.ipynb @@ -6,7 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -qU \"semantic-router[pinecone]==0.1.0.dev1\"" + "!pip install -qU \"semantic-router[pinecone]==0.1.0.dev2\"" ] }, { @@ -30,9 +30,18 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "from semantic_router import Route\n", "\n", @@ -68,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -93,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -108,8 +117,6 @@ "pc_index = PineconeIndex(\n", " dimensions=1536,\n", " init_async_index=True, # enables asynchronous methods, it's optional\n", - " sync=None, # defines whether we sync between local and remote route layers\n", - " # when sync is None, no sync is performed\n", ")\n", "pc_index.index = pc_index._init_index(force_create=True)" ] @@ -130,14 +137,45 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "from semantic_router.layer import RouteLayer\n", + "encoder.score_threshold = None" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m2024-11-23 23:10:13 WARNING semantic_router.utils.logger TEMP | add:\n", + "chitchat: how are things going?\n", + "chitchat: how's the weather today?\n", + "chitchat: let's go to the chippy\n", + "chitchat: lovely weather today\n", + "chitchat: the weather is horrendous\n", + "politics: don't you just hate the president\n", + "politics: don't you just love the president\n", + "politics: isn't politics the best thing ever\n", + "politics: they will save the country!\n", + "politics: they're going to destroy this country!\n", + "politics: why don't you tell me about your political opinions\u001b[0m\n" + ] + } + ], + "source": [ + "from semantic_router.routers import RouteLayer\n", "import time\n", "\n", - "rl = RouteLayer(encoder=encoder, routes=routes, index=pc_index)\n", + "rl = RouteLayer(\n", + " encoder=encoder, routes=routes, index=pc_index,\n", + " auto_sync=\"local\"\n", + ")\n", "# due to pinecone indexing latency we wait 3 seconds\n", "time.sleep(3)" ] @@ -151,23 +189,16 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 6, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hash_id: sr_hash#\n" - ] - }, { "data": { "text/plain": [ - "False" + "True" ] }, - "execution_count": 28, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -185,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -204,23 +235,16 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 9, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hash_id: sr_hash#\n" - ] - }, { "data": { "text/plain": [ "False" ] }, - "execution_count": 30, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -238,27 +262,26 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[\"- politics: don't you just hate the president\",\n", - " \"- politics: don't you just love the president\",\n", - " \"- politics: isn't politics the best thing ever\",\n", - " '- politics: they will save the country!',\n", - " \"- politics: they're going to destroy this country!\",\n", - " \"- politics: why don't you tell me about your political opinions\",\n", - " '+ Route 1: Hello',\n", - " '+ Route 1: Hi',\n", - " '+ Route 2: Au revoir',\n", - " '+ Route 2: Bye',\n", - " '+ Route 2: Goodbye',\n", - " '+ Route 3: Boo']" + "['+ chitchat: how are things going?',\n", + " \"+ chitchat: how's the weather today?\",\n", + " \"+ chitchat: let's go to the chippy\",\n", + " '+ chitchat: lovely weather today',\n", + " '+ chitchat: the weather is horrendous',\n", + " \" politics: don't you just hate the president\",\n", + " \" politics: don't you just love the president\",\n", + " \" politics: isn't politics the best thing ever\",\n", + " ' politics: they will save the country!',\n", + " \" politics: they're going to destroy this country!\",\n", + " \" politics: why don't you tell me about your political opinions\"]" ] }, - "execution_count": 31, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -283,9 +306,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"don't you just hate the president\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"don't you just love the president\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"they're going to destroy this country!\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"isn't politics the best thing ever\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"why don't you tell me about your political opinions\", function_schemas=None, metadata={}, diff_tag=' ')]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "remote_utterances = rl.index.get_utterances()\n", "remote_utterances" @@ -293,9 +337,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[Utterance(route='politics', utterance=\"isn't politics the best thing ever\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"why don't you tell me about your political opinions\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"don't you just love the president\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"don't you just hate the president\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"they're going to destroy this country!\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' ')]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "local_utterances = rl.to_config().to_utterances()\n", "local_utterances" @@ -310,7 +370,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -330,9 +390,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='politics', utterance=\"don't you just hate the president\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"don't you just love the president\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"isn't politics the best thing ever\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"they're going to destroy this country!\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"why don't you tell me about your political opinions\", function_schemas=None, metadata={}, diff_tag=' ')]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "diff.diff" ] @@ -354,9 +435,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "diff.get_tag(\"+\")" ] @@ -370,9 +466,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "diff.get_tag(\"-\")" ] @@ -386,9 +493,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[Utterance(route='politics', utterance=\"don't you just hate the president\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"don't you just love the president\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"isn't politics the best thing ever\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"they're going to destroy this country!\", function_schemas=None, metadata={}, diff_tag=' '),\n", + " Utterance(route='politics', utterance=\"why don't you tell me about your political opinions\", function_schemas=None, metadata={}, diff_tag=' ')]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "diff.get_tag(\" \")" ] @@ -423,45 +546,142 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'remote': {'upsert': [],\n", + " 'delete': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')]},\n", + " 'local': {'upsert': [], 'delete': []}}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "diff.get_sync_strategy(\"local\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'remote': {'upsert': [], 'delete': []},\n", + " 'local': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')],\n", + " 'delete': []}}" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "diff.get_sync_strategy(\"remote\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'remote': {'upsert': [], 'delete': []},\n", + " 'local': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')],\n", + " 'delete': []}}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "diff.get_sync_strategy(\"merge-force-remote\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-11-23 23:14:16 INFO semantic_router.utils.logger local_only_mapper: {}\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'remote': {'upsert': [],\n", + " 'delete': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')]},\n", + " 'local': {'upsert': [], 'delete': []}}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "diff.get_sync_strategy(\"merge-force-local\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'remote': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')],\n", + " 'delete': []},\n", + " 'local': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),\n", + " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')],\n", + " 'delete': []}}" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "diff.get_sync_strategy(\"merge\")" ] @@ -475,9 +695,18 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m2024-11-23 23:14:25 WARNING semantic_router.utils.logger TEMP | _remove_and_sync:\n", + "chitchat: ['how are things going?', \"how's the weather today?\", \"let's go to the chippy\", 'lovely weather today', 'the weather is horrendous']\u001b[0m\n" + ] + } + ], "source": [ "strategy = diff.get_sync_strategy(\"local\")\n", "rl._execute_sync_strategy(strategy=strategy)" @@ -485,9 +714,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "time.sleep(3)\n", "rl.is_synced()" @@ -502,9 +742,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[\" politics: don't you just hate the president\",\n", + " \" politics: don't you just love the president\",\n", + " \" politics: isn't politics the best thing ever\",\n", + " ' politics: they will save the country!',\n", + " \" politics: they're going to destroy this country!\",\n", + " \" politics: why don't you tell me about your political opinions\"]" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "rl.get_utterance_diff()" ] @@ -520,9 +776,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[\" politics: don't you just hate the president\",\n", + " \" politics: don't you just love the president\",\n", + " \" politics: isn't politics the best thing ever\",\n", + " ' politics: they will save the country!',\n", + " \" politics: they're going to destroy this country!\",\n", + " \" politics: why don't you tell me about your political opinions\"]" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "rl.sync(sync_mode=\"local\")" ] @@ -544,7 +816,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 100663da..4fd7b58b 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -278,7 +278,7 @@ class LayerConfig: class BaseRouteLayer(BaseModel): encoder: BaseEncoder index: BaseIndex = Field(default_factory=BaseIndex) - score_threshold: Optional[float] = None + score_threshold: Optional[float] = Field(default=None) routes: List[Route] = [] llm: Optional[BaseLLM] = None top_k: int = 5 @@ -289,10 +289,6 @@ class BaseRouteLayer(BaseModel): class Config: arbitrary_types_allowed = True - @validator("score_threshold", pre=True, always=True) - def set_score_threshold(cls, v): - return float(v) if v is not None else None - @validator("index", pre=True, always=True) def set_index(cls, v): return v if v is not None else LocalIndex() @@ -326,12 +322,13 @@ class BaseRouteLayer(BaseModel): self.encoder = encoder self.llm = llm self.routes = routes if routes else [] - if self.encoder.score_threshold is None: - raise ValueError( - "No score threshold provided for encoder. Please set the score threshold " - "in the encoder config." - ) - self.score_threshold = self.encoder.score_threshold + if self.encoder.score_threshold is not None: + self.score_threshold = self.encoder.score_threshold + if self.score_threshold is None: + logger.warning( + "No score threshold value found in encoder. Using the default " + "'None' value can lead to unexpected results." + ) self.top_k = top_k if self.top_k < 1: raise ValueError(f"top_k needs to be >= 1, but was: {self.top_k}.") @@ -365,6 +362,21 @@ class BaseRouteLayer(BaseModel): sync_strategy = diff.get_sync_strategy(self.auto_sync) self._execute_sync_strategy(sync_strategy) + def _set_score_threshold(self): + """Set the score threshold for the layer based on the encoder + score threshold. + + When no score threshold is used a default `None` value + is used, which means that a route will always be returned when + the layer is called.""" + if self.encoder.score_threshold is not None: + self.score_threshold = self.encoder.score_threshold + if self.score_threshold is None: + logger.warning( + "No score threshold value found in encoder. Using the default " + "'None' value can lead to unexpected results." + ) + def check_for_matching_routes(self, top_class: str) -> Optional[Route]: matching_route = next( (route for route in self.routes if route.name == top_class), None diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py index b65607e7..7ea3eddb 100644 --- a/semantic_router/routers/hybrid.py +++ b/semantic_router/routers/hybrid.py @@ -3,7 +3,6 @@ import asyncio from pydantic.v1 import validator, Field import numpy as np -from numpy.linalg import norm from semantic_router.encoders import ( BaseEncoder, @@ -69,6 +68,8 @@ class HybridRouteLayer(BaseRouteLayer): if routes: for route in routes: self.add(route) + # set score threshold using default method + self._set_score_threshold() # TODO: we can't really use this with hybrid... @validator("sparse_encoder", pre=True, always=True) def set_sparse_encoder(cls, v): diff --git a/semantic_router/routers/semantic.py b/semantic_router/routers/semantic.py index 499bdd73..2104d431 100644 --- a/semantic_router/routers/semantic.py +++ b/semantic_router/routers/semantic.py @@ -4,7 +4,7 @@ import os import random import hashlib from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic.v1 import BaseModel +from pydantic.v1 import validator, BaseModel, Field import numpy as np import yaml # type: ignore @@ -16,6 +16,7 @@ from semantic_router.index.local import LocalIndex from semantic_router.index.pinecone import PineconeIndex from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route +from semantic_router.routers.base import BaseRouteLayer from semantic_router.schema import ( ConfigParameter, EncoderType, @@ -272,14 +273,12 @@ class LayerConfig: ) -class RouteLayer(BaseModel): - score_threshold: float - encoder: BaseEncoder - index: BaseIndex - llm: Optional[BaseLLM] = None - top_k: int = 5 - aggregation: str = "mean" - auto_sync: Optional[str] = None +class RouteLayer(BaseRouteLayer): + index: BaseIndex = Field(default_factory=LocalIndex) + + @validator("index", pre=True, always=True) + def set_index(cls, v): + return v if v is not None else LocalIndex() def __init__( self, @@ -291,7 +290,15 @@ class RouteLayer(BaseModel): aggregation: str = "mean", auto_sync: Optional[str] = None, ): - self.index: BaseIndex = index if index is not None else LocalIndex() + super().__init__( + encoder=encoder, + llm=llm, + routes=routes.copy() if routes else [], + index=index, + top_k=top_k, + aggregation=aggregation, + auto_sync=auto_sync, + ) if encoder is None: logger.warning( "No encoder provided. Using default OpenAIEncoder. Ensure " @@ -302,12 +309,8 @@ class RouteLayer(BaseModel): self.encoder = encoder self.llm = llm self.routes = routes if routes else [] - if self.encoder.score_threshold is None: - raise ValueError( - "No score threshold provided for encoder. Please set the score threshold " - "in the encoder config." - ) - self.score_threshold = self.encoder.score_threshold + # set score threshold using default method + self._set_score_threshold() self.top_k = top_k if self.top_k < 1: raise ValueError(f"top_k needs to be >= 1, but was: {self.top_k}.") -- GitLab