Skip to content
Snippets Groups Projects
Commit e365910a authored by jamescalam's avatar jamescalam
Browse files

feat: new route layer naming convention

parent c4ed95f3
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
!pip install -qU "semantic-router[pinecone]==0.1.0.dev2" !pip install -qU "semantic-router[pinecone]==0.1.0.dev2"
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Syncing Routes with Pinecone Index # Syncing Routes with Pinecone Index
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
When using the `PineconeIndex`, our `RouteLayer` is stored in two places: When using the `PineconeIndex`, our `RouteLayer` is stored in two places:
* We keep route layer metadata locally. * We keep route layer metadata locally.
* Vectors alongside a backup of our metadata is stored remotely in Pinecone. * Vectors alongside a backup of our metadata is stored remotely in Pinecone.
By storing some data locally and some remotely we achieve improved persistence and the ability to recover our local state if lost. However, it does come with challenges around keep our local and remote instances synchronized. Fortunately, we have [several synchronization options](https://docs.aurelio.ai/semantic-router/route_layer/sync.html). In this example, we'll see how to use these options to keep our local and remote Pinecone instances synchronized. By storing some data locally and some remotely we achieve improved persistence and the ability to recover our local state if lost. However, it does come with challenges around keep our local and remote instances synchronized. Fortunately, we have [several synchronization options](https://docs.aurelio.ai/semantic-router/route_layer/sync.html). In this example, we'll see how to use these options to keep our local and remote Pinecone instances synchronized.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from semantic_router import Route from semantic_router import Route
# we could use this as a guide for our chatbot to avoid political conversations # we could use this as a guide for our chatbot to avoid political conversations
politics = Route( politics = Route(
name="politics", name="politics",
utterances=[ utterances=[
"isn't politics the best thing ever", "isn't politics the best thing ever",
"why don't you tell me about your political opinions", "why don't you tell me about your political opinions",
"don't you just love the president", "don't you just love the president",
"don't you just hate the president", "don't you just hate the president",
"they're going to destroy this country!", "they're going to destroy this country!",
"they will save the country!", "they will save the country!",
], ],
) )
# this could be used as an indicator to our chatbot to switch to a more # this could be used as an indicator to our chatbot to switch to a more
# conversational prompt # conversational prompt
chitchat = Route( chitchat = Route(
name="chitchat", name="chitchat",
utterances=[ utterances=[
"how's the weather today?", "how's the weather today?",
"how are things going?", "how are things going?",
"lovely weather today", "lovely weather today",
"the weather is horrendous", "the weather is horrendous",
"let's go to the chippy", "let's go to the chippy",
], ],
) )
# we place both of our decisions together into single list # we place both of our decisions together into single list
routes = [politics, chitchat] routes = [politics, chitchat]
``` ```
%% Output %% Output
/Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html /Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm from .autonotebook import tqdm as notebook_tqdm
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import os import os
from getpass import getpass from getpass import getpass
from semantic_router.encoders import OpenAIEncoder from semantic_router.encoders import OpenAIEncoder
# get at platform.openai.com # get at platform.openai.com
os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY") or getpass( os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY") or getpass(
"Enter OpenAI API key: " "Enter OpenAI API key: "
) )
encoder = OpenAIEncoder(name="text-embedding-3-small") encoder = OpenAIEncoder(name="text-embedding-3-small")
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
For our `PineconeIndex` we do the exact same thing, ie we initialize as usual: For our `PineconeIndex` we do the exact same thing, ie we initialize as usual:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import os import os
from semantic_router.index.pinecone import PineconeIndex from semantic_router.index.pinecone import PineconeIndex
# get at app.pinecone.io # get at app.pinecone.io
os.environ["PINECONE_API_KEY"] = os.environ.get("PINECONE_API_KEY") or getpass( os.environ["PINECONE_API_KEY"] = os.environ.get("PINECONE_API_KEY") or getpass(
"Enter Pinecone API key: " "Enter Pinecone API key: "
) )
pc_index = PineconeIndex( pc_index = PineconeIndex(
dimensions=1536, dimensions=1536,
init_async_index=True, # enables asynchronous methods, it's optional init_async_index=True, # enables asynchronous methods, it's optional
) )
pc_index.index = pc_index._init_index(force_create=True) pc_index.index = pc_index._init_index(force_create=True)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## RouteLayer ## RouteLayer
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The `RouteLayer` class supports both sync and async operations by default, so we initialize as usual: The `RouteLayer` class supports both sync and async operations by default, so we initialize as usual:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
encoder.score_threshold = None
```
%% Cell type:code id: tags:
``` python
from semantic_router.routers import RouteLayer from semantic_router.routers import RouteLayer
import time import time
rl = RouteLayer( rl = RouteLayer(encoder=encoder, routes=routes, index=pc_index, auto_sync="local")
encoder=encoder, routes=routes, index=pc_index,
auto_sync="local"
)
# due to pinecone indexing latency we wait 3 seconds # due to pinecone indexing latency we wait 3 seconds
time.sleep(3) time.sleep(3)
``` ```
%% Output %% Output
2024-11-23 23:10:13 WARNING semantic_router.utils.logger TEMP | add: 2024-11-23 23:46:42 WARNING semantic_router.utils.logger TEMP | add:
chitchat: how are things going?
chitchat: how's the weather today?
chitchat: let's go to the chippy
chitchat: lovely weather today
chitchat: the weather is horrendous
2024-11-23 23:46:50 WARNING semantic_router.utils.logger TEMP | add:
chitchat: how are things going? chitchat: how are things going?
chitchat: how's the weather today? chitchat: how's the weather today?
chitchat: let's go to the chippy chitchat: let's go to the chippy
chitchat: lovely weather today chitchat: lovely weather today
chitchat: the weather is horrendous chitchat: the weather is horrendous
politics: don't you just hate the president
politics: don't you just love the president
politics: isn't politics the best thing ever
politics: they will save the country!
politics: they're going to destroy this country!
politics: why don't you tell me about your political opinions
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Let's see if our local and remote instances are synchronized... Let's see if our local and remote instances are synchronized...
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
rl.is_synced() rl.is_synced()
``` ```
%% Output %% Output
True True
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
It looks like everything is synced! Let's try deleting our local route layer, initializing it with just the politics route, and checking again. It looks like everything is synced! Let's try deleting our local route layer, initializing it with just the politics route, and checking again.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
del rl del rl
rl = RouteLayer(encoder=encoder, routes=[politics], index=pc_index) rl = RouteLayer(encoder=encoder, routes=[politics], index=pc_index)
time.sleep(3) time.sleep(3)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Let's try `rl.is_synced()` again: Let's try `rl.is_synced()` again:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
rl.is_synced() rl.is_synced()
``` ```
%% Output %% Output
False False
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We can use the `get_utterance_diff` method to see exactly _why_ our local and remote are not synced We can use the `get_utterance_diff` method to see exactly _why_ our local and remote are not synced
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
rl.get_utterance_diff() rl.get_utterance_diff()
``` ```
%% Output %% Output
['+ chitchat: how are things going?', ['+ chitchat: how are things going?',
"+ chitchat: how's the weather today?", "+ chitchat: how's the weather today?",
"+ chitchat: let's go to the chippy", "+ chitchat: let's go to the chippy",
'+ chitchat: lovely weather today', '+ chitchat: lovely weather today',
'+ chitchat: the weather is horrendous', '+ chitchat: the weather is horrendous',
" politics: don't you just hate the president", " politics: don't you just hate the president",
" politics: don't you just love the president", " politics: don't you just love the president",
" politics: isn't politics the best thing ever", " politics: isn't politics the best thing ever",
' politics: they will save the country!', ' politics: they will save the country!',
" politics: they're going to destroy this country!", " politics: they're going to destroy this country!",
" politics: why don't you tell me about your political opinions"] " politics: why don't you tell me about your political opinions"]
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Handling Synchronization ## Handling Synchronization
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We may want to handle the resynchronization ourselves and to do that we ideally want a more structured version of the utterance diff returned above. To create that we first need to get a list of utterance objects from our remote and local instances: We may want to handle the resynchronization ourselves and to do that we ideally want a more structured version of the utterance diff returned above. To create that we first need to get a list of utterance objects from our remote and local instances:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
remote_utterances = rl.index.get_utterances() remote_utterances = rl.index.get_utterances()
remote_utterances remote_utterances
``` ```
%% Output %% Output
[Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag=' '), [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="don't you just hate the president", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="don't you just hate the president", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="don't you just love the president", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="don't you just love the president", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="they're going to destroy this country!", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="they're going to destroy this country!", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="isn't politics the best thing ever", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="isn't politics the best thing ever", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="why don't you tell me about your political opinions", function_schemas=None, metadata={}, diff_tag=' ')] Utterance(route='politics', utterance="why don't you tell me about your political opinions", function_schemas=None, metadata={}, diff_tag=' ')]
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
local_utterances = rl.to_config().to_utterances() local_utterances = rl.to_config().to_utterances()
local_utterances local_utterances
``` ```
%% Output %% Output
[Utterance(route='politics', utterance="isn't politics the best thing ever", function_schemas=None, metadata={}, diff_tag=' '), [Utterance(route='politics', utterance="isn't politics the best thing ever", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="why don't you tell me about your political opinions", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="why don't you tell me about your political opinions", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="don't you just love the president", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="don't you just love the president", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="don't you just hate the president", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="don't you just hate the president", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="they're going to destroy this country!", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="they're going to destroy this country!", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' ')] Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' ')]
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We can add the `diff_tag` attribute to each of these utterances by loading both lists into a `UtteranceDiff` object: We can add the `diff_tag` attribute to each of these utterances by loading both lists into a `UtteranceDiff` object:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from semantic_router.schema import UtteranceDiff from semantic_router.schema import UtteranceDiff
diff = UtteranceDiff.from_utterances( diff = UtteranceDiff.from_utterances(
local_utterances=local_utterances, remote_utterances=remote_utterances local_utterances=local_utterances, remote_utterances=remote_utterances
) )
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
`UtteranceDiff` objects include all diff information inside the `diff` attribute (which is a list of `Utterance` objects): `UtteranceDiff` objects include all diff information inside the `diff` attribute (which is a list of `Utterance` objects):
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
diff.diff diff.diff
``` ```
%% Output %% Output
[Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'), [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='politics', utterance="don't you just hate the president", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="don't you just hate the president", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="don't you just love the president", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="don't you just love the president", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="isn't politics the best thing ever", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="isn't politics the best thing ever", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="they're going to destroy this country!", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="they're going to destroy this country!", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="why don't you tell me about your political opinions", function_schemas=None, metadata={}, diff_tag=' ')] Utterance(route='politics', utterance="why don't you tell me about your political opinions", function_schemas=None, metadata={}, diff_tag=' ')]
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Each of our `Utterance` objects now contains a populate `diff_tag` attribute. Where: Each of our `Utterance` objects now contains a populate `diff_tag` attribute. Where:
* `diff_tag='+'` means the utterance exists in the remote instance *only* * `diff_tag='+'` means the utterance exists in the remote instance *only*
* `diff_tag='-'` means the utterance exists in the local instance *only* * `diff_tag='-'` means the utterance exists in the local instance *only*
* `diff_tag=' '` means the utterance exists in both remote and local instances * `diff_tag=' '` means the utterance exists in both remote and local instances
So, to collect utterances missing from our local instance we can run: So, to collect utterances missing from our local instance we can run:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
diff.get_tag("+") diff.get_tag("+")
``` ```
%% Output %% Output
[Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'), [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')] Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')]
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
To collect utterances missing from our remote instance we can run: To collect utterances missing from our remote instance we can run:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
diff.get_tag("-") diff.get_tag("-")
``` ```
%% Output %% Output
[] []
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
And, if needed, we can get all utterances that exist in both with: And, if needed, we can get all utterances that exist in both with:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
diff.get_tag(" ") diff.get_tag(" ")
``` ```
%% Output %% Output
[Utterance(route='politics', utterance="don't you just hate the president", function_schemas=None, metadata={}, diff_tag=' '), [Utterance(route='politics', utterance="don't you just hate the president", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="don't you just love the president", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="don't you just love the president", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="isn't politics the best thing ever", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="isn't politics the best thing ever", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="they're going to destroy this country!", function_schemas=None, metadata={}, diff_tag=' '), Utterance(route='politics', utterance="they're going to destroy this country!", function_schemas=None, metadata={}, diff_tag=' '),
Utterance(route='politics', utterance="why don't you tell me about your political opinions", function_schemas=None, metadata={}, diff_tag=' ')] Utterance(route='politics', utterance="why don't you tell me about your political opinions", function_schemas=None, metadata={}, diff_tag=' ')]
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Synchronization ## Synchronization
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
There are six synchronization methods that we can use, those are: There are six synchronization methods that we can use, those are:
* `error`: Raise an error if local and remote are not synchronized. * `error`: Raise an error if local and remote are not synchronized.
* `remote`: Take remote as the source of truth and update local to align. * `remote`: Take remote as the source of truth and update local to align.
* `local`: Take local as the source of truth and update remote to align. * `local`: Take local as the source of truth and update remote to align.
* `merge-force-remote`: Merge both local and remote keeping local as the priority. Remote utterances are only merged into local *if* a matching route for the utterance is found in local, all other route-utterances are dropped. Where a route exists in both local and remote, but each contains different `function_schema` or `metadata` information, the local version takes priority and local `function_schemas` and `metadata` is propogated to all remote utterances belonging to the given route. * `merge-force-remote`: Merge both local and remote keeping local as the priority. Remote utterances are only merged into local *if* a matching route for the utterance is found in local, all other route-utterances are dropped. Where a route exists in both local and remote, but each contains different `function_schema` or `metadata` information, the local version takes priority and local `function_schemas` and `metadata` is propogated to all remote utterances belonging to the given route.
* `merge-force-local`: Merge both local and remote keeping remote as the priority. Local utterances are only merged into remote *if* a matching route for the utterance is found in the remote, all other route-utterances are dropped. Where a route exists in both local and remote, but each contains different `function_schema` or `metadata` information, the remote version takes priotity and remote `function_schemas` and `metadata` are propogated to all local routes. * `merge-force-local`: Merge both local and remote keeping remote as the priority. Local utterances are only merged into remote *if* a matching route for the utterance is found in the remote, all other route-utterances are dropped. Where a route exists in both local and remote, but each contains different `function_schema` or `metadata` information, the remote version takes priotity and remote `function_schemas` and `metadata` are propogated to all local routes.
* `merge`: Merge both local and remote, merging also local and remote utterances when a route with same route name is present both locally and remotely. If a route exists in both local and remote but contains different `function_schemas` or `metadata` information, the local version takes priority and local `function_schemas` and `metadata` are propogated to all remote routes. * `merge`: Merge both local and remote, merging also local and remote utterances when a route with same route name is present both locally and remotely. If a route exists in both local and remote but contains different `function_schemas` or `metadata` information, the local version takes priority and local `function_schemas` and `metadata` are propogated to all remote routes.
We can get the synchronization strategy for each of these (with the exception of `error`) using the `diff.get_sync_strategy` method. We can get the synchronization strategy for each of these (with the exception of `error`) using the `diff.get_sync_strategy` method.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
diff.get_sync_strategy("local") diff.get_sync_strategy("local")
``` ```
%% Output %% Output
{'remote': {'upsert': [], {'remote': {'upsert': [],
'delete': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'), 'delete': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')]}, Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')]},
'local': {'upsert': [], 'delete': []}} 'local': {'upsert': [], 'delete': []}}
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
diff.get_sync_strategy("remote") diff.get_sync_strategy("remote")
``` ```
%% Output %% Output
{'remote': {'upsert': [], 'delete': []}, {'remote': {'upsert': [], 'delete': []},
'local': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'), 'local': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')], Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')],
'delete': []}} 'delete': []}}
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
diff.get_sync_strategy("merge-force-remote") diff.get_sync_strategy("merge-force-remote")
``` ```
%% Output %% Output
{'remote': {'upsert': [], 'delete': []}, {'remote': {'upsert': [], 'delete': []},
'local': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'), 'local': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')], Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')],
'delete': []}} 'delete': []}}
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
diff.get_sync_strategy("merge-force-local") diff.get_sync_strategy("merge-force-local")
``` ```
%% Output %% Output
2024-11-23 23:14:16 INFO semantic_router.utils.logger local_only_mapper: {} 2024-11-23 23:47:11 INFO semantic_router.utils.logger local_only_mapper: {}
{'remote': {'upsert': [], {'remote': {'upsert': [],
'delete': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'), 'delete': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')]}, Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')]},
'local': {'upsert': [], 'delete': []}} 'local': {'upsert': [], 'delete': []}}
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
diff.get_sync_strategy("merge") diff.get_sync_strategy("merge")
``` ```
%% Output %% Output
{'remote': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'), {'remote': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')], Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')],
'delete': []}, 'delete': []},
'local': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'), 'local': {'upsert': [Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="how's the weather today?", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance="let's go to the chippy", function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'), Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag='+'),
Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')], Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag='+')],
'delete': []}} 'delete': []}}
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Each of these sync strategies can be fed to our route layer via the `rl._execute_sync_strategy` method: Each of these sync strategies can be fed to our route layer via the `rl._execute_sync_strategy` method:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
strategy = diff.get_sync_strategy("local") strategy = diff.get_sync_strategy("local")
rl._execute_sync_strategy(strategy=strategy) rl._execute_sync_strategy(strategy=strategy)
``` ```
%% Output %% Output
2024-11-23 23:14:25 WARNING semantic_router.utils.logger TEMP | _remove_and_sync: 2024-11-23 23:47:15 WARNING semantic_router.utils.logger TEMP | _remove_and_sync:
chitchat: ['how are things going?', "how's the weather today?", "let's go to the chippy", 'lovely weather today', 'the weather is horrendous'] chitchat: ['how are things going?', "how's the weather today?", "let's go to the chippy", 'lovely weather today', 'the weather is horrendous']
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
time.sleep(3) time.sleep(3)
rl.is_synced() rl.is_synced()
``` ```
%% Output %% Output
False True
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We can check our diff method to see what the `local` sync did: We can check our diff method to see what the `local` sync did:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
rl.get_utterance_diff() rl.get_utterance_diff()
``` ```
%% Output %% Output
[" politics: don't you just hate the president", [" politics: don't you just hate the president",
" politics: don't you just love the president", " politics: don't you just love the president",
" politics: isn't politics the best thing ever", " politics: isn't politics the best thing ever",
' politics: they will save the country!', ' politics: they will save the country!',
" politics: they're going to destroy this country!", " politics: they're going to destroy this country!",
" politics: why don't you tell me about your political opinions"] " politics: why don't you tell me about your political opinions"]
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
As expected, it took all local utterances and applied them to the remote instance, removing all utterances that were only present in the remote instance. As expected, it took all local utterances and applied them to the remote instance, removing all utterances that were only present in the remote instance.
We can simplify this process significantly by running the `rl.sync` method with our chosen `sync_mode`: We can simplify this process significantly by running the `rl.sync` method with our chosen `sync_mode`:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
rl.sync(sync_mode="local") rl.sync(sync_mode="local")
``` ```
%% Output %% Output
2024-11-23 23:47:23 WARNING semantic_router.utils.logger Local and remote route layers are already synchronized.
[" politics: don't you just hate the president", [" politics: don't you just hate the president",
" politics: don't you just love the president", " politics: don't you just love the president",
" politics: isn't politics the best thing ever", " politics: isn't politics the best thing ever",
' politics: they will save the country!', ' politics: they will save the country!',
" politics: they're going to destroy this country!", " politics: they're going to destroy this country!",
" politics: why don't you tell me about your political opinions"] " politics: why don't you tell me about your political opinions"]
%% Cell type:markdown id: tags:
---
......
from semantic_router.routers import LayerConfig, RouteLayer, HybridRouteLayer from semantic_router.routers import RouterConfig, RouteLayer, HybridRouter
from semantic_router.route import Route from semantic_router.route import Route
__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"] __all__ = ["RouteLayer", "HybridRouter", "Route", "RouterConfig"]
__version__ = "0.1.0.dev2" __version__ = "0.1.0.dev2"
...@@ -5,7 +5,6 @@ from numpy.linalg import norm ...@@ -5,7 +5,6 @@ from numpy.linalg import norm
from semantic_router.schema import ConfigParameter, Utterance from semantic_router.schema import ConfigParameter, Utterance
from semantic_router.index.local import LocalIndex from semantic_router.index.local import LocalIndex
from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
from typing import Any from typing import Any
...@@ -104,7 +103,9 @@ class HybridLocalIndex(LocalIndex): ...@@ -104,7 +103,9 @@ class HybridLocalIndex(LocalIndex):
# calculate sparse vec similarity # calculate sparse vec similarity
sparse_norm = norm(self.sparse_index, axis=1) sparse_norm = norm(self.sparse_index, axis=1)
xq_s_norm = norm(xq_s) # TODO: this used to be xq_s.T, should work without xq_s_norm = norm(xq_s) # TODO: this used to be xq_s.T, should work without
sim_s = np.squeeze(np.dot(self.sparse_index, xq_s.T)) / (sparse_norm * xq_s_norm) sim_s = np.squeeze(np.dot(self.sparse_index, xq_s.T)) / (
sparse_norm * xq_s_norm
)
total_sim = sim_d + sim_s total_sim = sim_d + sim_s
# get indices of top_k records # get indices of top_k records
top_k = min(top_k, total_sim.shape[0]) top_k = min(top_k, total_sim.shape[0])
......
...@@ -43,9 +43,7 @@ class LocalIndex(BaseIndex): ...@@ -43,9 +43,7 @@ class LocalIndex(BaseIndex):
self.utterances = np.concatenate([self.utterances, utterances_arr]) self.utterances = np.concatenate([self.utterances, utterances_arr])
def _remove_and_sync(self, routes_to_delete: dict): def _remove_and_sync(self, routes_to_delete: dict):
logger.warning( logger.warning(f"Sync remove is not implemented for {self.__class__.__name__}.")
f"Sync remove is not implemented for {self.__class__.__name__}."
)
def get_utterances(self) -> List[Utterance]: def get_utterances(self) -> List[Utterance]:
""" """
......
from semantic_router.routers.base import BaseRouteLayer, LayerConfig from semantic_router.routers.base import BaseRouter, RouterConfig
from semantic_router.routers.semantic import RouteLayer from semantic_router.routers.semantic import RouteLayer
from semantic_router.routers.hybrid import HybridRouteLayer from semantic_router.routers.hybrid import HybridRouter
__all__ = [ __all__ = [
"BaseRouteLayer", "BaseRouter",
"LayerConfig", "RouterConfig",
"RouteLayer", "RouteLayer",
"HybridRouteLayer", "HybridRouter",
] ]
...@@ -57,10 +57,10 @@ def is_valid(layer_config: str) -> bool: ...@@ -57,10 +57,10 @@ def is_valid(layer_config: str) -> bool:
return False return False
class LayerConfig: class RouterConfig:
""" """
Generates a LayerConfig object that can be used for initializing a Generates a RouterConfig object that can be used for initializing a
RouteLayer. Routers.
""" """
routes: List[Route] = [] routes: List[Route] = []
...@@ -80,7 +80,7 @@ class LayerConfig: ...@@ -80,7 +80,7 @@ class LayerConfig:
if encode_type.value == self.encoder_type: if encode_type.value == self.encoder_type:
if self.encoder_type == EncoderType.HUGGINGFACE.value: if self.encoder_type == EncoderType.HUGGINGFACE.value:
raise NotImplementedError( raise NotImplementedError(
"HuggingFace encoder not supported by LayerConfig yet." "HuggingFace encoder not supported by RouterConfig yet."
) )
encoder_name = EncoderDefault[encode_type.name].value[ encoder_name = EncoderDefault[encode_type.name].value[
"embedding_model" "embedding_model"
...@@ -91,7 +91,7 @@ class LayerConfig: ...@@ -91,7 +91,7 @@ class LayerConfig:
self.routes = routes self.routes = routes
@classmethod @classmethod
def from_file(cls, path: str) -> "LayerConfig": def from_file(cls, path: str) -> "RouterConfig":
logger.info(f"Loading route config from {path}") logger.info(f"Loading route config from {path}")
_, ext = os.path.splitext(path) _, ext = os.path.splitext(path)
with open(path, "r") as f: with open(path, "r") as f:
...@@ -143,7 +143,7 @@ class LayerConfig: ...@@ -143,7 +143,7 @@ class LayerConfig:
encoder_type: str = "openai", encoder_type: str = "openai",
encoder_name: Optional[str] = None, encoder_name: Optional[str] = None,
): ):
"""Initialize a LayerConfig from a list of tuples of routes and """Initialize a RouterConfig from a list of tuples of routes and
utterances. utterances.
:param route_tuples: A list of tuples, each containing a route name and an :param route_tuples: A list of tuples, each containing a route name and an
...@@ -182,9 +182,9 @@ class LayerConfig: ...@@ -182,9 +182,9 @@ class LayerConfig:
encoder_type: str = "openai", encoder_type: str = "openai",
encoder_name: Optional[str] = None, encoder_name: Optional[str] = None,
): ):
"""Initialize a LayerConfig from a BaseIndex object. """Initialize a RouterConfig from a BaseIndex object.
:param index: The index to initialize the LayerConfig from. :param index: The index to initialize the RouterConfig from.
:type index: BaseIndex :type index: BaseIndex
:param encoder_type: The type of encoder to use, defaults to "openai". :param encoder_type: The type of encoder to use, defaults to "openai".
:type encoder_type: str, optional :type encoder_type: str, optional
...@@ -275,7 +275,7 @@ class LayerConfig: ...@@ -275,7 +275,7 @@ class LayerConfig:
) )
class BaseRouteLayer(BaseModel): class BaseRouter(BaseModel):
encoder: BaseEncoder encoder: BaseEncoder
index: BaseIndex = Field(default_factory=BaseIndex) index: BaseIndex = Field(default_factory=BaseIndex)
score_threshold: Optional[float] = Field(default=None) score_threshold: Optional[float] = Field(default=None)
...@@ -365,7 +365,7 @@ class BaseRouteLayer(BaseModel): ...@@ -365,7 +365,7 @@ class BaseRouteLayer(BaseModel):
def _set_score_threshold(self): def _set_score_threshold(self):
"""Set the score threshold for the layer based on the encoder """Set the score threshold for the layer based on the encoder
score threshold. score threshold.
When no score threshold is used a default `None` value When no score threshold is used a default `None` value
is used, which means that a route will always be returned when is used, which means that a route will always be returned when
the layer is called.""" the layer is called."""
...@@ -688,18 +688,18 @@ class BaseRouteLayer(BaseModel): ...@@ -688,18 +688,18 @@ class BaseRouteLayer(BaseModel):
@classmethod @classmethod
def from_json(cls, file_path: str): def from_json(cls, file_path: str):
config = LayerConfig.from_file(file_path) config = RouterConfig.from_file(file_path)
encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes) return cls(encoder=encoder, routes=config.routes)
@classmethod @classmethod
def from_yaml(cls, file_path: str): def from_yaml(cls, file_path: str):
config = LayerConfig.from_file(file_path) config = RouterConfig.from_file(file_path)
encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes) return cls(encoder=encoder, routes=config.routes)
@classmethod @classmethod
def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None): def from_config(cls, config: RouterConfig, index: Optional[BaseIndex] = None):
encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes, index=index) return cls(encoder=encoder, routes=config.routes, index=index)
...@@ -1115,8 +1115,8 @@ class BaseRouteLayer(BaseModel): ...@@ -1115,8 +1115,8 @@ class BaseRouteLayer(BaseModel):
route.name, self.score_threshold route.name, self.score_threshold
) )
def to_config(self) -> LayerConfig: def to_config(self) -> RouterConfig:
return LayerConfig( return RouterConfig(
encoder_type=self.encoder.type, encoder_type=self.encoder.type,
encoder_name=self.encoder.name, encoder_name=self.encoder.name,
routes=self.routes, routes=self.routes,
...@@ -1226,7 +1226,7 @@ class BaseRouteLayer(BaseModel): ...@@ -1226,7 +1226,7 @@ class BaseRouteLayer(BaseModel):
def threshold_random_search( def threshold_random_search(
route_layer: BaseRouteLayer, route_layer: BaseRouter,
search_range: Union[int, float], search_range: Union[int, float],
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Performs a random search iteration given a route layer and a search range.""" """Performs a random search iteration given a route layer and a search range."""
......
...@@ -13,13 +13,13 @@ from semantic_router.route import Route ...@@ -13,13 +13,13 @@ from semantic_router.route import Route
from semantic_router.index.hybrid_local import HybridLocalIndex from semantic_router.index.hybrid_local import HybridLocalIndex
from semantic_router.schema import RouteChoice from semantic_router.schema import RouteChoice
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
from semantic_router.routers.base import BaseRouteLayer from semantic_router.routers.base import BaseRouter
from semantic_router.llms import BaseLLM from semantic_router.llms import BaseLLM
class HybridRouteLayer(BaseRouteLayer): class HybridRouter(BaseRouter):
"""A hybrid layer that uses both dense and sparse embeddings to classify routes. """A hybrid layer that uses both dense and sparse embeddings to classify routes."""
"""
# there are a few additional attributes for hybrid # there are a few additional attributes for hybrid
sparse_encoder: BM25Encoder = Field(default_factory=BM25Encoder) sparse_encoder: BM25Encoder = Field(default_factory=BM25Encoder)
alpha: float = 0.3 alpha: float = 0.3
...@@ -74,7 +74,7 @@ class HybridRouteLayer(BaseRouteLayer): ...@@ -74,7 +74,7 @@ class HybridRouteLayer(BaseRouteLayer):
@validator("sparse_encoder", pre=True, always=True) @validator("sparse_encoder", pre=True, always=True)
def set_sparse_encoder(cls, v): def set_sparse_encoder(cls, v):
return v if v is not None else BM25Encoder() return v if v is not None else BM25Encoder()
@validator("index", pre=True, always=True) @validator("index", pre=True, always=True)
def set_index(cls, v): def set_index(cls, v):
return v if v is not None else HybridLocalIndex() return v if v is not None else HybridLocalIndex()
...@@ -87,10 +87,10 @@ class HybridRouteLayer(BaseRouteLayer): ...@@ -87,10 +87,10 @@ class HybridRouteLayer(BaseRouteLayer):
# TODO: add alpha as a parameter # TODO: add alpha as a parameter
# create dense query vector # create dense query vector
xq_d = np.array(self.encoder(text)) xq_d = np.array(self.encoder(text))
#xq_d = np.squeeze(xq_d) # Reduce to 1d array. # xq_d = np.squeeze(xq_d) # Reduce to 1d array.
# create sparse query vector # create sparse query vector
xq_s = np.array(self.sparse_encoder(text)) xq_s = np.array(self.sparse_encoder(text))
#xq_s = np.squeeze(xq_s) # xq_s = np.squeeze(xq_s)
# convex scaling # convex scaling
xq_d, xq_s = self._convex_scaling(xq_d, xq_s) xq_d, xq_s = self._convex_scaling(xq_d, xq_s)
return xq_d, xq_s return xq_d, xq_s
...@@ -107,10 +107,10 @@ class HybridRouteLayer(BaseRouteLayer): ...@@ -107,10 +107,10 @@ class HybridRouteLayer(BaseRouteLayer):
dense_vec, sparse_vec = await asyncio.gather(dense_coro, sparse_coro) dense_vec, sparse_vec = await asyncio.gather(dense_coro, sparse_coro)
# create dense query vector # create dense query vector
xq_d = np.array(dense_vec) xq_d = np.array(dense_vec)
#xq_d = np.squeeze(xq_d) # reduce to 1d array # xq_d = np.squeeze(xq_d) # reduce to 1d array
# create sparse query vector # create sparse query vector
xq_s = np.array(sparse_vec) xq_s = np.array(sparse_vec)
#xq_s = np.squeeze(xq_s) # xq_s = np.squeeze(xq_s)
# convex scaling # convex scaling
xq_d, xq_s = self._convex_scaling(xq_d, xq_s) xq_d, xq_s = self._convex_scaling(xq_d, xq_s)
return xq_d, xq_s return xq_d, xq_s
...@@ -137,15 +137,18 @@ class HybridRouteLayer(BaseRouteLayer): ...@@ -137,15 +137,18 @@ class HybridRouteLayer(BaseRouteLayer):
vector=np.array(vector) if isinstance(vector, list) else vector, vector=np.array(vector) if isinstance(vector, list) else vector,
top_k=self.top_k, top_k=self.top_k,
route_filter=route_filter, route_filter=route_filter,
sparse_vector=np.array(sparse_vector) if isinstance(sparse_vector, list) else sparse_vector, sparse_vector=(
np.array(sparse_vector)
if isinstance(sparse_vector, list)
else sparse_vector
),
)
top_class, top_class_scores = self._semantic_classify(
list(zip(scores, route_names))
) )
top_class, top_class_scores = self._semantic_classify(list(zip(scores, route_names)))
passed = self._pass_threshold(top_class_scores, self.score_threshold) passed = self._pass_threshold(top_class_scores, self.score_threshold)
if passed: if passed:
return RouteChoice( return RouteChoice(name=top_class, similarity_score=max(top_class_scores))
name=top_class,
similarity_score=max(top_class_scores)
)
else: else:
return RouteChoice() return RouteChoice()
......
import importlib
import json import json
import os
import random import random
import hashlib
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic.v1 import validator, BaseModel, Field from pydantic.v1 import validator, Field
import numpy as np import numpy as np
import yaml # type: ignore
from tqdm.auto import tqdm from tqdm.auto import tqdm
from semantic_router.encoders import AutoEncoder, BaseEncoder, OpenAIEncoder from semantic_router.encoders import AutoEncoder, BaseEncoder, OpenAIEncoder
...@@ -16,15 +12,13 @@ from semantic_router.index.local import LocalIndex ...@@ -16,15 +12,13 @@ from semantic_router.index.local import LocalIndex
from semantic_router.index.pinecone import PineconeIndex from semantic_router.index.pinecone import PineconeIndex
from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route from semantic_router.route import Route
from semantic_router.routers.base import BaseRouteLayer from semantic_router.routers.base import BaseRouter, RouterConfig
from semantic_router.schema import ( from semantic_router.schema import (
ConfigParameter, ConfigParameter,
EncoderType,
RouteChoice, RouteChoice,
Utterance, Utterance,
UtteranceDiff, UtteranceDiff,
) )
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
...@@ -58,222 +52,7 @@ def is_valid(layer_config: str) -> bool: ...@@ -58,222 +52,7 @@ def is_valid(layer_config: str) -> bool:
return False return False
class LayerConfig: class RouteLayer(BaseRouter):
"""
Generates a LayerConfig object that can be used for initializing a
RouteLayer.
"""
routes: List[Route] = []
def __init__(
self,
routes: List[Route] = [],
encoder_type: str = "openai",
encoder_name: Optional[str] = None,
):
self.encoder_type = encoder_type
if encoder_name is None:
for encode_type in EncoderType:
if encode_type.value == self.encoder_type:
if self.encoder_type == EncoderType.HUGGINGFACE.value:
raise NotImplementedError(
"HuggingFace encoder not supported by LayerConfig yet."
)
encoder_name = EncoderDefault[encode_type.name].value[
"embedding_model"
]
break
logger.info(f"Using default {encoder_type} encoder: {encoder_name}")
self.encoder_name = encoder_name
self.routes = routes
@classmethod
def from_file(cls, path: str) -> "LayerConfig":
logger.info(f"Loading route config from {path}")
_, ext = os.path.splitext(path)
with open(path, "r") as f:
if ext == ".json":
layer = json.load(f)
elif ext in [".yaml", ".yml"]:
layer = yaml.safe_load(f)
else:
raise ValueError(
"Unsupported file type. Only .json and .yaml are supported"
)
if not is_valid(json.dumps(layer)):
raise Exception("Invalid config JSON or YAML")
encoder_type = layer["encoder_type"]
encoder_name = layer["encoder_name"]
routes = []
for route_data in layer["routes"]:
# Handle the 'llm' field specially if it exists
if "llm" in route_data and route_data["llm"] is not None:
llm_data = route_data.pop(
"llm"
) # Remove 'llm' from route_data and handle it separately
# Use the module path directly from llm_data without modification
llm_module_path = llm_data["module"]
# Dynamically import the module and then the class from that module
llm_module = importlib.import_module(llm_module_path)
llm_class = getattr(llm_module, llm_data["class"])
# Instantiate the LLM class with the provided model name
llm = llm_class(name=llm_data["model"])
# Reassign the instantiated llm object back to route_data
route_data["llm"] = llm
# Dynamically create the Route object using the remaining route_data
route = Route(**route_data)
routes.append(route)
return cls(
encoder_type=encoder_type, encoder_name=encoder_name, routes=routes
)
@classmethod
def from_tuples(
cls,
route_tuples: List[
Tuple[str, str, Optional[List[Dict[str, Any]]], Dict[str, Any]]
],
encoder_type: str = "openai",
encoder_name: Optional[str] = None,
):
"""Initialize a LayerConfig from a list of tuples of routes and
utterances.
:param route_tuples: A list of tuples, each containing a route name and an
associated utterance.
:type route_tuples: List[Tuple[str, str]]
:param encoder_type: The type of encoder to use, defaults to "openai".
:type encoder_type: str, optional
:param encoder_name: The name of the encoder to use, defaults to None.
:type encoder_name: Optional[str], optional
"""
routes_dict: Dict[str, Route] = {}
# first create a dictionary of route names to Route objects
# TODO: duplicated code with BaseIndex.get_routes()
for route_name, utterance, function_schema, metadata in route_tuples:
# if the route is not in the dictionary, add it
if route_name not in routes_dict:
routes_dict[route_name] = Route(
name=route_name,
utterances=[utterance],
function_schemas=function_schema,
metadata=metadata,
)
else:
# otherwise, add the utterance to the route
routes_dict[route_name].utterances.append(utterance)
# then create a list of routes from the dictionary
routes: List[Route] = []
for route_name, route in routes_dict.items():
routes.append(route)
return cls(routes=routes, encoder_type=encoder_type, encoder_name=encoder_name)
@classmethod
def from_index(
cls,
index: BaseIndex,
encoder_type: str = "openai",
encoder_name: Optional[str] = None,
):
"""Initialize a LayerConfig from a BaseIndex object.
:param index: The index to initialize the LayerConfig from.
:type index: BaseIndex
:param encoder_type: The type of encoder to use, defaults to "openai".
:type encoder_type: str, optional
:param encoder_name: The name of the encoder to use, defaults to None.
:type encoder_name: Optional[str], optional
"""
remote_routes = index.get_utterances()
return cls.from_tuples(
route_tuples=[utt.to_tuple() for utt in remote_routes],
encoder_type=encoder_type,
encoder_name=encoder_name,
)
def to_dict(self) -> Dict[str, Any]:
return {
"encoder_type": self.encoder_type,
"encoder_name": self.encoder_name,
"routes": [route.to_dict() for route in self.routes],
}
def to_file(self, path: str):
"""Save the routes to a file in JSON or YAML format"""
logger.info(f"Saving route config to {path}")
_, ext = os.path.splitext(path)
# Check file extension before creating directories or files
if ext not in [".json", ".yaml", ".yml"]:
raise ValueError(
"Unsupported file type. Only .json and .yaml are supported"
)
dir_name = os.path.dirname(path)
# Create the directory if it doesn't exist and dir_name is not an empty string
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
with open(path, "w") as f:
if ext == ".json":
json.dump(self.to_dict(), f, indent=4)
elif ext in [".yaml", ".yml"]:
yaml.safe_dump(self.to_dict(), f)
def to_utterances(self) -> List[Utterance]:
"""Convert the routes to a list of Utterance objects.
:return: A list of Utterance objects.
:rtype: List[Utterance]
"""
utterances = []
for route in self.routes:
utterances.extend(
[
Utterance(
route=route.name,
utterance=x,
function_schemas=route.function_schemas,
metadata=route.metadata or {},
)
for x in route.utterances
]
)
return utterances
def add(self, route: Route):
self.routes.append(route)
logger.info(f"Added route `{route.name}`")
def get(self, name: str) -> Optional[Route]:
for route in self.routes:
if route.name == name:
return route
logger.error(f"Route `{name}` not found")
return None
def remove(self, name: str):
if name not in [route.name for route in self.routes]:
logger.error(f"Route `{name}` not found")
else:
self.routes = [route for route in self.routes if route.name != name]
logger.info(f"Removed route `{name}`")
def get_hash(self) -> ConfigParameter:
layer = self.to_dict()
return ConfigParameter(
field="sr_hash",
value=hashlib.sha256(json.dumps(layer).encode()).hexdigest(),
)
class RouteLayer(BaseRouteLayer):
index: BaseIndex = Field(default_factory=LocalIndex) index: BaseIndex = Field(default_factory=LocalIndex)
@validator("index", pre=True, always=True) @validator("index", pre=True, always=True)
...@@ -655,18 +434,18 @@ class RouteLayer(BaseRouteLayer): ...@@ -655,18 +434,18 @@ class RouteLayer(BaseRouteLayer):
@classmethod @classmethod
def from_json(cls, file_path: str): def from_json(cls, file_path: str):
config = LayerConfig.from_file(file_path) config = RouterConfig.from_file(file_path)
encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes) return cls(encoder=encoder, routes=config.routes)
@classmethod @classmethod
def from_yaml(cls, file_path: str): def from_yaml(cls, file_path: str):
config = LayerConfig.from_file(file_path) config = RouterConfig.from_file(file_path)
encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes) return cls(encoder=encoder, routes=config.routes)
@classmethod @classmethod
def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None): def from_config(cls, config: RouterConfig, index: Optional[BaseIndex] = None):
encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes, index=index) return cls(encoder=encoder, routes=config.routes, index=index)
...@@ -1069,8 +848,8 @@ class RouteLayer(BaseRouteLayer): ...@@ -1069,8 +848,8 @@ class RouteLayer(BaseRouteLayer):
route.name, self.score_threshold route.name, self.score_threshold
) )
def to_config(self) -> LayerConfig: def to_config(self) -> RouterConfig:
return LayerConfig( return RouterConfig(
encoder_type=self.encoder.type, encoder_type=self.encoder.type,
encoder_name=self.encoder.name, encoder_name=self.encoder.name,
routes=self.routes, routes=self.routes,
......
...@@ -8,7 +8,7 @@ from semantic_router.encoders import ( ...@@ -8,7 +8,7 @@ from semantic_router.encoders import (
OpenAIEncoder, OpenAIEncoder,
TfidfEncoder, TfidfEncoder,
) )
from semantic_router.OLD_hybrid_layer import HybridRouteLayer from semantic_router.OLD_hybrid_layer import HybridRouter
from semantic_router.route import Route from semantic_router.route import Route
...@@ -78,9 +78,9 @@ sparse_encoder = BM25Encoder(use_default_params=False) ...@@ -78,9 +78,9 @@ sparse_encoder = BM25Encoder(use_default_params=False)
sparse_encoder.fit(["The quick brown fox", "jumps over the lazy dog", "Hello, world!"]) sparse_encoder.fit(["The quick brown fox", "jumps over the lazy dog", "Hello, world!"])
class TestHybridRouteLayer: class TestHybridRouter:
def test_initialization(self, openai_encoder, routes): def test_initialization(self, openai_encoder, routes):
route_layer = HybridRouteLayer( route_layer = HybridRouter(
encoder=openai_encoder, encoder=openai_encoder,
sparse_encoder=sparse_encoder, sparse_encoder=sparse_encoder,
routes=routes, routes=routes,
...@@ -96,18 +96,18 @@ class TestHybridRouteLayer: ...@@ -96,18 +96,18 @@ class TestHybridRouteLayer:
assert len(set(route_layer.categories)) == 2 assert len(set(route_layer.categories)) == 2
def test_initialization_different_encoders(self, cohere_encoder, openai_encoder): def test_initialization_different_encoders(self, cohere_encoder, openai_encoder):
route_layer_cohere = HybridRouteLayer( route_layer_cohere = HybridRouter(
encoder=cohere_encoder, sparse_encoder=sparse_encoder encoder=cohere_encoder, sparse_encoder=sparse_encoder
) )
assert route_layer_cohere.score_threshold == 0.3 assert route_layer_cohere.score_threshold == 0.3
route_layer_openai = HybridRouteLayer( route_layer_openai = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder encoder=openai_encoder, sparse_encoder=sparse_encoder
) )
assert route_layer_openai.score_threshold == 0.3 assert route_layer_openai.score_threshold == 0.3
def test_add_route(self, openai_encoder): def test_add_route(self, openai_encoder):
route_layer = HybridRouteLayer( route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder encoder=openai_encoder, sparse_encoder=sparse_encoder
) )
route = Route(name="Route 3", utterances=["Yes", "No"]) route = Route(name="Route 3", utterances=["Yes", "No"])
...@@ -117,7 +117,7 @@ class TestHybridRouteLayer: ...@@ -117,7 +117,7 @@ class TestHybridRouteLayer:
assert len(set(route_layer.categories)) == 1 assert len(set(route_layer.categories)) == 1
def test_add_multiple_routes(self, openai_encoder, routes): def test_add_multiple_routes(self, openai_encoder, routes):
route_layer = HybridRouteLayer( route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder encoder=openai_encoder, sparse_encoder=sparse_encoder
) )
for route in routes: for route in routes:
...@@ -127,20 +127,20 @@ class TestHybridRouteLayer: ...@@ -127,20 +127,20 @@ class TestHybridRouteLayer:
assert len(set(route_layer.categories)) == 2 assert len(set(route_layer.categories)) == 2
def test_query_and_classification(self, openai_encoder, routes): def test_query_and_classification(self, openai_encoder, routes):
route_layer = HybridRouteLayer( route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes
) )
query_result = route_layer("Hello") query_result = route_layer("Hello")
assert query_result in ["Route 1", "Route 2"] assert query_result in ["Route 1", "Route 2"]
def test_query_with_no_index(self, openai_encoder): def test_query_with_no_index(self, openai_encoder):
route_layer = HybridRouteLayer( route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder encoder=openai_encoder, sparse_encoder=sparse_encoder
) )
assert route_layer("Anything") is None assert route_layer("Anything") is None
def test_semantic_classify(self, openai_encoder, routes): def test_semantic_classify(self, openai_encoder, routes):
route_layer = HybridRouteLayer( route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes
) )
classification, score = route_layer._semantic_classify( classification, score = route_layer._semantic_classify(
...@@ -153,7 +153,7 @@ class TestHybridRouteLayer: ...@@ -153,7 +153,7 @@ class TestHybridRouteLayer:
assert score == [0.9] assert score == [0.9]
def test_semantic_classify_multiple_routes(self, openai_encoder, routes): def test_semantic_classify_multiple_routes(self, openai_encoder, routes):
route_layer = HybridRouteLayer( route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes
) )
classification, score = route_layer._semantic_classify( classification, score = route_layer._semantic_classify(
...@@ -167,21 +167,19 @@ class TestHybridRouteLayer: ...@@ -167,21 +167,19 @@ class TestHybridRouteLayer:
assert score == [0.9, 0.8] assert score == [0.9, 0.8]
def test_pass_threshold(self, openai_encoder): def test_pass_threshold(self, openai_encoder):
route_layer = HybridRouteLayer( route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder encoder=openai_encoder, sparse_encoder=sparse_encoder
) )
assert not route_layer._pass_threshold([], 0.5) assert not route_layer._pass_threshold([], 0.5)
assert route_layer._pass_threshold([0.6, 0.7], 0.5) assert route_layer._pass_threshold([0.6, 0.7], 0.5)
def test_failover_score_threshold(self, base_encoder): def test_failover_score_threshold(self, base_encoder):
route_layer = HybridRouteLayer( route_layer = HybridRouter(encoder=base_encoder, sparse_encoder=sparse_encoder)
encoder=base_encoder, sparse_encoder=sparse_encoder
)
assert base_encoder.score_threshold == 0.50 assert base_encoder.score_threshold == 0.50
assert route_layer.score_threshold == 0.50 assert route_layer.score_threshold == 0.50
def test_add_route_tfidf(self, cohere_encoder, tfidf_encoder, routes): def test_add_route_tfidf(self, cohere_encoder, tfidf_encoder, routes):
hybrid_route_layer = HybridRouteLayer( hybrid_route_layer = HybridRouter(
encoder=cohere_encoder, encoder=cohere_encoder,
sparse_encoder=tfidf_encoder, sparse_encoder=tfidf_encoder,
routes=routes[:-1], routes=routes[:-1],
...@@ -195,7 +193,7 @@ class TestHybridRouteLayer: ...@@ -195,7 +193,7 @@ class TestHybridRouteLayer:
def test_setting_aggregation_methods(self, openai_encoder, routes): def test_setting_aggregation_methods(self, openai_encoder, routes):
for agg in ["sum", "mean", "max"]: for agg in ["sum", "mean", "max"]:
route_layer = HybridRouteLayer( route_layer = HybridRouter(
encoder=openai_encoder, encoder=openai_encoder,
sparse_encoder=sparse_encoder, sparse_encoder=sparse_encoder,
routes=routes, routes=routes,
...@@ -218,7 +216,7 @@ class TestHybridRouteLayer: ...@@ -218,7 +216,7 @@ class TestHybridRouteLayer:
{"route": "Route 3", "score": 1.0}, {"route": "Route 3", "score": 1.0},
] ]
for agg in ["sum", "mean", "max"]: for agg in ["sum", "mean", "max"]:
route_layer = HybridRouteLayer( route_layer = HybridRouter(
encoder=openai_encoder, encoder=openai_encoder,
sparse_encoder=sparse_encoder, sparse_encoder=sparse_encoder,
routes=routes, routes=routes,
......
...@@ -10,7 +10,7 @@ from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder ...@@ -10,7 +10,7 @@ from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder
from semantic_router.index.local import LocalIndex from semantic_router.index.local import LocalIndex
from semantic_router.index.pinecone import PineconeIndex from semantic_router.index.pinecone import PineconeIndex
from semantic_router.index.qdrant import QdrantIndex from semantic_router.index.qdrant import QdrantIndex
from semantic_router.routers import LayerConfig, RouteLayer from semantic_router.routers import RouterConfig, RouteLayer
from semantic_router.llms.base import BaseLLM from semantic_router.llms.base import BaseLLM
from semantic_router.route import Route from semantic_router.route import Route
from platform import python_version from platform import python_version
...@@ -588,8 +588,8 @@ class TestRouteLayer: ...@@ -588,8 +588,8 @@ class TestRouteLayer:
layer_json() layer_json()
) # Assuming layer_json() returns a valid JSON string ) # Assuming layer_json() returns a valid JSON string
# Load the LayerConfig from the temporary file # Load the RouterConfig from the temporary file
layer_config = LayerConfig.from_file(str(config_path)) layer_config = RouterConfig.from_file(str(config_path))
# Assertions to verify the loaded configuration # Assertions to verify the loaded configuration
assert layer_config.encoder_type == "cohere" assert layer_config.encoder_type == "cohere"
...@@ -604,8 +604,8 @@ class TestRouteLayer: ...@@ -604,8 +604,8 @@ class TestRouteLayer:
layer_yaml() layer_yaml()
) # Assuming layer_yaml() returns a valid YAML string ) # Assuming layer_yaml() returns a valid YAML string
# Load the LayerConfig from the temporary file # Load the RouterConfig from the temporary file
layer_config = LayerConfig.from_file(str(config_path)) layer_config = RouterConfig.from_file(str(config_path))
# Assertions to verify the loaded configuration # Assertions to verify the loaded configuration
assert layer_config.encoder_type == "cohere" assert layer_config.encoder_type == "cohere"
...@@ -615,7 +615,7 @@ class TestRouteLayer: ...@@ -615,7 +615,7 @@ class TestRouteLayer:
def test_from_file_invalid_path(self, index_cls): def test_from_file_invalid_path(self, index_cls):
with pytest.raises(FileNotFoundError) as excinfo: with pytest.raises(FileNotFoundError) as excinfo:
LayerConfig.from_file("nonexistent_path.json") RouterConfig.from_file("nonexistent_path.json")
assert "[Errno 2] No such file or directory: 'nonexistent_path.json'" in str( assert "[Errno 2] No such file or directory: 'nonexistent_path.json'" in str(
excinfo.value excinfo.value
) )
...@@ -626,7 +626,7 @@ class TestRouteLayer: ...@@ -626,7 +626,7 @@ class TestRouteLayer:
config_path.write_text(layer_json()) config_path.write_text(layer_json())
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
LayerConfig.from_file(str(config_path)) RouterConfig.from_file(str(config_path))
assert "Unsupported file type" in str(excinfo.value) assert "Unsupported file type" in str(excinfo.value)
def test_from_file_invalid_config(self, tmp_path, index_cls): def test_from_file_invalid_config(self, tmp_path, index_cls):
...@@ -645,10 +645,10 @@ class TestRouteLayer: ...@@ -645,10 +645,10 @@ class TestRouteLayer:
# Patch the is_valid function to return False for this test # Patch the is_valid function to return False for this test
with patch("semantic_router.layer.is_valid", return_value=False): with patch("semantic_router.layer.is_valid", return_value=False):
# Attempt to load the LayerConfig from the temporary file # Attempt to load the RouterConfig from the temporary file
# and assert that it raises an exception due to invalid configuration # and assert that it raises an exception due to invalid configuration
with pytest.raises(Exception) as excinfo: with pytest.raises(Exception) as excinfo:
LayerConfig.from_file(str(config_path)) RouterConfig.from_file(str(config_path))
assert "Invalid config JSON or YAML" in str( assert "Invalid config JSON or YAML" in str(
excinfo.value excinfo.value
), "Loading an invalid configuration should raise an exception." ), "Loading an invalid configuration should raise an exception."
...@@ -675,8 +675,8 @@ class TestRouteLayer: ...@@ -675,8 +675,8 @@ class TestRouteLayer:
with open(config_path, "w") as file: with open(config_path, "w") as file:
file.write(llm_config_json) file.write(llm_config_json)
# Load the LayerConfig from the temporary file # Load the RouterConfig from the temporary file
layer_config = LayerConfig.from_file(str(config_path)) layer_config = RouterConfig.from_file(str(config_path))
# Using BaseLLM because trying to create a usable Mock LLM is a nightmare. # Using BaseLLM because trying to create a usable Mock LLM is a nightmare.
assert isinstance( assert isinstance(
...@@ -939,61 +939,61 @@ class TestLayerFit: ...@@ -939,61 +939,61 @@ class TestLayerFit:
# Add more tests for edge cases and error handling as needed. # Add more tests for edge cases and error handling as needed.
class TestLayerConfig: class TestRouterConfig:
def test_init(self): def test_init(self):
layer_config = LayerConfig() layer_config = RouterConfig()
assert layer_config.routes == [] assert layer_config.routes == []
def test_to_file_json(self): def test_to_file_json(self):
route = Route(name="test", utterances=["utterance"]) route = Route(name="test", utterances=["utterance"])
layer_config = LayerConfig(routes=[route]) layer_config = RouterConfig(routes=[route])
with patch("builtins.open", mock_open()) as mocked_open: with patch("builtins.open", mock_open()) as mocked_open:
layer_config.to_file("data/test_output.json") layer_config.to_file("data/test_output.json")
mocked_open.assert_called_once_with("data/test_output.json", "w") mocked_open.assert_called_once_with("data/test_output.json", "w")
def test_to_file_yaml(self): def test_to_file_yaml(self):
route = Route(name="test", utterances=["utterance"]) route = Route(name="test", utterances=["utterance"])
layer_config = LayerConfig(routes=[route]) layer_config = RouterConfig(routes=[route])
with patch("builtins.open", mock_open()) as mocked_open: with patch("builtins.open", mock_open()) as mocked_open:
layer_config.to_file("data/test_output.yaml") layer_config.to_file("data/test_output.yaml")
mocked_open.assert_called_once_with("data/test_output.yaml", "w") mocked_open.assert_called_once_with("data/test_output.yaml", "w")
def test_to_file_invalid(self): def test_to_file_invalid(self):
route = Route(name="test", utterances=["utterance"]) route = Route(name="test", utterances=["utterance"])
layer_config = LayerConfig(routes=[route]) layer_config = RouterConfig(routes=[route])
with pytest.raises(ValueError): with pytest.raises(ValueError):
layer_config.to_file("test_output.txt") layer_config.to_file("test_output.txt")
def test_from_file_json(self): def test_from_file_json(self):
mock_json_data = layer_json() mock_json_data = layer_json()
with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open: with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open:
layer_config = LayerConfig.from_file("data/test.json") layer_config = RouterConfig.from_file("data/test.json")
mocked_open.assert_called_once_with("data/test.json", "r") mocked_open.assert_called_once_with("data/test.json", "r")
assert isinstance(layer_config, LayerConfig) assert isinstance(layer_config, RouterConfig)
def test_from_file_yaml(self): def test_from_file_yaml(self):
mock_yaml_data = layer_yaml() mock_yaml_data = layer_yaml()
with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open: with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open:
layer_config = LayerConfig.from_file("data/test.yaml") layer_config = RouterConfig.from_file("data/test.yaml")
mocked_open.assert_called_once_with("data/test.yaml", "r") mocked_open.assert_called_once_with("data/test.yaml", "r")
assert isinstance(layer_config, LayerConfig) assert isinstance(layer_config, RouterConfig)
def test_from_file_invalid(self): def test_from_file_invalid(self):
with open("test.txt", "w") as f: with open("test.txt", "w") as f:
f.write("dummy content") f.write("dummy content")
with pytest.raises(ValueError): with pytest.raises(ValueError):
LayerConfig.from_file("test.txt") RouterConfig.from_file("test.txt")
os.remove("test.txt") os.remove("test.txt")
def test_to_dict(self): def test_to_dict(self):
route = Route(name="test", utterances=["utterance"]) route = Route(name="test", utterances=["utterance"])
layer_config = LayerConfig(routes=[route]) layer_config = RouterConfig(routes=[route])
assert layer_config.to_dict()["routes"] == [route.to_dict()] assert layer_config.to_dict()["routes"] == [route.to_dict()]
def test_add(self): def test_add(self):
route = Route(name="test", utterances=["utterance"]) route = Route(name="test", utterances=["utterance"])
route2 = Route(name="test2", utterances=["utterance2"]) route2 = Route(name="test2", utterances=["utterance2"])
layer_config = LayerConfig() layer_config = RouterConfig()
layer_config.add(route) layer_config.add(route)
# confirm route added # confirm route added
assert layer_config.routes == [route] assert layer_config.routes == [route]
...@@ -1003,17 +1003,17 @@ class TestLayerConfig: ...@@ -1003,17 +1003,17 @@ class TestLayerConfig:
def test_get(self): def test_get(self):
route = Route(name="test", utterances=["utterance"]) route = Route(name="test", utterances=["utterance"])
layer_config = LayerConfig(routes=[route]) layer_config = RouterConfig(routes=[route])
assert layer_config.get("test") == route assert layer_config.get("test") == route
def test_get_not_found(self): def test_get_not_found(self):
route = Route(name="test", utterances=["utterance"]) route = Route(name="test", utterances=["utterance"])
layer_config = LayerConfig(routes=[route]) layer_config = RouterConfig(routes=[route])
assert layer_config.get("not_found") is None assert layer_config.get("not_found") is None
def test_remove(self): def test_remove(self):
route = Route(name="test", utterances=["utterance"]) route = Route(name="test", utterances=["utterance"])
layer_config = LayerConfig(routes=[route]) layer_config = RouterConfig(routes=[route])
layer_config.remove("test") layer_config.remove("test")
assert layer_config.routes == [] assert layer_config.routes == []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment