diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb index 7cb18bd26d6f851e20528a524b3f353708c53db2..40f198e1cc70f769cadfae761c03b7f1a2cffc26 100644 --- a/docs/02-dynamic-routes.ipynb +++ b/docs/02-dynamic-routes.ipynb @@ -185,27 +185,27 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger local\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 2 length: 51\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 2 trunc length: 51\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 3 length: 66\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 3 trunc length: 66\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 4 length: 38\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 4 trunc length: 38\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 5 length: 27\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 5 trunc length: 27\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 6 length: 24\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 6 trunc length: 24\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 7 length: 21\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 7 trunc length: 21\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 8 length: 20\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 8 trunc length: 20\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 9 length: 25\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 9 trunc length: 25\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 10 length: 22\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:34 INFO semantic_router.utils.logger Document 10 trunc length: 22\u001b[0m\n" + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger local\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 2 length: 51\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 2 trunc length: 51\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 3 length: 66\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 3 trunc length: 66\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 4 length: 38\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 4 trunc length: 38\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 5 length: 27\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 5 trunc length: 27\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 6 length: 24\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 6 trunc length: 24\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 7 length: 21\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 7 trunc length: 21\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 8 length: 20\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 8 trunc length: 20\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 9 length: 25\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 9 trunc length: 25\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 10 length: 22\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:54 INFO semantic_router.utils.logger Document 10 trunc length: 22\u001b[0m\n" ] } ], @@ -254,8 +254,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-30 03:58:35 INFO semantic_router.utils.logger Document 1 length: 24\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:35 INFO semantic_router.utils.logger Document 1 trunc length: 24\u001b[0m\n" + "\u001b[32m2024-05-01 01:25:55 INFO semantic_router.utils.logger Document 1 length: 24\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:55 INFO semantic_router.utils.logger Document 1 trunc length: 24\u001b[0m\n" ] }, { @@ -288,7 +288,7 @@ "id": "ANAoEjxYur0y" }, "source": [ - "As with static routes, we must create a dynamic route before adding it to our route layer. To make a route dynamic, we need to provide a `function_schema`. The function schema provides instructions on what a function is, so that an LLM can decide how to use it correctly." + "As with static routes, we must create a dynamic route before adding it to our route layer. To make a route dynamic, we need to provide the `function_schemas` as a list. Each function schema provides instructions on what a function is, so that an LLM can decide how to use it correctly." ] }, { @@ -332,7 +332,7 @@ { "data": { "text/plain": [ - "'19:58'" + "'17:25'" ] }, "execution_count": 6, @@ -367,13 +367,13 @@ { "data": { "text/plain": [ - "{'type': 'function',\n", - " 'function': {'name': 'get_time',\n", - " 'description': 'Finds the current time in a specific timezone.\\n\\n:param timezone: The timezone to find the current time in, should\\n be a valid timezone from the IANA Time Zone Database like\\n \"America/New_York\" or \"Europe/London\". Do NOT put the place\\n name itself like \"rome\", or \"new york\", you must provide\\n the IANA format.\\n:type timezone: str\\n:return: The current time in the specified timezone.',\n", - " 'parameters': {'type': 'object',\n", - " 'properties': {'timezone': {'type': 'string',\n", - " 'description': 'The timezone to find the current time in, should\\n be a valid timezone from the IANA Time Zone Database like\\n \"America/New_York\" or \"Europe/London\". Do NOT put the place\\n name itself like \"rome\", or \"new york\", you must provide\\n the IANA format.'}},\n", - " 'required': ['timezone']}}}" + "[{'type': 'function',\n", + " 'function': {'name': 'get_time',\n", + " 'description': 'Finds the current time in a specific timezone.\\n\\n:param timezone: The timezone to find the current time in, should\\n be a valid timezone from the IANA Time Zone Database like\\n \"America/New_York\" or \"Europe/London\". Do NOT put the place\\n name itself like \"rome\", or \"new york\", you must provide\\n the IANA format.\\n:type timezone: str\\n:return: The current time in the specified timezone.',\n", + " 'parameters': {'type': 'object',\n", + " 'properties': {'timezone': {'type': 'string',\n", + " 'description': 'The timezone to find the current time in, should\\n be a valid timezone from the IANA Time Zone Database like\\n \"America/New_York\" or \"Europe/London\". Do NOT put the place\\n name itself like \"rome\", or \"new york\", you must provide\\n the IANA format.'}},\n", + " 'required': ['timezone']}}}]" ] }, "execution_count": 7, @@ -382,9 +382,9 @@ } ], "source": [ - "from semantic_router.utils.function_call import get_schema_openai\n", + "from semantic_router.utils.function_call import get_schemas_openai\n", "\n", - "schema = get_schema_openai(get_time)\n", + "schema = get_schemas_openai([get_time])\n", "schema" ] }, @@ -412,7 +412,7 @@ " \"what is the time in london?\",\n", " \"I live in Rome, what time is it?\",\n", " ],\n", - " function_schema=schema,\n", + " function_schemas=schema,\n", ")" ] }, @@ -449,13 +449,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-30 03:58:35 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:35 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:35 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:35 INFO semantic_router.utils.logger Document 2 length: 27\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:35 INFO semantic_router.utils.logger Document 2 trunc length: 27\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:35 INFO semantic_router.utils.logger Document 3 length: 32\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:35 INFO semantic_router.utils.logger Document 3 trunc length: 32\u001b[0m\n" + "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 2 length: 27\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 2 trunc length: 27\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 3 length: 32\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 3 trunc length: 32\u001b[0m\n" ] } ], @@ -497,15 +497,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-30 03:58:36 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", - "\u001b[32m2024-04-30 03:58:36 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", - "\u001b[33m2024-04-30 03:58:36 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" + "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 1 length: 34\u001b[0m\n", + "\u001b[32m2024-05-01 01:25:56 INFO semantic_router.utils.logger Document 1 trunc length: 34\u001b[0m\n", + "\u001b[33m2024-05-01 01:25:57 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "RouteChoice(name='get_time', function_call={'timezone': 'America/New_York'}, similarity_score=None)" + "RouteChoice(name='get_time', function_call=[{'function_name': 'get_time', 'arguments': '{\"timezone\":\"America/New_York\"}'}], similarity_score=None)" ] }, "execution_count": 12, @@ -524,18 +524,38 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "'19:58'" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'function_name': 'get_time', 'arguments': '{\"timezone\":\"America/New_York\"}'}]\n" + ] } ], "source": [ - "get_time(**out.function_call)" + "print(out.function_call)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17:25\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "for call in out.function_call:\n", + " if call['function_name'] == 'get_time':\n", + " args = json.loads(call['arguments'])\n", + " result = get_time(**args)\n", + "print(result)" ] }, { diff --git a/docs/05-local-execution.ipynb b/docs/05-local-execution.ipynb index 24f78c94101603fc9b45e27533df34294cd64ccf..c936c7773de0ba347f4bea444203520e409c8694 100644 --- a/docs/05-local-execution.ipynb +++ b/docs/05-local-execution.ipynb @@ -130,7 +130,7 @@ " \"what is the time in london?\",\n", " \"I live in Rome, what time is it?\",\n", " ],\n", - " function_schema=time_schema,\n", + " function_schemas=[time_schema],\n", ")\n", "\n", "politics = Route(\n", @@ -701,7 +701,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 5392e54e537cadeb74ec3554c37d8f76c9e4a96c..1aae719728c3e5c594bea3986b9c41e821dd17fa 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -138,7 +138,7 @@ " \"what is the current temperature in London?\",\n", " \"tomorrow's weather in Paris?\",\n", " ],\n", - " function_schema=None,\n", + " function_schemas=None,\n", ")\n", "routes.append(get_weather_route)" ] diff --git a/docs/examples/ollama-local-execution.ipynb b/docs/examples/ollama-local-execution.ipynb index 3f326d54210a24571ac51eaf8211f555b8b77cfb..2ed7acdefb0927cb2dc74d4fab32887bb8f0e2b7 100644 --- a/docs/examples/ollama-local-execution.ipynb +++ b/docs/examples/ollama-local-execution.ipynb @@ -370,7 +370,7 @@ " \"what is the time in london?\",\n", " \"I live in Rome, what time is it?\",\n", " ],\n", - " function_schema=schema,\n", + " function_schemas=[schema],\n", ")" ] }, diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 02e626f4979b945258b0ef3223e547e05c0218bd..b7905cca30d69b57638baa9ecb93ec79504a43ca 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -244,11 +244,11 @@ class RouteLayer: passed = self._check_threshold(top_class_scores, route) if passed and route is not None and not simulate_static: - if route.function_schema and text is None: + if route.function_schemas and text is None: raise ValueError( "Route has a function schema, but no text was provided." ) - if route.function_schema and not isinstance(route.llm, BaseLLM): + if route.function_schemas and not isinstance(route.llm, BaseLLM): if not self.llm: logger.warning( "No LLM provided for dynamic route, will use OpenAI LLM " diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 963c754a1c49ec4c9d2308562dd6f6f17e090090..3d1d9b3737c80a3efa702a979de740c9d023ea4e 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -20,81 +20,139 @@ class BaseLLM(BaseModel): raise NotImplementedError("Subclasses must implement this method") def _is_valid_inputs( - self, inputs: dict[str, Any], function_schema: dict[str, Any] + self, inputs: list[dict[str, Any]], function_schemas: list[dict[str, Any]] ) -> bool: - """Validate the extracted inputs against the function schema""" + """Determine if the functions chosen by the LLM exist within the function_schemas, + and if the input arguments are valid for those functions.""" try: - # Extract parameter names and types from the signature string - signature = function_schema["signature"] - param_info = [param.strip() for param in signature[1:-1].split(",")] - param_names = [info.split(":")[0].strip() for info in param_info] - param_types = [ - info.split(":")[1].strip().split("=")[0].strip() for info in param_info - ] - for name, type_str in zip(param_names, param_types): - if name not in inputs: - logger.error(f"Input {name} missing from query") + for input_dict in inputs: + # Check if 'function_name' and 'arguments' keys exist in each input dictionary + if "function_name" not in input_dict or "arguments" not in input_dict: + logger.error("Missing 'function_name' or 'arguments' in inputs") return False + + function_name = input_dict["function_name"] + arguments = input_dict["arguments"] + + # Find the matching function schema based on function_name + matching_schema = next((schema for schema in function_schemas if schema["name"] == function_name), None) + if not matching_schema: + logger.error(f"No matching function schema found for function name: {function_name}") + return False + + # Extract parameter names and types from the signature string of the matching schema + param_names, param_types = self._extract_parameter_info(matching_schema["signature"]) + + # Validate that all required parameters are present in the arguments + for name, type_str in zip(param_names, param_types): + if name not in arguments: + logger.error(f"Input {name} missing from arguments") + return False + return True except Exception as e: logger.error(f"Input validation error: {str(e)}") return False + def _extract_parameter_info(self, signature: str) -> tuple[list[str], list[str]]: + """Extract parameter names and types from the function signature.""" + param_info = [param.strip() for param in signature[1:-1].split(",")] + param_names = [info.split(":")[0].strip() for info in param_info] + param_types = [ + info.split(":")[1].strip().split("=")[0].strip() for info in param_info + ] + return param_names, param_types + def extract_function_inputs( - self, query: str, function_schema: dict[str, Any] + self, query: str, function_schemas: list[dict[str, Any]] ) -> dict: logger.info("Extracting function input...") prompt = f""" You are an accurate and reliable computer program that only outputs valid JSON. -Your task is to output JSON representing the input arguments of a Python function. +Your task is to: + 1) Pick the most relevant Python function schema(s) from FUNCTION_SCHEMAS below, based on the input QUERY. If only one schema is provided, choose that. If multiple schemas are relevant, output a list of JSON objects for each. + 2) Output JSON representing the input arguments of the chosen function schema(s), including the function name, with argument values determined by information in the QUERY. -This is the Python function's schema: +These are the Python functions' schema: -### FUNCTION_SCHEMA Start ### - {function_schema} -### FUNCTION_SCHEMA End ### +### FUNCTION_SCHEMAS Start ### + {json.dumps(function_schemas, indent=4)} +### FUNCTION_SCHEMAS End ### This is the input query. ### QUERY Start ### - {query} + {query} ### QUERY End ### The arguments that you need to provide values for, together with their datatypes, are stated in "signature" in the FUNCTION_SCHEMA. The values these arguments must take are made clear by the QUERY. Use the FUNCTION_SCHEMA "description" too, as this might provide helpful clues about the arguments and their values. -Return only JSON, stating the argument names and their corresponding values. +Include the function name in your JSON output. +Return only JSON, stating the function name and the argument names with their corresponding values. ### FORMATTING_INSTRUCTIONS Start ### - Return a respones in valid JSON format. Do not return any other explanation or text, just the JSON. - The JSON-Keys are the names of the arguments, and JSON-values are the values those arguments should take. + Return a response in valid JSON format. Do not return any other explanation or text, just the JSON. + The JSON output should include a key 'function_name' with the value being the name of the function. + Under the key 'arguments', include a nested JSON object where the keys are the names of the arguments and the values are the values those arguments should take. + If multiple function schemas are relevant, return a list of JSON objects. ### FORMATTING_INSTRUCTIONS End ### ### EXAMPLE Start ### - === EXAMPLE_INPUT_QUERY Start === - "How is the weather in Hawaii right now in International units?" - === EXAMPLE_INPUT_QUERY End === - === EXAMPLE_INPUT_SCHEMA Start === - {{ - "name": "get_weather", - "description": "Useful to get the weather in a specific location", - "signature": "(location: str, degree: str) -> str", - "output": "<class 'str'>", - }} - === EXAMPLE_INPUT_QUERY End === - === EXAMPLE_OUTPUT Start === - {{ - "location": "Hawaii", - "degree": "Celsius", - }} - === EXAMPLE_OUTPUT End === + === EXAMPLE_INPUT_QUERY Start === + "What is the temperature in Hawaii and New York right now in Celsius, and what is the humidity in Hawaii?" + === EXAMPLE_INPUT_QUERY End === + === EXAMPLE_INPUT_SCHEMA Start === + {{ + "name": "get_temperature", + "description": "Useful to get the temperature in a specific location", + "signature": "(location: str, degree: str) -> str", + "output": "<class 'str'>", + }} + {{ + "name": "get_humidity", + "description": "Useful to get the humidity level in a specific location", + "signature": "(location: str) -> int", + "output": "<class 'int'>", + }} + {{ + "name": "get_wind_speed", + "description": "Useful to get the wind speed in a specific location", + "signature": "(location: str) -> float", + "output": "<class 'float'>", + }} + === EXAMPLE_INPUT_SCHEMA End === + === EXAMPLE_OUTPUT Start === + [ + { + "function_name": "get_temperature", + "arguments": { + "location": "Hawaii", + "degree": "Celsius" + } + }, + { + "function_name": "get_temperature", + "arguments": { + "location": "New York", + "degree": "Celsius" + } + }, + { + "function_name": "get_humidity", + "arguments": { + "location": "Hawaii" + } + } + ] + === EXAMPLE_OUTPUT End === ### EXAMPLE End ### -Note: I will tip $500 for and accurate JSON output. You will be penalized for an inaccurate JSON output. +Note: I will tip $500 for an accurate JSON output. You will be penalized for an inaccurate JSON output. Provide JSON output now: -""" + """ llm_input = [Message(role="user", content=prompt)] output = self(llm_input) @@ -105,6 +163,6 @@ Provide JSON output now: logger.info(f"LLM output: {output}") function_inputs = json.loads(output) logger.info(f"Function inputs: {function_inputs}") - if not self._is_valid_inputs(function_inputs, function_schema): + if not self._is_valid_inputs(function_inputs, function_schemas): raise ValueError("Invalid inputs") - return function_inputs + return function_inputs \ No newline at end of file diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 0f561cc4330f0ed2e8f14f443148f0e6bca71e7c..db5da0014fc3eaf49dc1d4520b169c55907f6290 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -37,16 +37,29 @@ class OpenAILLM(BaseLLM): self.temperature = temperature self.max_tokens = max_tokens + def _extract_tool_calls_info(self, tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]: + tool_calls_info = [] + for tool_call in tool_calls: + if tool_call.function.arguments is None: + raise ValueError( + "Invalid output, expected arguments to be specified for each tool call." + ) + tool_calls_info.append({ + "function_name": tool_call.function.name, + "arguments": tool_call.function.arguments + }) + return json.dumps(tool_calls_info) + def __call__( self, messages: List[Message], - function_schema: Optional[dict[str, Any]] = None, + function_schemas: Optional[list[dict[str, Any]]] = None, ) -> str: if self.client is None: raise ValueError("OpenAI client is not initialized.") try: - if function_schema: - tools = [function_schema] + if function_schemas: + tools = function_schemas else: tools = None @@ -58,20 +71,17 @@ class OpenAILLM(BaseLLM): tools=tools, # type: ignore # MyPy expecting Iterable[ChatCompletionToolParam] | NotGiven, but dict is accepted by OpenAI. ) - if function_schema: + if function_schemas: tool_calls = completion.choices[0].message.tool_calls if tool_calls is None: raise ValueError("Invalid output, expected a tool call.") - if len(tool_calls) != 1: + if len(tool_calls) < 1: raise ValueError( - "Invalid output, expected a single tool to be specified." + "Invalid output, expected at least one tool to be specified." ) - arguments = tool_calls[0].function.arguments - if arguments is None: - raise ValueError( - "Invalid output, expected arguments to be specified." - ) - output = str(arguments) # str to keep MyPy happy. + + # Collecting multiple tool calls information + output = self._extract_tool_calls_info(tool_calls) else: content = completion.choices[0].message.content if content is None: @@ -85,12 +95,12 @@ class OpenAILLM(BaseLLM): raise Exception(f"LLM error: {e}") from e def extract_function_inputs( - self, query: str, function_schema: dict[str, Any] + self, query: str, function_schemas: list[dict[str, Any]] ) -> dict: messages = [] system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." messages.append(Message(role="system", content=system_prompt)) messages.append(Message(role="user", content=query)) - function_inputs_str = self(messages=messages, function_schema=function_schema) + function_inputs_str = self(messages=messages, function_schemas=function_schemas) function_inputs = json.loads(function_inputs_str) return function_inputs diff --git a/semantic_router/route.py b/semantic_router/route.py index 3d46a8b4f4578ce90da8984d60a8a85956341ed3..74209058f9a79e363e963054e4b4e85d0b854708 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -47,7 +47,7 @@ class Route(BaseModel): name: str utterances: Union[List[str], List[Union[Any, "Image"]]] description: Optional[str] = None - function_schema: Optional[Dict[str, Any]] = None + function_schemas: Optional[list[Dict[str, Any]]] = None llm: Optional[BaseLLM] = None score_threshold: Optional[float] = None @@ -55,7 +55,7 @@ class Route(BaseModel): arbitrary_types_allowed = True def __call__(self, query: Optional[str] = None) -> RouteChoice: - if self.function_schema: + if self.function_schemas: if not self.llm: raise ValueError( "LLM is required for dynamic routes. Please ensure the `llm` " @@ -68,7 +68,7 @@ class Route(BaseModel): ) # if a function schema is provided we generate the inputs extracted_inputs = self.llm.extract_function_inputs( - query=query, function_schema=self.function_schema + query=query, function_schemas=self.function_schemas ) func_call = extracted_inputs else: @@ -92,17 +92,16 @@ class Route(BaseModel): @classmethod def from_dict(cls, data: Dict[str, Any]): return cls(**data) - @classmethod - def from_dynamic_route(cls, llm: BaseLLM, entity: Union[BaseModel, Callable]): + def from_dynamic_route(cls, llm: BaseLLM, entities: List[Union[BaseModel, Callable]]): """ - Generate a dynamic Route object from a function or Pydantic model using LLM + Generate a dynamic Route object from a list of functions or Pydantic models using LLM """ - schema = function_call.get_schema(item=entity) - dynamic_route = cls._generate_dynamic_route(llm=llm, function_schema=schema) - dynamic_route.function_schema = schema + schemas = function_call.get_schemas(items=entities) + dynamic_route = cls._generate_dynamic_route(llm=llm, function_schemas=schemas) + dynamic_route.function_schemas = schemas return dynamic_route - + @classmethod def _parse_route_config(cls, config: str) -> str: # Regular expression to match content inside <config></config> @@ -116,16 +115,18 @@ class Route(BaseModel): raise ValueError("No <config></config> tags found in the output.") @classmethod - def _generate_dynamic_route(cls, llm: BaseLLM, function_schema: Dict[str, Any]): + def _generate_dynamic_route(cls, llm: BaseLLM, function_schemas: List[Dict[str, Any]], route_name: str): logger.info("Generating dynamic route...") + formatted_schemas = "\n".join([json.dumps(schema, indent=4) for schema in function_schemas]) prompt = f""" - You are tasked to generate a JSON configuration based on the provided - function schema. Please follow the template below, no other tokens allowed: + You are tasked to generate a single JSON configuration for multiple function schemas. + Each function schema should contribute five example utterances. + Please follow the template below, no other tokens allowed: <config> {{ - "name": "<function_name>", + "name": "{route_name}", "utterances": [ "<example_utterance_1>", "<example_utterance_2>", @@ -136,12 +137,12 @@ class Route(BaseModel): </config> Only include the "name" and "utterances" keys in your answer. - The "name" should match the function name and the "utterances" - should comprise a list of 5 example phrases that could be used to invoke - the function. Use real values instead of placeholders. + The "name" should match the provided route name and the "utterances" + should comprise a list of 5 example phrases for each function schema that could be used to invoke + the functions. Use real values instead of placeholders. - Input schema: - {function_schema} + Input schemas: + {formatted_schemas} """ llm_input = [Message(role="user", content=prompt)] diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 60f61536e08840e2c023fd12867cc960a14d3608..75aa3b7535f01ab6fcf64a98fcecd3e8249857af 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -25,7 +25,7 @@ class EncoderInfo(BaseModel): class RouteChoice(BaseModel): name: Optional[str] = None - function_call: Optional[dict] = None + function_call: Optional[list[dict]] = None similarity_score: Optional[float] = None diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index 416191bfaf7cc1f7919cd92b042a60b8806be74e..e3700e12cc902894f4b834863fea9716d26a1a0f 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -9,36 +9,39 @@ from semantic_router.utils.logger import logger import re -def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]: - if isinstance(item, BaseModel): - signature_parts = [] - for field_name, field_model in item.__annotations__.items(): - field_info = item.__fields__[field_name] - default_value = field_info.default - - if default_value: - default_repr = repr(default_value) - signature_part = ( - f"{field_name}: {field_model.__name__} = {default_repr}" - ) - else: - signature_part = f"{field_name}: {field_model.__name__}" - - signature_parts.append(signature_part) - signature = f"({', '.join(signature_parts)}) -> str" - schema = { - "name": item.__class__.__name__, - "description": item.__doc__, - "signature": signature, - } - else: - schema = { - "name": item.__name__, - "description": str(inspect.getdoc(item)), - "signature": str(inspect.signature(item)), - "output": str(inspect.signature(item).return_annotation), - } - return schema +def get_schemas(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, Any]]: + schemas = [] + for item in items: + if isinstance(item, BaseModel): + signature_parts = [] + for field_name, field_model in item.__annotations__.items(): + field_info = item.__fields__[field_name] + default_value = field_info.default + + if default_value: + default_repr = repr(default_value) + signature_part = ( + f"{field_name}: {field_model.__name__} = {default_repr}" + ) + else: + signature_part = f"{field_name}: {field_model.__name__}" + + signature_parts.append(signature_part) + signature = f"({', '.join(signature_parts)}) -> str" + schema = { + "name": item.__class__.__name__, + "description": item.__doc__, + "signature": signature, + } + else: + schema = { + "name": item.__name__, + "description": str(inspect.getdoc(item)), + "signature": str(inspect.signature(item)), + "output": str(inspect.signature(item).return_annotation), + } + schemas.append(schema) + return schemas def convert_param_type_to_json_type(param_type: str) -> str: @@ -58,47 +61,51 @@ def convert_param_type_to_json_type(param_type: str) -> str: return "object" -def get_schema_openai(item: Callable) -> Dict[str, Any]: - if not callable(item): - raise ValueError("Provided item must be a callable function.") - - docstring = inspect.getdoc(item) - signature = inspect.signature(item) - - schema = { - "type": "function", - "function": { - "name": item.__name__, - "description": docstring if docstring else "No description available.", - "parameters": {"type": "object", "properties": {}, "required": []}, - }, - } - - for param_name, param in signature.parameters.items(): - param_type = ( - param.annotation.__name__ - if param.annotation != inspect.Parameter.empty - else "Any" - ) - param_description = "No description available." - param_required = param.default is inspect.Parameter.empty - - # Attempt to extract the parameter description from the docstring - if docstring: - param_doc_regex = re.compile(rf":param {param_name}:(.*?)\n(?=:\w|$)", re.S) - match = param_doc_regex.search(docstring) - if match: - param_description = match.group(1).strip() - - schema["function"]["parameters"]["properties"][param_name] = { # type: ignore - "type": convert_param_type_to_json_type(param_type), - "description": param_description, - } +def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: + schemas = [] + for item in items: + if not callable(item): + raise ValueError("Provided item must be a callable function.") + + docstring = inspect.getdoc(item) + signature = inspect.signature(item) - if param_required: - schema["function"]["parameters"]["required"].append(param_name) # type: ignore + schema = { + "type": "function", + "function": { + "name": item.__name__, + "description": docstring if docstring else "No description available.", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + } - return schema + for param_name, param in signature.parameters.items(): + param_type = ( + param.annotation.__name__ + if param.annotation != inspect.Parameter.empty + else "Any" + ) + param_description = "No description available." + param_required = param.default is inspect.Parameter.empty + + # Attempt to extract the parameter description from the docstring + if docstring: + param_doc_regex = re.compile(rf":param {param_name}:(.*?)\n(?=:\w|$)", re.S) + match = param_doc_regex.search(docstring) + if match: + param_description = match.group(1).strip() + + schema["function"]["parameters"]["properties"][param_name] = { + "type": convert_param_type_to_json_type(param_type), + "description": param_description, + } + + if param_required: + schema["function"]["parameters"]["required"].append(param_name) + + schemas.append(schema) + + return schemas # TODO: Add route layer object to the input, solve circular import issue diff --git a/tests/unit/llms/test_llm_openai.py b/tests/unit/llms/test_llm_openai.py index c8fdde4a43b7fffa83500ec586e7b127d1190fe8..45e9280ccc3e51c1ff2b64c754f33edfbb93060e 100644 --- a/tests/unit/llms/test_llm_openai.py +++ b/tests/unit/llms/test_llm_openai.py @@ -102,8 +102,8 @@ class TestOpenAILLM: openai_llm.client.chat.completions, "create", return_value=mock_completion ) llm_input = [Message(role="user", content="test")] - function_schema = {"type": "function", "name": "sample_function"} - output = openai_llm(llm_input, function_schema) + function_schemas = [{"type": "function", "name": "sample_function"}] + output = openai_llm(llm_input, function_schemas) assert ( output == "result" ), "Output did not match expected result with function schema" @@ -115,10 +115,10 @@ class TestOpenAILLM: openai_llm.client.chat.completions, "create", return_value=mock_completion ) llm_input = [Message(role="user", content="test")] - function_schema = {"type": "function", "name": "sample_function"} + function_schemas = [{"type": "function", "name": "sample_function"}] with pytest.raises(Exception) as exc_info: - openai_llm(llm_input, function_schema) + openai_llm(llm_input, function_schemas) expected_error_message = "LLM error: Invalid output, expected a tool call." actual_error_message = str(exc_info.value) @@ -135,10 +135,10 @@ class TestOpenAILLM: openai_llm.client.chat.completions, "create", return_value=mock_completion ) llm_input = [Message(role="user", content="test")] - function_schema = {"type": "function", "name": "sample_function"} + function_schemas = [{"type": "function", "name": "sample_function"}] with pytest.raises(Exception) as exc_info: - openai_llm(llm_input, function_schema) + openai_llm(llm_input, function_schemas) expected_error_message = ( "LLM error: Invalid output, expected arguments to be specified." @@ -158,10 +158,10 @@ class TestOpenAILLM: openai_llm.client.chat.completions, "create", return_value=mock_completion ) llm_input = [Message(role="user", content="test")] - function_schema = {"type": "function", "name": "sample_function"} + function_schemas = [{"type": "function", "name": "sample_function"}] with pytest.raises(Exception) as exc_info: - openai_llm(llm_input, function_schema) + openai_llm(llm_input, function_schemas) expected_error_message = ( "LLM error: Invalid output, expected a single tool to be specified." @@ -184,11 +184,11 @@ class TestOpenAILLM: def test_extract_function_inputs(self, openai_llm, mocker): query = "fetch user data" - function_schema = {"function": "get_user_data", "args": ["user_id"]} + function_schemas = [{"function": "get_user_data", "args": ["user_id"]}] # Mock the __call__ method to return a JSON string as expected mocker.patch.object(OpenAILLM, "__call__", return_value='{"user_id": "123"}') - result = openai_llm.extract_function_inputs(query, function_schema) + result = openai_llm.extract_function_inputs(query, function_schemas) # Ensure the __call__ method is called with the correct parameters expected_messages = [ @@ -199,7 +199,7 @@ class TestOpenAILLM: Message(role="user", content=query), ] openai_llm.__call__.assert_called_once_with( - messages=expected_messages, function_schema=function_schema + messages=expected_messages, function_schemas=function_schemas ) # Check if the result is as expected