diff --git a/coverage.xml b/coverage.xml index 628f2950738068f9cb79abdcab95eeeb56d55f2b..27e175c33ef0897bd44b76a15c97bbb33f9afecc 100644 --- a/coverage.xml +++ b/coverage.xml @@ -22,8 +22,8 @@ <lines> <line number="1" hits="1"/> <line number="2" hits="1"/> - <line number="3" hits="1"/> - <line number="5" hits="1"/> + <line number="4" hits="1"/> + <line number="10" hits="1"/> <line number="11" hits="1"/> <line number="13" hits="1"/> <line number="16" hits="1"/> @@ -93,9 +93,7 @@ <line number="118" hits="1"/> <line number="119" hits="1"/> <line number="121" hits="1"/> - <line number="122" hits="1"/> <line number="123" hits="1"/> - <line number="124" hits="1"/> <line number="125" hits="1"/> <line number="126" hits="1"/> <line number="127" hits="1"/> @@ -619,19 +617,18 @@ <line number="8" hits="1"/> <line number="23" hits="1"/> <line number="24" hits="1"/> + <line number="25" hits="1"/> <line number="26" hits="1"/> <line number="27" hits="1"/> - <line number="29" hits="1"/> + <line number="28" hits="1"/> + <line number="31" hits="1"/> + <line number="32" hits="1"/> + <line number="33" hits="1"/> <line number="35" hits="1"/> <line number="37" hits="1"/> + <line number="38" hits="1"/> <line number="40" hits="1"/> - <line number="41" hits="1"/> - <line number="42" hits="1"/> - <line number="44" hits="1"/> - <line number="46" hits="1"/> - <line number="47" hits="1"/> - <line number="49" hits="1"/> - <line number="52" hits="1"/> + <line number="43" hits="1"/> </lines> </class> </classes> diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 14581286c87668c415e49a51b6081bbea2d1cab5..d082468b997726bc51e12f1e1d9f6b32ffe52697 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 6, "metadata": {}, "outputs": [ { diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 8b1da5ae75f0a8a9572996b8e416a282d2c48f1b..5d0cb4525a5b4b50d6ae564827044a9899a5172f 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -34,7 +34,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -qU semantic-router==0.0.6" + "!pip install -qU semantic-router==0.0.11" ] }, { @@ -46,21 +46,9 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "ename": "ImportError", - "evalue": "cannot import name 'Route' from 'semantic_router.schema' (/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/semantic_router/schema.py)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/Users/jakit/customers/aurelio/semantic-router/docs/examples/hybrid-layer.ipynb Cell 7\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/hybrid-layer.ipynb#X10sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msemantic_router\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mschema\u001b[39;00m \u001b[39mimport\u001b[39;00m Route\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/hybrid-layer.ipynb#X10sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m politics \u001b[39m=\u001b[39m Route(\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/hybrid-layer.ipynb#X10sZmlsZQ%3D%3D?line=3'>4</a>\u001b[0m name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mpolitics\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/hybrid-layer.ipynb#X10sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m utterances\u001b[39m=\u001b[39m[\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/hybrid-layer.ipynb#X10sZmlsZQ%3D%3D?line=11'>12</a>\u001b[0m ],\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/hybrid-layer.ipynb#X10sZmlsZQ%3D%3D?line=12'>13</a>\u001b[0m )\n", - "\u001b[0;31mImportError\u001b[0m: cannot import name 'Route' from 'semantic_router.schema' (/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/semantic_router/schema.py)" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from semantic_router.schema import Route\n", "\n", @@ -84,6 +72,13 @@ "Let's define another for good measure:" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index ff75369529fd0b116ef7900b3be1463934cdb460..475a12f09b4bcac340d2ceabf77199ceee9cb071 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -1,6 +1,5 @@ import numpy as np from numpy.linalg import norm -from tqdm.auto import tqdm from semantic_router.encoders import ( BaseEncoder, @@ -35,8 +34,9 @@ class HybridRouteLayer: # if routes list has been passed, we initialize index now if routes: # initialize index now - for route in tqdm(routes): - self._add_route(route=route) + # for route in tqdm(routes): + # self._add_route(route=route) + self._add_routes(routes) def __call__(self, text: str) -> str | None: results = self._query(text) @@ -78,6 +78,38 @@ class HybridRouteLayer: else: self.sparse_index = np.concatenate([self.sparse_index, sparse_embeds]) + def _add_routes(self, routes: list[Route]): + # create embeddings for all routes + logger.info("Creating embeddings for all routes...") + all_utterances = [ + utterance for route in routes for utterance in route.utterances + ] + dense_embeds = np.array(self.encoder(all_utterances)) + sparse_embeds = np.array(self.sparse_encoder(all_utterances)) + + # create route array + route_names = [route.name for route in routes for _ in route.utterances] + route_array = np.array(route_names) + self.categories = ( + np.concatenate([self.categories, route_array]) + if self.categories is not None + else route_array + ) + + # create utterance array (the dense index) + self.index = ( + np.concatenate([self.index, dense_embeds]) + if self.index is not None + else dense_embeds + ) + + # create sparse utterance array + self.sparse_index = ( + np.concatenate([self.sparse_index, sparse_embeds]) + if self.sparse_index is not None + else sparse_embeds + ) + def _query(self, text: str, top_k: int = 5): """Given some text, encodes and searches the index vector space to retrieve the top_k most similar records. diff --git a/semantic_router/utils/logger.py b/semantic_router/utils/logger.py index a001623a9c1eae5cc5a632f6afd69858f0319e32..00c83693435487016f819c4716900fc09f8b8b92 100644 --- a/semantic_router/utils/logger.py +++ b/semantic_router/utils/logger.py @@ -22,18 +22,9 @@ class CustomFormatter(colorlog.ColoredFormatter): def add_coloured_handler(logger): formatter = CustomFormatter() - console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) - - logging.basicConfig( - datefmt="%Y-%m-%d %H:%M:%S", - format="%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s", - force=True, - ) - logger.addHandler(console_handler) - return logger diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 06b5d733de6b3120486e373f7105a779877404c3..f87cb1d281b2884a10a9817b9a838c21e64a9881 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -60,7 +60,7 @@ class TestHybridRouteLayer: def test_add_route(self, openai_encoder): route_layer = HybridRouteLayer(encoder=openai_encoder) route = Route(name="Route 3", utterances=["Yes", "No"]) - route_layer.add(route) + route_layer._add_routes([route]) assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 2 assert len(set(route_layer.categories)) == 1