diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index deb1c32f18e9c4509074d0156cd8f891c9ce3745..d86eba846d39cc5f4eff681dba80180f0c5e5944 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,20 +9,21 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def get_time(location: str) -> str:\n", " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", + " print(f\"Result from: `get_time` function with location: `{location}`\")\n", " return \"get_time\"\n", "\n", "\n", "def get_news(category: str, country: str) -> str:\n", " \"\"\"Useful to get the news in a specific country\"\"\"\n", " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", + " f\"Result from: `get_news` function with category: `{category}` \"\n", + " f\"and country: `{country}`\"\n", " )\n", " return \"get_news\"" ] @@ -36,37 +37,34 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n", - "\u001b[32m2023-12-19 16:06:38 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 16:06:44 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-19 17:46:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_time\",\n", " \"utterances\": [\n", - " \"What's the time in [location]?\",\n", - " \"Can you tell me the time in [location]?\",\n", - " \"I need to know the time in [location].\",\n", - " \"What time is it in [location]?\",\n", - " \"Can you give me the time in [location]?\"\n", + " \"What's the time in New York?\",\n", + " \"Can you tell me the time in Tokyo?\",\n", + " \"What's the current time in London?\",\n", + " \"Can you give me the time in Sydney?\",\n", + " \"What's the time in Paris?\"\n", " ]\n", "}\u001b[0m\n", - "\u001b[32m2023-12-19 16:06:44 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 16:06:50 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_news\",\n", " \"utterances\": [\n", - " \"Tell me the latest news from the US\",\n", + " \"Tell me the latest news from the United States\",\n", " \"What's happening in India today?\",\n", - " \"Get me the top stories from Japan\",\n", - " \"Can you give me the breaking news from Brazil?\",\n", + " \"Can you give me the top stories from Japan\",\n", + " \"Get me the breaking news from the UK\",\n", " \"What's the latest in Germany?\"\n", " ]\n", "}\u001b[0m\n" @@ -88,37 +86,37 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 16:07:10 INFO semantic_router.utils.logger Added route `get_weather`\u001b[0m\n", - "\u001b[32m2023-12-19 16:07:10 INFO semantic_router.utils.logger Removed route `get_weather`\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Added route `get_weather`\u001b[0m\n", + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Removed route `get_weather`\u001b[0m\n" ] }, { "data": { "text/plain": [ "[{'name': 'get_time',\n", - " 'utterances': [\"What's the time in [location]?\",\n", - " 'Can you tell me the time in [location]?',\n", - " 'I need to know the time in [location].',\n", - " 'What time is it in [location]?',\n", - " 'Can you give me the time in [location]?'],\n", + " 'utterances': [\"What's the time in New York?\",\n", + " 'Can you tell me the time in Tokyo?',\n", + " \"What's the current time in London?\",\n", + " 'Can you give me the time in Sydney?',\n", + " \"What's the time in Paris?\"],\n", " 'description': None},\n", " {'name': 'get_news',\n", - " 'utterances': ['Tell me the latest news from the US',\n", + " 'utterances': ['Tell me the latest news from the United States',\n", " \"What's happening in India today?\",\n", - " 'Get me the top stories from Japan',\n", - " 'Can you give me the breaking news from Brazil?',\n", + " 'Can you give me the top stories from Japan',\n", + " 'Get me the breaking news from the UK',\n", " \"What's the latest in Germany?\"],\n", " 'description': None}]" ] }, - "execution_count": 4, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -143,16 +141,16 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Route(name='get_time', utterances=[\"What's the time in [location]?\", 'Can you tell me the time in [location]?', 'I need to know the time in [location].', 'What time is it in [location]?', 'Can you give me the time in [location]?'], description=None)" + "Route(name='get_time', utterances=[\"What's the time in New York?\", 'Can you tell me the time in Tokyo?', \"What's the current time in London?\", 'Can you give me the time in Sydney?', \"What's the time in Paris?\"], description=None)" ] }, - "execution_count": 5, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -171,14 +169,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 16:04:24 INFO semantic_router.utils.logger Saving route config to route_config.json\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Saving route config to route_config.json\u001b[0m\n" ] } ], @@ -186,6 +184,13 @@ "route_config.to_file(\"route_config.json\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define routing layer" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -195,31 +200,26 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 16:07:16 INFO semantic_router.utils.logger Loading route config from route_config.json\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Loading route config from route_config.json\u001b[0m\n" ] } ], "source": [ + "from semantic_router.route import RouteConfig\n", + "\n", "route_config = RouteConfig.from_file(\"route_config.json\")" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Define routing layer" - ] - }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -237,50 +237,55 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 16:07:25 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Calling function: get_time\n" + "Calling function: get_time\n", + "Result from: `get_time` function with location: `Stockholm`\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[31m2023-12-19 16:07:27 ERROR semantic_router.utils.logger Input name missing from query\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:49 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Inputs: {'location': 'Stockholm'}\n", - "Schema: {'name': 'get_time', 'description': 'Useful to get the time in a specific location', 'signature': '(location: str) -> str', 'output': \"<class 'str'>\"}\n" + "Calling function: get_news\n", + "Result from: `get_news` function with category: `tech` and country: `Lithuania`\n" ] }, { - "ename": "ValueError", - "evalue": "Invalid inputs", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb Cell 14\u001b[0m line \u001b[0;36m5\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msemantic_router\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mutils\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mfunction_call\u001b[39;00m \u001b[39mimport\u001b[39;00m route_and_execute\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m tools \u001b[39m=\u001b[39m [get_time, get_news]\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m \u001b[39mawait\u001b[39;00m route_and_execute(\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=5'>6</a>\u001b[0m query\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mWhat is the time in Stockholm?\u001b[39m\u001b[39m\"\u001b[39m, functions\u001b[39m=\u001b[39mtools, route_layer\u001b[39m=\u001b[39mroute_layer\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m )\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=7'>8</a>\u001b[0m \u001b[39mawait\u001b[39;00m route_and_execute(\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m query\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mWhat is the tech news in the Lithuania?\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m functions\u001b[39m=\u001b[39mtools,\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=10'>11</a>\u001b[0m route_layer\u001b[39m=\u001b[39mroute_layer,\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=11'>12</a>\u001b[0m )\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb#Y115sZmlsZQ%3D%3D?line=12'>13</a>\u001b[0m \u001b[39mawait\u001b[39;00m route_and_execute(query\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mHi!\u001b[39m\u001b[39m\"\u001b[39m, functions\u001b[39m=\u001b[39mtools, route_layer\u001b[39m=\u001b[39mroute_layer)\n", - "File \u001b[0;32m~/customers/aurelio/semantic-router/semantic_router/utils/function_call.py:125\u001b[0m, in \u001b[0;36mroute_and_execute\u001b[0;34m(query, functions, route_layer)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mCalling function: \u001b[39m\u001b[39m{\u001b[39;00mfunction\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 124\u001b[0m schema \u001b[39m=\u001b[39m get_schema(function)\n\u001b[0;32m--> 125\u001b[0m inputs \u001b[39m=\u001b[39m \u001b[39mawait\u001b[39;00m extract_function_inputs(query, schema)\n\u001b[1;32m 126\u001b[0m call_function(function, inputs)\n", - "File \u001b[0;32m~/customers/aurelio/semantic-router/semantic_router/utils/function_call.py:83\u001b[0m, in \u001b[0;36mextract_function_inputs\u001b[0;34m(query, function_schema)\u001b[0m\n\u001b[1;32m 81\u001b[0m function_inputs \u001b[39m=\u001b[39m json\u001b[39m.\u001b[39mloads(output)\n\u001b[1;32m 82\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m is_valid_inputs(function_inputs, function_schema):\n\u001b[0;32m---> 83\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInvalid inputs\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 84\u001b[0m \u001b[39mreturn\u001b[39;00m function_inputs\n", - "\u001b[0;31mValueError\u001b[0m: Invalid inputs" + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m2023-12-19 17:46:52 WARNING semantic_router.utils.logger No function found, calling LLM...\u001b[0m\n" ] + }, + { + "data": { + "text/plain": [ + "'Hello! How can I assist you today?'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -298,13 +303,6 @@ ")\n", "await route_and_execute(query=\"Hi!\", functions=tools, route_layer=route_layer)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/examples/route_config.json b/docs/examples/route_config.json index 0a02d850653613f2fd4f2a7fc6ca8b13b9f9ae86..f76a73859e4c3534f37bc99c0aec17627e0d2ee6 100644 --- a/docs/examples/route_config.json +++ b/docs/examples/route_config.json @@ -1 +1 @@ -[{"name": "get_time", "utterances": ["What's the time in [location]?", "Can you tell me the time in [location]?", "I need to know the time in [location].", "What time is it in [location]?", "Can you give me the time in [location]?"], "description": null}, {"name": "get_news", "utterances": ["Tell me the latest news from the US", "What's happening in India today?", "Get me the top stories from Japan", "Can you give me the breaking news from Brazil?", "What's the latest news from Germany?"], "description": null}] +[{"name": "get_time", "utterances": ["What's the time in New York?", "Can you tell me the time in Tokyo?", "What's the current time in London?", "Can you give me the time in Sydney?", "What's the time in Paris?"], "description": null}, {"name": "get_news", "utterances": ["Tell me the latest news from the United States", "What's happening in India today?", "Can you give me the top stories from Japan", "Get me the breaking news from the UK", "What's the latest in Germany?"], "description": null}] diff --git a/semantic_router/route.py b/semantic_router/route.py index c1ec8fc360080d41bdf69a61d3acd8a7a5fcb79e..f46c005cc4d5021391acc9196b63e96113edd5b7 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -98,7 +98,7 @@ class Route(BaseModel): Only include the "name" and "utterances" keys in your answer. The "name" should match the function name and the "utterances" should comprise a list of 5 example phrases that could be used to invoke - the function. + the function. Use real values instead of placeholders. Input schema: {function_schema} diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index 3c0c9a42c9e581eba151984786ff487e3c7fa918..c1b4fceed0492124fc98401f0f1789e89dfddb21 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -86,19 +86,19 @@ async def extract_function_inputs(query: str, function_schema: dict[str, Any]) - def is_valid_inputs(inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool: """Validate the extracted inputs against the function schema""" - - print(f"Inputs: {inputs}") - - print(f"Schema: {function_schema}") - try: - for name, param in function_schema.items(): + # Extract parameter names and types from the signature string + signature = function_schema["signature"] + param_info = [param.strip() for param in signature[1:-1].split(",")] + param_names = [info.split(":")[0].strip() for info in param_info] + param_types = [ + info.split(":")[1].strip().split("=")[0].strip() for info in param_info + ] + + for name, type_str in zip(param_names, param_types): if name not in inputs: logger.error(f"Input {name} missing from query") return False - if not isinstance(inputs[name], param["type"]): - logger.error(f"Input {name} is not of type {param['type']}") - return False return True except Exception as e: logger.error(f"Input validation error: {str(e)}") @@ -117,7 +117,7 @@ async def route_and_execute(query: str, functions: list[Callable], route_layer): function_name = route_layer(query) if not function_name: logger.warning("No function found, calling LLM...") - return llm(query) + return await llm(query) for function in functions: if function.__name__ == function_name: