From b27b8ddf251b16c01f749743413bdd2643ecb307 Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Mon, 18 Dec 2023 17:00:44 +0200 Subject: [PATCH] wip --- docs/examples/function_calling.ipynb | 232 ++++++++++----------------- docs/examples/router.json | 24 +++ poetry.lock | 51 +++++- pyproject.toml | 1 + semantic_router/layer.py | 31 +++- semantic_router/schema.py | 11 ++ tests/unit/test_layer.py | 2 +- 7 files changed, 196 insertions(+), 156 deletions(-) create mode 100644 docs/examples/router.json diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 5d3be2fb..c41a8a2b 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,9 +9,19 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n" + ] + } + ], "source": [ "# OpenAI\n", "import openai\n", @@ -39,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -91,7 +101,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -113,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -130,16 +140,18 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import json\n", + "from typing import Callable\n", "\n", "from semantic_router.utils.logger import logger\n", + "from semantic_router.layer import Route\n", "\n", "\n", - "def generate_route(function) -> dict:\n", + "def generate_route(function: Callable) -> Route:\n", " logger.info(\"Generating config...\")\n", "\n", " function_schema = get_function_schema(function)\n", @@ -196,11 +208,10 @@ " try:\n", " route_config = json.loads(ai_message)\n", " logger.info(f\"Generated config: {route_config}\")\n", - " return route_config\n", + " return Route(**route_config)\n", " except json.JSONDecodeError as json_error:\n", " logger.error(f\"JSON parsing error {json_error}\")\n", - " print(f\"AI message: {ai_message}\")\n", - " return {\"error\": \"Failed to generate config\"}" + " raise Exception(f\"Failed to generate a valid Route {json_error}\")" ] }, { @@ -212,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -228,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -299,40 +310,6 @@ " return {\"error\": \"Failed to validate parameters\"}" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Set up the routing layer" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.schema import Route\n", - "from semantic_router.encoders import CohereEncoder\n", - "from semantic_router.layer import RouteLayer\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "def create_router(routes: list[dict]) -> RouteLayer:\n", - " logger.info(\"Creating route layer...\")\n", - " encoder = CohereEncoder()\n", - "\n", - " route_list: list[Route] = []\n", - " for route in routes:\n", - " if \"name\" in route and \"utterances\" in route:\n", - " print(f\"Route: {route}\")\n", - " route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", - " else:\n", - " logger.warning(f\"Misconfigured route: {route}\")\n", - "\n", - " return RouteLayer(encoder=encoder, routes=route_list)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -342,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -387,57 +364,18 @@ "### Workflow" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Functions as a tool" + ] + }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 9, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[31m2023-12-18 12:18:00 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:00 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger AI message: {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"what is the time in new york\",\n", - " \"can you tell me the time in london\",\n", - " \"get me the current time in tokyo\",\n", - " \"i need to know the time in sydney\",\n", - " \"please tell me the current time in paris\"\n", - " ]\n", - "}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[31m2023-12-18 12:18:07 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:07 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger AI message: {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"Can I get the latest news in Canada?\",\n", - " \"Show me the recent news in the US\",\n", - " \"I would like to know about the sports news in England\",\n", - " \"Let's check the technology news in Japan\",\n", - " \"Show me the health related news in Germany\"\n", - " ]\n", - "}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Route: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\n", - "Route: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\n" - ] - } - ], + "outputs": [], "source": [ "def get_time(location: str) -> str:\n", " \"\"\"Useful to get the time in a specific location\"\"\"\n", @@ -450,32 +388,21 @@ " print(\n", " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", " )\n", - " return \"get_news\"\n", - "\n", - "\n", - "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", - "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", - "\n", - "# Tools\n", - "tools = [get_time, get_news]" + " return \"get_news\"" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:00 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:00 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger AI message: \n", " Example output:\n", " {\n", " \"name\": \"get_time\",\n", @@ -487,10 +414,10 @@ " \"Can you tell me the time in Berlin?\"\n", " ]\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:07 INFO semantic_router.utils.logger AI message: \n", " Example output:\n", " {\n", " \"name\": \"get_news\",\n", @@ -502,61 +429,62 @@ " \"What's the latest news from Germany?\"\n", " ]\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" + "\u001b[32m2023-12-18 16:58:07 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Route: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\n", - "Route: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\n" + "Generated routes: [Route(name='get_time', utterances=[\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?'], description=None), Route(name='get_news', utterances=['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"], description=None)]\n" ] } ], "source": [ - "def get_time(location: str) -> str:\n", - " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", - " return \"get_time\"\n", - "\n", - "\n", - "def get_news(category: str, country: str) -> str:\n", - " \"\"\"Useful to get the news in a specific country\"\"\"\n", - " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", - " )\n", - " return \"get_news\"\n", - "\n", + "from semantic_router.layer import RouteLayer\n", "\n", "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", + "def from_functions(functions: list[Callable]) -> RouteLayer:\n", + " routes = []\n", + " for function in functions:\n", + " route = generate_route(function)\n", + " routes.append(route)\n", "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", + " print(f\"Generated routes: {routes}\")\n", + " return RouteLayer(routes=routes)\n", "\n", - "# Tools\n", - "tools = [get_time, get_news]" + "router = from_functions([get_time, get_news])\n", + "\n", + "# Saving the router configuration\n", + "router.to_json(\"router.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Loading configuration from file\n", + "router = RouteLayer.from_json(\"router.json\")" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:09 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:09 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger AI message: \n", " {\n", " \"location\": \"Stockholm\"\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" ] }, { @@ -571,14 +499,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger AI message: \n", " {\n", " \"category\": \"tech\",\n", " \"country\": \"Lithuania\"\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" + "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" ] }, { @@ -593,9 +521,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2023-12-18 12:20:05 WARNING semantic_router.utils.logger No function found\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:06 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" + "\u001b[33m2023-12-18 16:58:12 WARNING semantic_router.utils.logger No function found\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:13 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" ] }, { @@ -604,12 +532,14 @@ "' How can I help you today?'" ] }, - "execution_count": 26, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "tools = [get_time, get_news]\n", + "\n", "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", "call(query=\"Hi!\", functions=tools, router=router)" diff --git a/docs/examples/router.json b/docs/examples/router.json new file mode 100644 index 00000000..d82eaf6b --- /dev/null +++ b/docs/examples/router.json @@ -0,0 +1,24 @@ +[ + { + "name": "get_time", + "utterances": [ + "What's the time in New York?", + "Tell me the time in Tokyo.", + "Can you give me the time in London?", + "What's the current time in Sydney?", + "Can you tell me the time in Berlin?" + ], + "description": null + }, + { + "name": "get_news", + "utterances": [ + "Tell me the latest news from the US", + "What's happening in India today?", + "Get me the top stories from Japan", + "Can you give me the breaking news from Brazil?", + "What's the latest news from Germany?" + ], + "description": null + } +] diff --git a/poetry.lock b/poetry.lock index 216d298d..81101378 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1686,6 +1686,55 @@ files = [ {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, ] +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "pyzmq" version = "25.1.2" @@ -2222,4 +2271,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "f2735c243faa3d788c0f6268d6cb550648ed0d1fffec27a084344dafa4590a80" +content-hash = "f9717f2fd983029796c2c6162081f4b195555453f23f8e5d784ca7a7c1034034" diff --git a/pyproject.toml b/pyproject.toml index e45e5f17..32cb1fe3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ cohere = "^4.32" numpy = "^1.25.2" pinecone-text = "^0.7.0" colorlog = "^6.8.0" +pyyaml = "^6.0.1" [tool.poetry.group.dev.dependencies] diff --git a/semantic_router/layer.py b/semantic_router/layer.py index cb408c5c..ad9f73fe 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -1,4 +1,7 @@ +import json + import numpy as np +import yaml from semantic_router.encoders import ( BaseEncoder, @@ -15,7 +18,10 @@ class RouteLayer: categories = None score_threshold = 0.82 - def __init__(self, encoder: BaseEncoder, routes: list[Route] = []): + def __init__( + self, encoder: BaseEncoder = CohereEncoder(), routes: list[Route] = [] + ): + self.routes: list[Route] = routes self.encoder = encoder # decide on default threshold based on encoder if isinstance(encoder, OpenAIEncoder): @@ -27,7 +33,7 @@ class RouteLayer: # if routes list has been passed, we initialize index now if routes: # initialize index now - self.add_routes(routes=routes) + self._add_routes(routes=routes) def __call__(self, text: str) -> str | None: results = self._query(text) @@ -38,6 +44,20 @@ class RouteLayer: else: return None + @classmethod + def from_json(cls, file_path: str): + with open(file_path, "r") as f: + routes_data = json.load(f) + routes = [Route.from_dict(route_data) for route_data in routes_data] + return cls(routes=routes) + + @classmethod + def from_yaml(cls, file_path: str): + with open(file_path, "r") as f: + routes_data = yaml.load(f, Loader=yaml.FullLoader) + routes = [Route.from_dict(route_data) for route_data in routes_data] + return cls(routes=routes) + def add_route(self, route: Route): # create embeddings embeds = self.encoder(route.utterances) @@ -55,7 +75,7 @@ class RouteLayer: embed_arr = np.array(embeds) self.index = np.concatenate([self.index, embed_arr]) - def add_routes(self, routes: list[Route]): + def _add_routes(self, routes: list[Route]): # create embeddings for all routes all_utterances = [ utterance for route in routes for utterance in route.utterances @@ -124,3 +144,8 @@ class RouteLayer: return max(scores) > threshold else: return False + + def to_json(self, file_path: str): + routes = [route.to_dict() for route in self.routes] + with open(file_path, "w") as f: + json.dump(routes, f, indent=4) diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 007cddcb..1bb2ad00 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,5 +1,6 @@ from enum import Enum +import yaml from pydantic import BaseModel from pydantic.dataclasses import dataclass @@ -15,6 +16,16 @@ class Route(BaseModel): utterances: list[str] description: str | None = None + def to_dict(self): + return self.dict() + + def to_yaml(self): + return yaml.dump(self.dict()) + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + class EncoderType(Enum): HUGGINGFACE = "huggingface" diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 66e0d53b..1d9536a7 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -78,7 +78,7 @@ class TestRouteLayer: def test_add_multiple_routes(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder) - route_layer.add_routes(routes=routes) + route_layer._add_routes(routes=routes) assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 -- GitLab