From d59189fc128d57302050fe7964ce429c126aed13 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Wed, 1 May 2024 23:20:35 +0400 Subject: [PATCH] Applying James' suggested fixes and testing. --- docs/02-dynamic-routes.ipynb | 686 +++++++++++++++++++++++-- semantic_router/llms/openai.py | 54 +- semantic_router/utils/function_call.py | 110 ++-- 3 files changed, 724 insertions(+), 126 deletions(-) diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb index 40f198e1..59f1366e 100644 --- a/docs/02-dynamic-routes.ipynb +++ b/docs/02-dynamic-routes.ipynb @@ -185,27 +185,27 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger local\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 2 length: 51\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 2 trunc length: 51\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 3 length: 66\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 3 trunc length: 66\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 4 length: 38\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 4 trunc length: 38\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 5 length: 27\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 5 trunc length: 27\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 6 length: 24\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 6 trunc length: 24\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 7 length: 21\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 7 trunc length: 21\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 8 length: 20\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 8 trunc length: 20\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 9 length: 25\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 9 trunc length: 25\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 10 length: 22\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 10 trunc length: 22\u001b[0m\n" + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger local\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 2 length: 51\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 2 trunc length: 51\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 3 length: 66\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 3 trunc length: 66\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 4 length: 38\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 4 trunc length: 38\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 5 length: 27\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 5 trunc length: 27\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 6 length: 24\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 6 trunc length: 24\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 7 length: 21\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 7 trunc length: 21\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 8 length: 20\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 8 trunc length: 20\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 9 length: 25\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 9 trunc length: 25\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 10 length: 22\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:13 INFO semantic_router.utils.logger Document 10 trunc length: 22\u001b[0m\n" ] } ], @@ -254,8 +254,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-01 01:25:55 INFO semantic_router.utils.logger Document 1 length: 24\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:55 INFO semantic_router.utils.logger Document 1 trunc length: 24\u001b[0m\n" + "\u001b[32m2024-05-01 23:07:14 INFO semantic_router.utils.logger Document 1 length: 24\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:14 INFO semantic_router.utils.logger Document 1 trunc length: 24\u001b[0m\n" ] }, { @@ -332,7 +332,7 @@ { "data": { "text/plain": [ - "'17:25'" + "'15:07'" ] }, "execution_count": 6, @@ -356,6 +356,32 @@ { "cell_type": "code", "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'get_time',\n", + " 'description': 'Finds the current time in a specific timezone.\\n\\n:param timezone: The timezone to find the current time in, should\\n be a valid timezone from the IANA Time Zone Database like\\n \"America/New_York\" or \"Europe/London\". Do NOT put the place\\n name itself like \"rome\", or \"new york\", you must provide\\n the IANA format.\\n:type timezone: str\\n:return: The current time in the specified timezone.',\n", + " 'signature': '(timezone: str) -> str',\n", + " 'output': \"<class 'str'>\"}]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from semantic_router.utils.function_call import get_schema_list\n", + "\n", + "schema = get_schema_list([get_time])\n", + "schema" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -376,13 +402,13 @@ " 'required': ['timezone']}}}]" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from semantic_router.utils.function_call import get_schemas_openai\n", + "from semantic_router.llms.openai import get_schemas_openai\n", "\n", "schema = get_schemas_openai([get_time])\n", "schema" @@ -399,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": { "id": "iesBG9P3ur0z" }, @@ -418,7 +444,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -436,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -449,13 +475,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 2 length: 27\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 2 trunc length: 27\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 3 length: 32\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 3 trunc length: 32\u001b[0m\n" + "\u001b[32m2024-05-01 23:07:14 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:14 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:14 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:14 INFO semantic_router.utils.logger Document 2 length: 27\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:14 INFO semantic_router.utils.logger Document 2 trunc length: 27\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:14 INFO semantic_router.utils.logger Document 3 length: 32\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:14 INFO semantic_router.utils.logger Document 3 trunc length: 32\u001b[0m\n" ] } ], @@ -465,7 +491,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -483,7 +509,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -497,9 +523,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", - "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", - "\u001b[33m2024-05-01 01:25:57 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" + "\u001b[32m2024-05-01 23:07:15 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:15 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", + "\u001b[33m2024-05-01 23:07:15 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" ] }, { @@ -508,19 +534,19 @@ "RouteChoice(name='get_time', function_call=[{'function_name': 'get_time', 'arguments': '{\"timezone\":\"America/New_York\"}'}], similarity_score=None)" ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "out = rl(\"what is the time in new york city?\")\n", - "out\n" + "response = rl(\"what is the time in new york city?\")\n", + "response\n" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -532,26 +558,26 @@ } ], "source": [ - "print(out.function_call)" + "print(response.function_call)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "17:25\n" + "15:07\n" ] } ], "source": [ "import json\n", "\n", - "for call in out.function_call:\n", + "for call in response.function_call:\n", " if call['function_name'] == 'get_time':\n", " args = json.loads(call['arguments'])\n", " result = get_time(**args)\n", @@ -567,6 +593,13 @@ "Our dynamic route provides both the route itself _and_ the input parameters required to use the route." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dynamic Routes with Multiple Functions" + ] + }, { "cell_type": "markdown", "metadata": { @@ -575,6 +608,561 @@ "source": [ "---" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Routes can be assigned multiple functions. Then, when that particular Route is selected by the Route Layer, a number of those functions might be invoked due to the users utterance containing relevant information that fits their arguments. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define a Route that has multiple functions." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime, timedelta\n", + "from zoneinfo import ZoneInfo\n", + "\n", + "# Function with one argument\n", + "def get_time(timezone: str) -> str:\n", + " \"\"\"Finds the current time in a specific timezone.\n", + "\n", + " :param timezone: The timezone to find the current time in, should\n", + " be a valid timezone from the IANA Time Zone Database like\n", + " \"America/New_York\" or \"Europe/London\". Do NOT put the place\n", + " name itself like \"rome\", or \"new york\", you must provide\n", + " the IANA format.\n", + " :type timezone: str\n", + " :return: The current time in the specified timezone.\"\"\"\n", + " now = datetime.now(ZoneInfo(timezone))\n", + " return now.strftime(\"%H:%M\")\n", + "\n", + "# # Function with two arguments\n", + "# def get_time_difference(timezone1: str, timezone2: str) -> str:\n", + "# \"\"\"Calculates the time difference between two timezones.\n", + "# :param timezone1: The first timezone.\n", + "# :param timezone2: The second timezone.\n", + "# :type timezone1: str\n", + "# :type timezone2: str\n", + "# :return: The time difference in hours between the two timezones.\"\"\"\n", + "# now = datetime.now()\n", + "# tz1 = now.astimezone(ZoneInfo(timezone1))\n", + "# tz2 = now.astimezone(ZoneInfo(timezone2))\n", + "# difference = tz2 - tz1\n", + "# total_seconds = difference.total_seconds()\n", + "# hours_difference = total_seconds / 3600\n", + "# return f\"The time difference between {timezone1} and {timezone2} is {hours_difference} hours.\"\n", + "\n", + "\n", + "\n", + "def get_time_difference(timezone1: str, timezone2: str) -> str:\n", + " \"\"\"Calculates the time difference between two timezones.\n", + " :param timezone1: The first timezone.\n", + " :param timezone2: The second timezone.\n", + " :type timezone1: str\n", + " :type timezone2: str\n", + " :return: The time difference in hours between the two timezones.\"\"\"\n", + " # Get the current time in UTC\n", + " now_utc = datetime.utcnow().replace(tzinfo=ZoneInfo('UTC'))\n", + " \n", + " # Convert the UTC time to the specified timezones\n", + " tz1_time = now_utc.astimezone(ZoneInfo(timezone1))\n", + " tz2_time = now_utc.astimezone(ZoneInfo(timezone2))\n", + " \n", + " # Calculate the difference in offsets from UTC\n", + " tz1_offset = tz1_time.utcoffset().total_seconds()\n", + " tz2_offset = tz2_time.utcoffset().total_seconds()\n", + " \n", + " # Calculate the difference in hours\n", + " hours_difference = (tz2_offset - tz1_offset) / 3600\n", + " \n", + " return f\"The time difference between {timezone1} and {timezone2} is {hours_difference} hours.\"\n", + "\n", + "# Function with three arguments\n", + "def convert_time(time: str, from_timezone: str, to_timezone: str) -> str:\n", + " \"\"\"Converts a specific time from one timezone to another.\n", + " :param time: The time to convert in HH:MM format.\n", + " :param from_timezone: The original timezone of the time, should be a valid IANA timezone.\n", + " :param to_timezone: The target timezone for the time, should be a valid IANA timezone.\n", + " :type time: str\n", + " :type from_timezone: str\n", + " :type to_timezone: str\n", + " :return: The converted time in the target timezone.\n", + " :raises ValueError: If the time format or timezone strings are invalid.\n", + " \n", + " Example:\n", + " convert_time(\"12:30\", \"America/New_York\", \"Asia/Tokyo\") -> \"03:30\"\n", + " \"\"\"\n", + " try:\n", + " print(f\"Attempting to parse the time '{time}' and apply timezone '{from_timezone}'\")\n", + " # Use today's date to avoid historical timezone issues\n", + " today = datetime.now().date()\n", + " datetime_string = f\"{today} {time}\"\n", + " time_obj = datetime.strptime(datetime_string, \"%Y-%m-%d %H:%M\").replace(tzinfo=ZoneInfo(from_timezone))\n", + " print(f\"Time parsed successfully: {time_obj}\")\n", + " \n", + " print(f\"Converting time from '{from_timezone}' to '{to_timezone}'\")\n", + " converted_time = time_obj.astimezone(ZoneInfo(to_timezone))\n", + " print(f\"Time conversion successful: {converted_time}\")\n", + " \n", + " formatted_time = converted_time.strftime(\"%H:%M\")\n", + " print(f\"Formatted converted time: {formatted_time}\")\n", + " return formatted_time\n", + " except Exception as e:\n", + " print(f\"Error encountered: {e}\")\n", + " raise ValueError(f\"Error converting time: {e}\")\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "functions = [get_time, get_time_difference, convert_time]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'type': 'function',\n", + " 'function': {'name': 'get_time',\n", + " 'description': 'Finds the current time in a specific timezone.\\n\\n:param timezone: The timezone to find the current time in, should\\n be a valid timezone from the IANA Time Zone Database like\\n \"America/New_York\" or \"Europe/London\". Do NOT put the place\\n name itself like \"rome\", or \"new york\", you must provide\\n the IANA format.\\n:type timezone: str\\n:return: The current time in the specified timezone.',\n", + " 'parameters': {'type': 'object',\n", + " 'properties': {'timezone': {'type': 'string',\n", + " 'description': 'The timezone to find the current time in, should\\n be a valid timezone from the IANA Time Zone Database like\\n \"America/New_York\" or \"Europe/London\". Do NOT put the place\\n name itself like \"rome\", or \"new york\", you must provide\\n the IANA format.'}},\n", + " 'required': ['timezone']}}},\n", + " {'type': 'function',\n", + " 'function': {'name': 'get_time_difference',\n", + " 'description': 'Calculates the time difference between two timezones.\\n:param timezone1: The first timezone.\\n:param timezone2: The second timezone.\\n:type timezone1: str\\n:type timezone2: str\\n:return: The time difference in hours between the two timezones.',\n", + " 'parameters': {'type': 'object',\n", + " 'properties': {'timezone1': {'type': 'string',\n", + " 'description': 'The first timezone.'},\n", + " 'timezone2': {'type': 'string', 'description': 'The second timezone.'}},\n", + " 'required': ['timezone1', 'timezone2']}}},\n", + " {'type': 'function',\n", + " 'function': {'name': 'convert_time',\n", + " 'description': 'Converts a specific time from one timezone to another.\\n:param time: The time to convert in HH:MM format.\\n:param from_timezone: The original timezone of the time, should be a valid IANA timezone.\\n:param to_timezone: The target timezone for the time, should be a valid IANA timezone.\\n:type time: str\\n:type from_timezone: str\\n:type to_timezone: str\\n:return: The converted time in the target timezone.\\n:raises ValueError: If the time format or timezone strings are invalid.\\n\\nExample:\\n convert_time(\"12:30\", \"America/New_York\", \"Asia/Tokyo\") -> \"03:30\"',\n", + " 'parameters': {'type': 'object',\n", + " 'properties': {'time': {'type': 'string',\n", + " 'description': 'The time to convert in HH:MM format.'},\n", + " 'from_timezone': {'type': 'string',\n", + " 'description': 'The original timezone of the time, should be a valid IANA timezone.'},\n", + " 'to_timezone': {'type': 'string',\n", + " 'description': 'The target timezone for the time, should be a valid IANA timezone.'}},\n", + " 'required': ['time', 'from_timezone', 'to_timezone']}}}]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Generate schemas for all functions\n", + "from semantic_router.llms.openai import get_schemas_openai\n", + "schemas = get_schemas_openai(functions)\n", + "schemas\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the dynamic route with multiple functions\n", + "multi_function_route = Route(\n", + " name=\"timezone_management\",\n", + " utterances=[\n", + " # Utterances for get_time function\n", + " \"what is the time in New York?\",\n", + " \"current time in Berlin?\",\n", + " \"tell me the time in Moscow right now\",\n", + " \"can you show me the current time in Tokyo?\",\n", + " \"please provide the current time in London\",\n", + "\n", + " # Utterances for get_time_difference function\n", + " \"how many hours ahead is Tokyo from London?\",\n", + " \"time difference between Sydney and Cairo\",\n", + " \"what's the time gap between Los Angeles and New York?\",\n", + " \"how much time difference is there between Paris and Sydney?\",\n", + " \"calculate the time difference between Dubai and Toronto\",\n", + "\n", + " # Utterances for convert_time function\n", + " \"convert 15:00 from New York time to Berlin time\",\n", + " \"change 09:00 from Paris time to Moscow time\",\n", + " \"adjust 20:00 from Rome time to London time\",\n", + " \"convert 12:00 from Madrid time to Chicago time\",\n", + " \"change 18:00 from Beijing time to Los Angeles time\"\n", + " ],\n", + " function_schemas=schemas\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "routes = [politics, chitchat, multi_function_route]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger local\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 2 length: 51\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 2 trunc length: 51\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 3 length: 66\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 3 trunc length: 66\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 4 length: 38\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 4 trunc length: 38\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 5 length: 27\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 5 trunc length: 27\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 6 length: 24\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 6 trunc length: 24\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 7 length: 21\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 7 trunc length: 21\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 8 length: 20\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 8 trunc length: 20\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 9 length: 25\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 9 trunc length: 25\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 10 length: 22\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 10 trunc length: 22\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 11 length: 29\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 11 trunc length: 29\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 12 length: 23\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 12 trunc length: 23\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 13 length: 36\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 13 trunc length: 36\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 14 length: 42\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 14 trunc length: 42\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 15 length: 41\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 15 trunc length: 41\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 16 length: 42\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 16 trunc length: 42\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 17 length: 40\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 17 trunc length: 40\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 18 length: 53\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 18 trunc length: 53\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 19 length: 59\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 19 trunc length: 59\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 20 length: 55\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 20 trunc length: 55\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 21 length: 47\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 21 trunc length: 47\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 22 length: 43\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 22 trunc length: 43\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 23 length: 42\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 23 trunc length: 42\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 24 length: 46\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 24 trunc length: 46\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 25 length: 50\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:16 INFO semantic_router.utils.logger Document 25 trunc length: 50\u001b[0m\n" + ] + } + ], + "source": [ + "rl2 = RouteLayer(encoder=encoder, routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Function to Parse Route Layer Responses" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "def parse_response(response: str):\n", + "\n", + " for call in response.function_call:\n", + " args = json.loads(call['arguments'])\n", + " if call['function_name'] == 'get_time':\n", + " result = get_time(**args)\n", + " print(result)\n", + " if call['function_name'] == 'get_time_difference':\n", + " result = get_time_difference(**args)\n", + " print(result)\n", + " if call['function_name'] == 'convert_time':\n", + " result = convert_time(**args)\n", + " print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Checking that Politics Non-Dynamic Route Still Works" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-01 23:07:17 INFO semantic_router.utils.logger Document 1 length: 31\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:17 INFO semantic_router.utils.logger Document 1 trunc length: 31\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name='politics', function_call=None, similarity_score=None)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = rl2(\"What is your political leaning?\")\n", + "response" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Checking that Chitchat Non-Dynamic Route Still Works" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-01 23:07:17 INFO semantic_router.utils.logger Document 1 length: 29\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:17 INFO semantic_router.utils.logger Document 1 trunc length: 29\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = rl2(\"Hello bot, how are you today?\")\n", + "response" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Testing the `multi_function_route` - The `get_time` Function" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-01 23:07:18 INFO semantic_router.utils.logger Document 1 length: 29\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:18 INFO semantic_router.utils.logger Document 1 trunc length: 29\u001b[0m\n", + "\u001b[33m2024-05-01 23:07:18 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time', 'arguments': '{\"timezone\":\"America/New_York\"}'}], similarity_score=None)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = rl2(\"what is the time in New York?\")\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15:07\n" + ] + } + ], + "source": [ + "parse_response(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Testing the `multi_function_route` - The `get_time_difference` Function" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-01 23:07:19 INFO semantic_router.utils.logger Document 1 length: 61\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:19 INFO semantic_router.utils.logger Document 1 trunc length: 61\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time_difference', 'arguments': '{\"timezone1\":\"America/Los_Angeles\",\"timezone2\":\"Europe/Istanbul\"}'}], similarity_score=None)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = rl2(\"What is the time difference between Los Angeles and Istanbul?\")\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The time difference between America/Los_Angeles and Europe/Istanbul is 10.0 hours.\n" + ] + } + ], + "source": [ + "parse_response(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Testing the `multi_function_route` - The `convert_time` Function" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-01 23:07:21 INFO semantic_router.utils.logger Document 1 length: 72\u001b[0m\n", + "\u001b[32m2024-05-01 23:07:21 INFO semantic_router.utils.logger Document 1 trunc length: 72\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name='timezone_management', function_call=[{'function_name': 'convert_time', 'arguments': '{\"time\":\"23:02\",\"from_timezone\":\"Asia/Dubai\",\"to_timezone\":\"Asia/Tokyo\"}'}], similarity_score=None)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = rl2(\"What is 23:02 Dubai time in Tokyo time. Please and thank you.\")\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Attempting to parse the time '23:02' and apply timezone 'Asia/Dubai'\n", + "Time parsed successfully: 1900-01-01 23:02:00+03:41:12\n", + "Converting time from 'Asia/Dubai' to 'Asia/Tokyo'\n", + "Time conversion successful: 1900-01-02 04:20:48+09:00\n", + "Formatted converted time: 04:20\n", + "04:20\n" + ] + } + ], + "source": [ + "parse_response(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index db5da001..c3cb347b 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -8,7 +8,10 @@ from semantic_router.schema import Message from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger import json - +from semantic_router.utils.function_call import get_schema, convert_python_type_to_json_type +import inspect +from typing import Callable, Dict +import re class OpenAILLM(BaseLLM): client: Optional[openai.OpenAI] @@ -104,3 +107,52 @@ class OpenAILLM(BaseLLM): function_inputs_str = self(messages=messages, function_schemas=function_schemas) function_inputs = json.loads(function_inputs_str) return function_inputs + +def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: + schemas = [] + for item in items: + if not callable(item): + raise ValueError("Provided item must be a callable function.") + + # Use the existing get_schema function to get the basic schema + basic_schema = get_schema(item) + + # Initialize the function schema with basic details + function_schema = { + "name": basic_schema['name'], + "description": basic_schema['description'], + "parameters": {"type": "object", "properties": {}, "required": []} + } + + # Extract parameter details from the signature + signature = inspect.signature(item) + docstring = inspect.getdoc(item) + param_doc_regex = re.compile(r":param (\w+):(.*?)\n(?=:\w|$)", re.S) + doc_params = param_doc_regex.findall(docstring) if docstring else [] + + for param_name, param in signature.parameters.items(): + param_type = param.annotation.__name__ if param.annotation != inspect.Parameter.empty else "Any" + param_description = "No description available." + param_required = param.default is inspect.Parameter.empty + + # Find the parameter description in the docstring + for doc_param_name, doc_param_desc in doc_params: + if doc_param_name == param_name: + param_description = doc_param_desc.strip() + break + + function_schema["parameters"]["properties"][param_name] = { + "type": convert_python_type_to_json_type(param_type), + "description": param_description + } + + if param_required: + function_schema["parameters"]["required"].append(param_name) + + schemas.append({ + "type": "function", + "function": function_schema + }) + + return schemas + diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index e3700e12..57e4c467 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -9,42 +9,46 @@ from semantic_router.utils.logger import logger import re -def get_schemas(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, Any]]: +def get_schema_list(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, Any]]: schemas = [] for item in items: - if isinstance(item, BaseModel): - signature_parts = [] - for field_name, field_model in item.__annotations__.items(): - field_info = item.__fields__[field_name] - default_value = field_info.default - - if default_value: - default_repr = repr(default_value) - signature_part = ( - f"{field_name}: {field_model.__name__} = {default_repr}" - ) - else: - signature_part = f"{field_name}: {field_model.__name__}" - - signature_parts.append(signature_part) - signature = f"({', '.join(signature_parts)}) -> str" - schema = { - "name": item.__class__.__name__, - "description": item.__doc__, - "signature": signature, - } - else: - schema = { - "name": item.__name__, - "description": str(inspect.getdoc(item)), - "signature": str(inspect.signature(item)), - "output": str(inspect.signature(item).return_annotation), - } + schema = get_schema(item) schemas.append(schema) return schemas +def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]: + if isinstance(item, BaseModel): + signature_parts = [] + for field_name, field_model in item.__annotations__.items(): + field_info = item.__fields__[field_name] + default_value = field_info.default + + if default_value: + default_repr = repr(default_value) + signature_part = ( + f"{field_name}: {field_model.__name__} = {default_repr}" + ) + else: + signature_part = f"{field_name}: {field_model.__name__}" + + signature_parts.append(signature_part) + signature = f"({', '.join(signature_parts)}) -> str" + schema = { + "name": item.__class__.__name__, + "description": item.__doc__, + "signature": signature, + } + else: + schema = { + "name": item.__name__, + "description": str(inspect.getdoc(item)), + "signature": str(inspect.signature(item)), + "output": str(inspect.signature(item).return_annotation), + } + return schema + -def convert_param_type_to_json_type(param_type: str) -> str: +def convert_python_type_to_json_type(param_type: str) -> str: if param_type == "int": return "number" if param_type == "float": @@ -61,52 +65,6 @@ def convert_param_type_to_json_type(param_type: str) -> str: return "object" -def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: - schemas = [] - for item in items: - if not callable(item): - raise ValueError("Provided item must be a callable function.") - - docstring = inspect.getdoc(item) - signature = inspect.signature(item) - - schema = { - "type": "function", - "function": { - "name": item.__name__, - "description": docstring if docstring else "No description available.", - "parameters": {"type": "object", "properties": {}, "required": []}, - }, - } - - for param_name, param in signature.parameters.items(): - param_type = ( - param.annotation.__name__ - if param.annotation != inspect.Parameter.empty - else "Any" - ) - param_description = "No description available." - param_required = param.default is inspect.Parameter.empty - - # Attempt to extract the parameter description from the docstring - if docstring: - param_doc_regex = re.compile(rf":param {param_name}:(.*?)\n(?=:\w|$)", re.S) - match = param_doc_regex.search(docstring) - if match: - param_description = match.group(1).strip() - - schema["function"]["parameters"]["properties"][param_name] = { - "type": convert_param_type_to_json_type(param_type), - "description": param_description, - } - - if param_required: - schema["function"]["parameters"]["required"].append(param_name) - - schemas.append(schema) - - return schemas - # TODO: Add route layer object to the input, solve circular import issue async def route_and_execute( -- GitLab