diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb index d122c9653d731f032a75d3eb44465ae029e9fbcf..d2dd0dea3ab36120f28fcad6e7282e949bcd23a2 100644 --- a/docs/02-dynamic-routes.ipynb +++ b/docs/02-dynamic-routes.ipynb @@ -86,6 +86,14 @@ "name": "stderr", "output_type": "stream", "text": [ + "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution ~lama-cpp-python (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution ~lama-cpp-python (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution ~lama-cpp-python (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution ~lama-cpp-python (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", "\n", "[notice] A new release of pip is available: 23.1.2 -> 24.0\n", "[notice] To update, run: python.exe -m pip install --upgrade pip\n" @@ -94,7 +102,7 @@ ], "source": [ "!pip install tzdata\n", - "!pip install -qU semantic-router" + "# !pip install -qU semantic-router" ] }, { @@ -182,7 +190,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-04 01:12:56 INFO semantic_router.utils.logger local\u001b[0m\n" + "\u001b[32m2024-05-06 16:01:19 INFO semantic_router.utils.logger local\u001b[0m\n" ] } ], @@ -301,7 +309,7 @@ { "data": { "text/plain": [ - "'17:12'" + "'08:01'" ] }, "execution_count": 6, @@ -324,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -345,7 +353,7 @@ " 'required': ['timezone']}}}]" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -368,7 +376,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": { "id": "iesBG9P3ur0z" }, @@ -387,7 +395,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -405,7 +413,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -418,7 +426,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-04 01:12:58 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" + "\u001b[32m2024-05-06 16:01:20 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" ] } ], @@ -428,7 +436,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -446,7 +454,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -460,7 +468,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2024-05-04 01:12:59 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" + "\u001b[33m2024-05-06 16:01:21 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", + "\u001b[32m2024-05-06 16:01:22 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" ] }, { @@ -469,7 +478,7 @@ "RouteChoice(name='get_time', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}], similarity_score=None)" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -481,7 +490,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -498,14 +507,14 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "17:13\n" + "08:01\n" ] } ], @@ -560,7 +569,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -639,7 +648,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -648,7 +657,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -683,7 +692,7 @@ " 'required': ['time', 'from_timezone', 'to_timezone']}}}]" ] }, - "execution_count": 18, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -697,7 +706,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -735,7 +744,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -744,14 +753,14 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-04 01:13:00 INFO semantic_router.utils.logger local\u001b[0m\n" + "\u001b[32m2024-05-06 16:01:22 INFO semantic_router.utils.logger local\u001b[0m\n" ] } ], @@ -768,7 +777,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -796,7 +805,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -805,7 +814,7 @@ "RouteChoice(name='politics', function_call=None, similarity_score=None)" ] }, - "execution_count": 23, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -824,7 +833,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -833,7 +842,7 @@ "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" ] }, - "execution_count": 24, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -852,14 +861,15 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2024-05-04 01:13:02 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" + "\u001b[33m2024-05-06 16:01:24 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", + "\u001b[32m2024-05-06 16:01:25 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" ] }, { @@ -868,7 +878,7 @@ "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}], similarity_score=None)" ] }, - "execution_count": 25, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -880,14 +890,14 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "17:13\n" + "08:01\n" ] } ], @@ -904,16 +914,23 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 26, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-06 16:01:26 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}]\u001b[0m\n" + ] + }, { "data": { "text/plain": [ "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}], similarity_score=None)" ] }, - "execution_count": 27, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -925,7 +942,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -949,16 +966,23 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 28, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-06 16:01:28 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}]\u001b[0m\n" + ] + }, { "data": { "text/plain": [ "RouteChoice(name='timezone_management', function_call=[{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}], similarity_score=None)" ] }, - "execution_count": 29, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -970,7 +994,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -994,9 +1018,17 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 30, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-06 16:01:31 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'Europe/Prague'}}, {'function_name': 'get_time_difference', 'arguments': {'timezone1': 'Europe/Berlin', 'timezone2': 'Asia/Shanghai'}}, {'function_name': 'convert_time', 'arguments': {'time': '05:53', 'from_timezone': 'Europe/Lisbon', 'to_timezone': 'Asia/Bangkok'}}]\u001b[0m\n" + ] + } + ], "source": [ "response = rl2(\"\"\"\n", " What is the time in Prague?\n", @@ -1008,7 +1040,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -1017,7 +1049,7 @@ "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'Europe/Prague'}}, {'function_name': 'get_time_difference', 'arguments': {'timezone1': 'Europe/Berlin', 'timezone2': 'Asia/Shanghai'}}, {'function_name': 'convert_time', 'arguments': {'time': '05:53', 'from_timezone': 'Europe/Lisbon', 'to_timezone': 'Asia/Bangkok'}}], similarity_score=None)" ] }, - "execution_count": 32, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -1028,14 +1060,14 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "23:13\n", + "14:01\n", "The time difference between Europe/Berlin and Asia/Shanghai is 6.0 hours.\n", "11:53\n" ] diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 2bd7aa3601aa54eadc138100a9c2a1518d342f3b..6d8e95b095414e87bbd75fd4d489c4029140cd9f 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -24,6 +24,18 @@ class BaseLLM(BaseModel): ) -> bool: """Determine if the functions chosen by the LLM exist within the function_schemas, and if the input arguments are valid for those functions.""" + # DEBUGGING: Start. + print('#'*50) + print('inputs') + print(inputs) + print('#'*50) + # DEBUGGING: End. + # DEBUGGING: Start. + print('#'*50) + print('function_schemas') + print(function_schemas) + print('#'*50) + # DEBUGGING: End. try: for input_dict in inputs: # Check if 'function_name' and 'arguments' keys exist in each input dictionary @@ -78,106 +90,70 @@ class BaseLLM(BaseModel): return param_names, param_types def extract_function_inputs( - self, query: str, function_schemas: List[Dict[str, Any]] - ) -> Dict: + self, query: str, function_schema: Dict[str, Any] + ) -> List[Dict[str, Any]]: logger.info("Extracting function input...") prompt = f""" You are an accurate and reliable computer program that only outputs valid JSON. -Your task is to: - 1) Pick the most relevant Python function schema(s) from FUNCTION_SCHEMAS below, based on the input QUERY. If only one schema is provided, choose that. If multiple schemas are relevant, output a list of JSON objects for each. - 2) Output JSON representing the input arguments of the chosen function schema(s), including the function name, with argument values determined by information in the QUERY. +Your task is to output JSON representing the input arguments of a Python function. -These are the Python functions' schema: +This is the Python function's schema: -### FUNCTION_SCHEMAS Start ### - {json.dumps(function_schemas, indent=4)} -### FUNCTION_SCHEMAS End ### +### FUNCTION_SCHEMA Start ### + {function_schema} +### FUNCTION_SCHEMA End ### This is the input query. ### QUERY Start ### - {query} + {query} ### QUERY End ### The arguments that you need to provide values for, together with their datatypes, are stated in "signature" in the FUNCTION_SCHEMA. The values these arguments must take are made clear by the QUERY. Use the FUNCTION_SCHEMA "description" too, as this might provide helpful clues about the arguments and their values. -Include the function name in your JSON output. -Return only JSON, stating the function name and the argument names with their corresponding values. +Return only JSON, stating the argument names and their corresponding values. ### FORMATTING_INSTRUCTIONS Start ### - Return a response in valid JSON format. Do not return any other explanation or text, just the JSON. - The JSON output should always be an array of JSON objects. If only one function is relevant, return an array with a single JSON object. - Each JSON object should include a key 'function_name' with the value being the name of the function. - Under the key 'arguments', include a nested JSON object where the keys are the names of the arguments and the values are the values those arguments should take. + Return a respones in valid JSON format. Do not return any other explanation or text, just the JSON. + The JSON-Keys are the names of the arguments, and JSON-values are the values those arguments should take. ### FORMATTING_INSTRUCTIONS End ### ### EXAMPLE Start ### - === EXAMPLE_INPUT_QUERY Start === - "What is the temperature in Hawaii and New York right now in Celsius, and what is the humidity in Hawaii?" - === EXAMPLE_INPUT_QUERY End === - === EXAMPLE_INPUT_SCHEMA Start === - {{ - "name": "get_temperature", - "description": "Useful to get the temperature in a specific location", - "signature": "(location: str, degree: str) -> str", - "output": "<class 'str'>", - }} - {{ - "name": "get_humidity", - "description": "Useful to get the humidity level in a specific location", - "signature": "(location: str) -> int", - "output": "<class 'int'>", - }} - {{ - "name": "get_wind_speed", - "description": "Useful to get the wind speed in a specific location", - "signature": "(location: str) -> float", - "output": "<class 'float'>", - }} - === EXAMPLE_INPUT_SCHEMA End === - === EXAMPLE_OUTPUT Start === - [ - {{ - "function_name": "get_temperature", - "arguments": {{ - "location": "Hawaii", - "degree": "Celsius" - }} - }}, - {{ - "function_name": "get_temperature", - "arguments": {{ - "location": "New York", - "degree": "Celsius" - }} - }}, - {{ - "function_name": "get_humidity", - "arguments": {{ - "location": "Hawaii" - }} - }} - ] - === EXAMPLE_OUTPUT End === + === EXAMPLE_INPUT_QUERY Start === + "How is the weather in Hawaii right now in International units?" + === EXAMPLE_INPUT_QUERY End === + === EXAMPLE_INPUT_SCHEMA Start === + {{ + "name": "get_weather", + "description": "Useful to get the weather in a specific location", + "signature": "(location: str, degree: str) -> str", + "output": "<class 'str'>", + }} + === EXAMPLE_INPUT_QUERY End === + === EXAMPLE_OUTPUT Start === + {{ + "location": "Hawaii", + "degree": "Celsius", + }} + === EXAMPLE_OUTPUT End === ### EXAMPLE End ### -Note: I will tip $500 for an accurate JSON output. You will be penalized for an inaccurate JSON output. +Note: I will tip $500 for and accurate JSON output. You will be penalized for an inaccurate JSON output. Provide JSON output now: - """ +""" llm_input = [Message(role="user", content=prompt)] output = self(llm_input) if not output: raise Exception("No output generated for extract function input") - output = output.replace("'", '"').strip().rstrip(",") logger.info(f"LLM output: {output}") function_inputs = json.loads(output) if not isinstance(function_inputs, list): # Local LLMs return a single JSON object that isn't in an array sometimes. function_inputs = [function_inputs] logger.info(f"Function inputs: {function_inputs}") - if not self._is_valid_inputs(function_inputs, function_schemas): + if not self._is_valid_inputs(function_inputs, [function_schema]): raise ValueError("Invalid inputs") return function_inputs \ No newline at end of file diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 1e121a3e9952b71bb1f4ca812266cd9db24d584a..be9a10dd47a6684ed8281a10f360c12e9235cac6 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -104,7 +104,69 @@ class OpenAILLM(BaseLLM): system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." messages.append(Message(role="system", content=system_prompt)) messages.append(Message(role="user", content=query)) - return self(messages=messages, function_schemas=function_schemas) + function_inputs = self(messages=messages, function_schemas=function_schemas) + logger.info(f"Function inputs: {function_inputs}") + if not self._is_valid_inputs(function_inputs, function_schemas): + raise ValueError("Invalid inputs") + return function_inputs + + def _is_valid_inputs( + self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]] + ) -> bool: + """Determine if the functions chosen by the LLM exist within the function_schemas, + and if the input arguments are valid for those functions.""" + try: + for input_dict in inputs: + # Check if 'function_name' and 'arguments' keys exist in each input dictionary + if "function_name" not in input_dict or "arguments" not in input_dict: + logger.error("Missing 'function_name' or 'arguments' in inputs") + return False + + function_name = input_dict["function_name"] + arguments = input_dict["arguments"] + + # Find the matching function schema based on function_name + matching_schema = next((schema['function'] for schema in function_schemas if schema['function']['name'] == function_name), None) + if not matching_schema: + logger.error(f"No matching function schema found for function name: {function_name}") + return False + + # Validate the inputs against the function schema + if not self._validate_single_function_inputs(arguments, matching_schema): + logger.error(f"Validation failed for function name: {function_name}") + return False + + return True + except Exception as e: + logger.error(f"Input validation error: {str(e)}") + return False + + def _validate_single_function_inputs(self, inputs: Dict[str, Any], function_schema: Dict[str, Any]) -> bool: + """Validate the extracted inputs against the function schema""" + try: + # Access the parameters and their properties from the function schema directly + parameters = function_schema['parameters']['properties'] + required_params = function_schema['parameters'].get('required', []) + + # Check if all required parameters are present in the inputs + for param_name in required_params: + if param_name not in inputs: + logger.error(f"Required input '{param_name}' missing from query") + return False + + # Check if the types of the inputs match the expected types (if type checking is needed) + for param_name, param_info in parameters.items(): + if param_name in inputs: + expected_type = param_info['type'] + # This is a simple type check, consider expanding it based on your needs + if expected_type == 'string' and not isinstance(inputs[param_name], str): + logger.error(f"Input type for '{param_name}' is not {expected_type}") + return False + + return True + except Exception as e: + logger.error(f"Single input validation error: {str(e)}") + return False def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: schemas = []