diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 5d3be2fb9072d3e25d50a6c7456500f7d72083bc..c41a8a2b535432b0b935ebe39681b6f91df2e0ca 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,9 +9,19 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n" + ] + } + ], "source": [ "# OpenAI\n", "import openai\n", @@ -39,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -91,7 +101,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -113,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -130,16 +140,18 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import json\n", + "from typing import Callable\n", "\n", "from semantic_router.utils.logger import logger\n", + "from semantic_router.layer import Route\n", "\n", "\n", - "def generate_route(function) -> dict:\n", + "def generate_route(function: Callable) -> Route:\n", " logger.info(\"Generating config...\")\n", "\n", " function_schema = get_function_schema(function)\n", @@ -196,11 +208,10 @@ " try:\n", " route_config = json.loads(ai_message)\n", " logger.info(f\"Generated config: {route_config}\")\n", - " return route_config\n", + " return Route(**route_config)\n", " except json.JSONDecodeError as json_error:\n", " logger.error(f\"JSON parsing error {json_error}\")\n", - " print(f\"AI message: {ai_message}\")\n", - " return {\"error\": \"Failed to generate config\"}" + " raise Exception(f\"Failed to generate a valid Route {json_error}\")" ] }, { @@ -212,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -228,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -299,40 +310,6 @@ " return {\"error\": \"Failed to validate parameters\"}" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Set up the routing layer" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.schema import Route\n", - "from semantic_router.encoders import CohereEncoder\n", - "from semantic_router.layer import RouteLayer\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "def create_router(routes: list[dict]) -> RouteLayer:\n", - " logger.info(\"Creating route layer...\")\n", - " encoder = CohereEncoder()\n", - "\n", - " route_list: list[Route] = []\n", - " for route in routes:\n", - " if \"name\" in route and \"utterances\" in route:\n", - " print(f\"Route: {route}\")\n", - " route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", - " else:\n", - " logger.warning(f\"Misconfigured route: {route}\")\n", - "\n", - " return RouteLayer(encoder=encoder, routes=route_list)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -342,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -387,57 +364,18 @@ "### Workflow" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Functions as a tool" + ] + }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 9, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[31m2023-12-18 12:18:00 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:00 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger AI message: {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"what is the time in new york\",\n", - " \"can you tell me the time in london\",\n", - " \"get me the current time in tokyo\",\n", - " \"i need to know the time in sydney\",\n", - " \"please tell me the current time in paris\"\n", - " ]\n", - "}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[31m2023-12-18 12:18:07 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:07 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger AI message: {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"Can I get the latest news in Canada?\",\n", - " \"Show me the recent news in the US\",\n", - " \"I would like to know about the sports news in England\",\n", - " \"Let's check the technology news in Japan\",\n", - " \"Show me the health related news in Germany\"\n", - " ]\n", - "}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Route: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\n", - "Route: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\n" - ] - } - ], + "outputs": [], "source": [ "def get_time(location: str) -> str:\n", " \"\"\"Useful to get the time in a specific location\"\"\"\n", @@ -450,32 +388,21 @@ " print(\n", " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", " )\n", - " return \"get_news\"\n", - "\n", - "\n", - "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", - "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", - "\n", - "# Tools\n", - "tools = [get_time, get_news]" + " return \"get_news\"" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:00 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:00 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger AI message: \n", " Example output:\n", " {\n", " \"name\": \"get_time\",\n", @@ -487,10 +414,10 @@ " \"Can you tell me the time in Berlin?\"\n", " ]\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:07 INFO semantic_router.utils.logger AI message: \n", " Example output:\n", " {\n", " \"name\": \"get_news\",\n", @@ -502,61 +429,62 @@ " \"What's the latest news from Germany?\"\n", " ]\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" + "\u001b[32m2023-12-18 16:58:07 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Route: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\n", - "Route: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\n" + "Generated routes: [Route(name='get_time', utterances=[\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?'], description=None), Route(name='get_news', utterances=['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"], description=None)]\n" ] } ], "source": [ - "def get_time(location: str) -> str:\n", - " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", - " return \"get_time\"\n", - "\n", - "\n", - "def get_news(category: str, country: str) -> str:\n", - " \"\"\"Useful to get the news in a specific country\"\"\"\n", - " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", - " )\n", - " return \"get_news\"\n", - "\n", + "from semantic_router.layer import RouteLayer\n", "\n", "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", + "def from_functions(functions: list[Callable]) -> RouteLayer:\n", + " routes = []\n", + " for function in functions:\n", + " route = generate_route(function)\n", + " routes.append(route)\n", "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", + " print(f\"Generated routes: {routes}\")\n", + " return RouteLayer(routes=routes)\n", "\n", - "# Tools\n", - "tools = [get_time, get_news]" + "router = from_functions([get_time, get_news])\n", + "\n", + "# Saving the router configuration\n", + "router.to_json(\"router.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Loading configuration from file\n", + "router = RouteLayer.from_json(\"router.json\")" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:09 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:09 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger AI message: \n", " {\n", " \"location\": \"Stockholm\"\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" ] }, { @@ -571,14 +499,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger AI message: \n", " {\n", " \"category\": \"tech\",\n", " \"country\": \"Lithuania\"\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" + "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" ] }, { @@ -593,9 +521,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2023-12-18 12:20:05 WARNING semantic_router.utils.logger No function found\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:06 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" + "\u001b[33m2023-12-18 16:58:12 WARNING semantic_router.utils.logger No function found\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:13 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" ] }, { @@ -604,12 +532,14 @@ "' How can I help you today?'" ] }, - "execution_count": 26, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "tools = [get_time, get_news]\n", + "\n", "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", "call(query=\"Hi!\", functions=tools, router=router)" diff --git a/docs/examples/router.json b/docs/examples/router.json new file mode 100644 index 0000000000000000000000000000000000000000..d82eaf6b71a43fe95b86bae160c40872febb0afd --- /dev/null +++ b/docs/examples/router.json @@ -0,0 +1,24 @@ +[ + { + "name": "get_time", + "utterances": [ + "What's the time in New York?", + "Tell me the time in Tokyo.", + "Can you give me the time in London?", + "What's the current time in Sydney?", + "Can you tell me the time in Berlin?" + ], + "description": null + }, + { + "name": "get_news", + "utterances": [ + "Tell me the latest news from the US", + "What's happening in India today?", + "Get me the top stories from Japan", + "Can you give me the breaking news from Brazil?", + "What's the latest news from Germany?" + ], + "description": null + } +] diff --git a/poetry.lock b/poetry.lock index 216d298ddcb9dbc839733929399a6990cbd8584b..81101378f593e43dfdc6da36b3f9d943425f4629 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1686,6 +1686,55 @@ files = [ {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, ] +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "pyzmq" version = "25.1.2" @@ -2222,4 +2271,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "f2735c243faa3d788c0f6268d6cb550648ed0d1fffec27a084344dafa4590a80" +content-hash = "f9717f2fd983029796c2c6162081f4b195555453f23f8e5d784ca7a7c1034034" diff --git a/pyproject.toml b/pyproject.toml index e45e5f17d0356cce8a2cfe5a33d9fa0529c170c5..32cb1fe35b645e89c11421ba0c4f91556f7e8428 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ cohere = "^4.32" numpy = "^1.25.2" pinecone-text = "^0.7.0" colorlog = "^6.8.0" +pyyaml = "^6.0.1" [tool.poetry.group.dev.dependencies] diff --git a/semantic_router/layer.py b/semantic_router/layer.py index cb408c5c5f452b78e9750976d7c669a01028450a..ad9f73febf0df63b13f719a190bb7a5b4e64f0e7 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -1,4 +1,7 @@ +import json + import numpy as np +import yaml from semantic_router.encoders import ( BaseEncoder, @@ -15,7 +18,10 @@ class RouteLayer: categories = None score_threshold = 0.82 - def __init__(self, encoder: BaseEncoder, routes: list[Route] = []): + def __init__( + self, encoder: BaseEncoder = CohereEncoder(), routes: list[Route] = [] + ): + self.routes: list[Route] = routes self.encoder = encoder # decide on default threshold based on encoder if isinstance(encoder, OpenAIEncoder): @@ -27,7 +33,7 @@ class RouteLayer: # if routes list has been passed, we initialize index now if routes: # initialize index now - self.add_routes(routes=routes) + self._add_routes(routes=routes) def __call__(self, text: str) -> str | None: results = self._query(text) @@ -38,6 +44,20 @@ class RouteLayer: else: return None + @classmethod + def from_json(cls, file_path: str): + with open(file_path, "r") as f: + routes_data = json.load(f) + routes = [Route.from_dict(route_data) for route_data in routes_data] + return cls(routes=routes) + + @classmethod + def from_yaml(cls, file_path: str): + with open(file_path, "r") as f: + routes_data = yaml.load(f, Loader=yaml.FullLoader) + routes = [Route.from_dict(route_data) for route_data in routes_data] + return cls(routes=routes) + def add_route(self, route: Route): # create embeddings embeds = self.encoder(route.utterances) @@ -55,7 +75,7 @@ class RouteLayer: embed_arr = np.array(embeds) self.index = np.concatenate([self.index, embed_arr]) - def add_routes(self, routes: list[Route]): + def _add_routes(self, routes: list[Route]): # create embeddings for all routes all_utterances = [ utterance for route in routes for utterance in route.utterances @@ -124,3 +144,8 @@ class RouteLayer: return max(scores) > threshold else: return False + + def to_json(self, file_path: str): + routes = [route.to_dict() for route in self.routes] + with open(file_path, "w") as f: + json.dump(routes, f, indent=4) diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 007cddcbeb2c9e464e02a6c7f6cd12d2e9769cbc..1bb2ad006c37a5f8bc4a21a2a507a9d3179effac 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,5 +1,6 @@ from enum import Enum +import yaml from pydantic import BaseModel from pydantic.dataclasses import dataclass @@ -15,6 +16,16 @@ class Route(BaseModel): utterances: list[str] description: str | None = None + def to_dict(self): + return self.dict() + + def to_yaml(self): + return yaml.dump(self.dict()) + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + class EncoderType(Enum): HUGGINGFACE = "huggingface" diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 66e0d53bb9350c77578682f9ea0742b1d3dfe0b2..1d9536a7910f33639fcd3c396e8df6896f4f2e64 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -78,7 +78,7 @@ class TestRouteLayer: def test_add_multiple_routes(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder) - route_layer.add_routes(routes=routes) + route_layer._add_routes(routes=routes) assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2