From 1e6a45d4b4602f4c9e38f4d4ec3b42b0203bdf94 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Mon, 6 May 2024 16:37:18 +0400
Subject: [PATCH] Esnuring that we check the OpenAI function arguments genrated
 from dynamic routes.

---
 docs/02-dynamic-routes.ipynb   | 124 +++++++++++++++++++++------------
 semantic_router/llms/base.py   | 110 ++++++++++++-----------------
 semantic_router/llms/openai.py |  64 ++++++++++++++++-
 3 files changed, 184 insertions(+), 114 deletions(-)

diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb
index d122c965..d2dd0dea 100644
--- a/docs/02-dynamic-routes.ipynb
+++ b/docs/02-dynamic-routes.ipynb
@@ -86,6 +86,14 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
+            "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n",
+            "WARNING: Ignoring invalid distribution ~lama-cpp-python (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n",
+            "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n",
+            "WARNING: Ignoring invalid distribution ~lama-cpp-python (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n",
+            "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n",
+            "WARNING: Ignoring invalid distribution ~lama-cpp-python (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n",
+            "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n",
+            "WARNING: Ignoring invalid distribution ~lama-cpp-python (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n",
             "\n",
             "[notice] A new release of pip is available: 23.1.2 -> 24.0\n",
             "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
@@ -94,7 +102,7 @@
       ],
       "source": [
         "!pip install tzdata\n",
-        "!pip install -qU semantic-router"
+        "# !pip install -qU semantic-router"
       ]
     },
     {
@@ -182,7 +190,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[32m2024-05-04 01:12:56 INFO semantic_router.utils.logger local\u001b[0m\n"
+            "\u001b[32m2024-05-06 16:01:19 INFO semantic_router.utils.logger local\u001b[0m\n"
           ]
         }
       ],
@@ -301,7 +309,7 @@
         {
           "data": {
             "text/plain": [
-              "'17:12'"
+              "'08:01'"
             ]
           },
           "execution_count": 6,
@@ -324,7 +332,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 8,
+      "execution_count": 7,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
@@ -345,7 +353,7 @@
               "    'required': ['timezone']}}}]"
             ]
           },
-          "execution_count": 8,
+          "execution_count": 7,
           "metadata": {},
           "output_type": "execute_result"
         }
@@ -368,7 +376,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 9,
+      "execution_count": 8,
       "metadata": {
         "id": "iesBG9P3ur0z"
       },
@@ -387,7 +395,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 10,
+      "execution_count": 9,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -405,7 +413,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 11,
+      "execution_count": 10,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
@@ -418,7 +426,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[32m2024-05-04 01:12:58 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n"
+            "\u001b[32m2024-05-06 16:01:20 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n"
           ]
         }
       ],
@@ -428,7 +436,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 12,
+      "execution_count": 11,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -446,7 +454,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 13,
+      "execution_count": 12,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/",
@@ -460,7 +468,8 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[33m2024-05-04 01:12:59 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n"
+            "\u001b[33m2024-05-06 16:01:21 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n",
+            "\u001b[32m2024-05-06 16:01:22 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n"
           ]
         },
         {
@@ -469,7 +478,7 @@
               "RouteChoice(name='get_time', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}], similarity_score=None)"
             ]
           },
-          "execution_count": 13,
+          "execution_count": 12,
           "metadata": {},
           "output_type": "execute_result"
         }
@@ -481,7 +490,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 14,
+      "execution_count": 13,
       "metadata": {},
       "outputs": [
         {
@@ -498,14 +507,14 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 15,
+      "execution_count": 14,
       "metadata": {},
       "outputs": [
         {
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "17:13\n"
+            "08:01\n"
           ]
         }
       ],
@@ -560,7 +569,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 16,
+      "execution_count": 15,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -639,7 +648,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 17,
+      "execution_count": 16,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -648,7 +657,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 18,
+      "execution_count": 17,
       "metadata": {},
       "outputs": [
         {
@@ -683,7 +692,7 @@
               "    'required': ['time', 'from_timezone', 'to_timezone']}}}]"
             ]
           },
-          "execution_count": 18,
+          "execution_count": 17,
           "metadata": {},
           "output_type": "execute_result"
         }
@@ -697,7 +706,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 19,
+      "execution_count": 18,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -735,7 +744,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 20,
+      "execution_count": 19,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -744,14 +753,14 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 21,
+      "execution_count": 20,
       "metadata": {},
       "outputs": [
         {
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[32m2024-05-04 01:13:00 INFO semantic_router.utils.logger local\u001b[0m\n"
+            "\u001b[32m2024-05-06 16:01:22 INFO semantic_router.utils.logger local\u001b[0m\n"
           ]
         }
       ],
@@ -768,7 +777,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 22,
+      "execution_count": 21,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -796,7 +805,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 23,
+      "execution_count": 22,
       "metadata": {},
       "outputs": [
         {
@@ -805,7 +814,7 @@
               "RouteChoice(name='politics', function_call=None, similarity_score=None)"
             ]
           },
-          "execution_count": 23,
+          "execution_count": 22,
           "metadata": {},
           "output_type": "execute_result"
         }
@@ -824,7 +833,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 24,
+      "execution_count": 23,
       "metadata": {},
       "outputs": [
         {
@@ -833,7 +842,7 @@
               "RouteChoice(name='chitchat', function_call=None, similarity_score=None)"
             ]
           },
-          "execution_count": 24,
+          "execution_count": 23,
           "metadata": {},
           "output_type": "execute_result"
         }
@@ -852,14 +861,15 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 25,
+      "execution_count": 24,
       "metadata": {},
       "outputs": [
         {
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[33m2024-05-04 01:13:02 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n"
+            "\u001b[33m2024-05-06 16:01:24 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n",
+            "\u001b[32m2024-05-06 16:01:25 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n"
           ]
         },
         {
@@ -868,7 +878,7 @@
               "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}], similarity_score=None)"
             ]
           },
-          "execution_count": 25,
+          "execution_count": 24,
           "metadata": {},
           "output_type": "execute_result"
         }
@@ -880,14 +890,14 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 26,
+      "execution_count": 25,
       "metadata": {},
       "outputs": [
         {
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "17:13\n"
+            "08:01\n"
           ]
         }
       ],
@@ -904,16 +914,23 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 27,
+      "execution_count": 26,
       "metadata": {},
       "outputs": [
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "\u001b[32m2024-05-06 16:01:26 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}]\u001b[0m\n"
+          ]
+        },
         {
           "data": {
             "text/plain": [
               "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}], similarity_score=None)"
             ]
           },
-          "execution_count": 27,
+          "execution_count": 26,
           "metadata": {},
           "output_type": "execute_result"
         }
@@ -925,7 +942,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 28,
+      "execution_count": 27,
       "metadata": {},
       "outputs": [
         {
@@ -949,16 +966,23 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 29,
+      "execution_count": 28,
       "metadata": {},
       "outputs": [
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "\u001b[32m2024-05-06 16:01:28 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}]\u001b[0m\n"
+          ]
+        },
         {
           "data": {
             "text/plain": [
               "RouteChoice(name='timezone_management', function_call=[{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}], similarity_score=None)"
             ]
           },
-          "execution_count": 29,
+          "execution_count": 28,
           "metadata": {},
           "output_type": "execute_result"
         }
@@ -970,7 +994,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 30,
+      "execution_count": 29,
       "metadata": {},
       "outputs": [
         {
@@ -994,9 +1018,17 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 31,
+      "execution_count": 30,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "\u001b[32m2024-05-06 16:01:31 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'Europe/Prague'}}, {'function_name': 'get_time_difference', 'arguments': {'timezone1': 'Europe/Berlin', 'timezone2': 'Asia/Shanghai'}}, {'function_name': 'convert_time', 'arguments': {'time': '05:53', 'from_timezone': 'Europe/Lisbon', 'to_timezone': 'Asia/Bangkok'}}]\u001b[0m\n"
+          ]
+        }
+      ],
       "source": [
         "response = rl2(\"\"\"\n",
         "    What is the time in Prague?\n",
@@ -1008,7 +1040,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 32,
+      "execution_count": 31,
       "metadata": {},
       "outputs": [
         {
@@ -1017,7 +1049,7 @@
               "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'Europe/Prague'}}, {'function_name': 'get_time_difference', 'arguments': {'timezone1': 'Europe/Berlin', 'timezone2': 'Asia/Shanghai'}}, {'function_name': 'convert_time', 'arguments': {'time': '05:53', 'from_timezone': 'Europe/Lisbon', 'to_timezone': 'Asia/Bangkok'}}], similarity_score=None)"
             ]
           },
-          "execution_count": 32,
+          "execution_count": 31,
           "metadata": {},
           "output_type": "execute_result"
         }
@@ -1028,14 +1060,14 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 33,
+      "execution_count": 32,
       "metadata": {},
       "outputs": [
         {
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "23:13\n",
+            "14:01\n",
             "The time difference between Europe/Berlin and Asia/Shanghai is 6.0 hours.\n",
             "11:53\n"
           ]
diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py
index 2bd7aa36..6d8e95b0 100644
--- a/semantic_router/llms/base.py
+++ b/semantic_router/llms/base.py
@@ -24,6 +24,18 @@ class BaseLLM(BaseModel):
     ) -> bool:
         """Determine if the functions chosen by the LLM exist within the function_schemas, 
         and if the input arguments are valid for those functions."""
+        # DEBUGGING: Start.
+        print('#'*50)
+        print('inputs')
+        print(inputs)
+        print('#'*50)
+        # DEBUGGING: End.
+        # DEBUGGING: Start.
+        print('#'*50)
+        print('function_schemas')
+        print(function_schemas)
+        print('#'*50)
+        # DEBUGGING: End.
         try:
             for input_dict in inputs:
                 # Check if 'function_name' and 'arguments' keys exist in each input dictionary
@@ -78,106 +90,70 @@ class BaseLLM(BaseModel):
         return param_names, param_types
 
     def extract_function_inputs(
-        self, query: str, function_schemas: List[Dict[str, Any]]
-    ) -> Dict:
+        self, query: str, function_schema: Dict[str, Any]
+    ) -> List[Dict[str, Any]]:
         logger.info("Extracting function input...")
 
         prompt = f"""
 You are an accurate and reliable computer program that only outputs valid JSON. 
-Your task is to:
-    1) Pick the most relevant Python function schema(s) from FUNCTION_SCHEMAS below, based on the input QUERY. If only one schema is provided, choose that. If multiple schemas are relevant, output a list of JSON objects for each.
-    2) Output JSON representing the input arguments of the chosen function schema(s), including the function name, with argument values determined by information in the QUERY.
+Your task is to output JSON representing the input arguments of a Python function.
 
-These are the Python functions' schema:
+This is the Python function's schema:
 
-### FUNCTION_SCHEMAS Start ###
-    {json.dumps(function_schemas, indent=4)}
-### FUNCTION_SCHEMAS End ###
+### FUNCTION_SCHEMA Start ###
+	{function_schema}
+### FUNCTION_SCHEMA End ###
 
 This is the input query.
 
 ### QUERY Start ###
-    {query}
+	{query}
 ### QUERY End ###
 
 The arguments that you need to provide values for, together with their datatypes, are stated in "signature" in the FUNCTION_SCHEMA.
 The values these arguments must take are made clear by the QUERY.
 Use the FUNCTION_SCHEMA "description" too, as this might provide helpful clues about the arguments and their values.
-Include the function name in your JSON output.
-Return only JSON, stating the function name and the argument names with their corresponding values.
+Return only JSON, stating the argument names and their corresponding values.
 
 ### FORMATTING_INSTRUCTIONS Start ###
-    Return a response in valid JSON format. Do not return any other explanation or text, just the JSON.
-    The JSON output should always be an array of JSON objects. If only one function is relevant, return an array with a single JSON object.
-    Each JSON object should include a key 'function_name' with the value being the name of the function.
-    Under the key 'arguments', include a nested JSON object where the keys are the names of the arguments and the values are the values those arguments should take.
+	Return a respones in valid JSON format. Do not return any other explanation or text, just the JSON.
+	The JSON-Keys are the names of the arguments, and JSON-values are the values those arguments should take.
 ### FORMATTING_INSTRUCTIONS End ###
 
 ### EXAMPLE Start ###
-    === EXAMPLE_INPUT_QUERY Start ===
-        "What is the temperature in Hawaii and New York right now in Celsius, and what is the humidity in Hawaii?"
-    === EXAMPLE_INPUT_QUERY End ===
-    === EXAMPLE_INPUT_SCHEMA Start ===
-        {{
-            "name": "get_temperature",
-            "description": "Useful to get the temperature in a specific location",
-            "signature": "(location: str, degree: str) -> str",
-            "output": "<class 'str'>",
-        }}
-        {{
-            "name": "get_humidity",
-            "description": "Useful to get the humidity level in a specific location",
-            "signature": "(location: str) -> int",
-            "output": "<class 'int'>",
-        }}
-        {{
-            "name": "get_wind_speed",
-            "description": "Useful to get the wind speed in a specific location",
-            "signature": "(location: str) -> float",
-            "output": "<class 'float'>",
-        }}
-    === EXAMPLE_INPUT_SCHEMA End ===
-    === EXAMPLE_OUTPUT Start ===
-        [
-            {{
-                "function_name": "get_temperature",
-                "arguments": {{
-                    "location": "Hawaii",
-                    "degree": "Celsius"
-                }}
-            }},
-            {{
-                "function_name": "get_temperature",
-                "arguments": {{
-                    "location": "New York",
-                    "degree": "Celsius"
-                }}
-            }},
-            {{
-                "function_name": "get_humidity",
-                "arguments": {{
-                    "location": "Hawaii"
-                }}
-            }}
-        ]
-    === EXAMPLE_OUTPUT End ===
+	=== EXAMPLE_INPUT_QUERY Start ===
+		"How is the weather in Hawaii right now in International units?"
+	=== EXAMPLE_INPUT_QUERY End ===
+	=== EXAMPLE_INPUT_SCHEMA Start ===
+		{{
+			"name": "get_weather",
+			"description": "Useful to get the weather in a specific location",
+			"signature": "(location: str, degree: str) -> str",
+			"output": "<class 'str'>",
+		}}
+	=== EXAMPLE_INPUT_QUERY End ===
+	=== EXAMPLE_OUTPUT Start ===
+		{{
+			"location": "Hawaii",
+			"degree": "Celsius",
+		}}
+	=== EXAMPLE_OUTPUT End ===
 ### EXAMPLE End ###
 
-Note: I will tip $500 for an accurate JSON output. You will be penalized for an inaccurate JSON output.
+Note: I will tip $500 for and accurate JSON output. You will be penalized for an inaccurate JSON output.
 
 Provide JSON output now:
-    """
+"""
         llm_input = [Message(role="user", content=prompt)]
         output = self(llm_input)
         if not output:
             raise Exception("No output generated for extract function input")
-
         output = output.replace("'", '"').strip().rstrip(",")
         logger.info(f"LLM output: {output}")
         function_inputs = json.loads(output)
         if not isinstance(function_inputs, list): # Local LLMs return a single JSON object that isn't in an array sometimes.
             function_inputs = [function_inputs]
         logger.info(f"Function inputs: {function_inputs}")
-        if not self._is_valid_inputs(function_inputs, function_schemas):
+        if not self._is_valid_inputs(function_inputs, [function_schema]):
             raise ValueError("Invalid inputs")
         return function_inputs
\ No newline at end of file
diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py
index 1e121a3e..be9a10dd 100644
--- a/semantic_router/llms/openai.py
+++ b/semantic_router/llms/openai.py
@@ -104,7 +104,69 @@ class OpenAILLM(BaseLLM):
         system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request."
         messages.append(Message(role="system", content=system_prompt))
         messages.append(Message(role="user", content=query))
-        return self(messages=messages, function_schemas=function_schemas)
+        function_inputs = self(messages=messages, function_schemas=function_schemas)
+        logger.info(f"Function inputs: {function_inputs}")
+        if not self._is_valid_inputs(function_inputs, function_schemas):
+            raise ValueError("Invalid inputs")
+        return function_inputs
+    
+    def _is_valid_inputs(
+        self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]]
+    ) -> bool:
+        """Determine if the functions chosen by the LLM exist within the function_schemas, 
+        and if the input arguments are valid for those functions."""
+        try:
+            for input_dict in inputs:
+                # Check if 'function_name' and 'arguments' keys exist in each input dictionary
+                if "function_name" not in input_dict or "arguments" not in input_dict:
+                    logger.error("Missing 'function_name' or 'arguments' in inputs")
+                    return False
+
+                function_name = input_dict["function_name"]
+                arguments = input_dict["arguments"]
+
+                # Find the matching function schema based on function_name
+                matching_schema = next((schema['function'] for schema in function_schemas if schema['function']['name'] == function_name), None)
+                if not matching_schema:
+                    logger.error(f"No matching function schema found for function name: {function_name}")
+                    return False
+
+                # Validate the inputs against the function schema
+                if not self._validate_single_function_inputs(arguments, matching_schema):
+                    logger.error(f"Validation failed for function name: {function_name}")
+                    return False
+
+            return True
+        except Exception as e:
+            logger.error(f"Input validation error: {str(e)}")
+            return False
+
+    def _validate_single_function_inputs(self, inputs: Dict[str, Any], function_schema: Dict[str, Any]) -> bool:
+        """Validate the extracted inputs against the function schema"""
+        try:
+            # Access the parameters and their properties from the function schema directly
+            parameters = function_schema['parameters']['properties']
+            required_params = function_schema['parameters'].get('required', [])
+
+            # Check if all required parameters are present in the inputs
+            for param_name in required_params:
+                if param_name not in inputs:
+                    logger.error(f"Required input '{param_name}' missing from query")
+                    return False
+
+            # Check if the types of the inputs match the expected types (if type checking is needed)
+            for param_name, param_info in parameters.items():
+                if param_name in inputs:
+                    expected_type = param_info['type']
+                    # This is a simple type check, consider expanding it based on your needs
+                    if expected_type == 'string' and not isinstance(inputs[param_name], str):
+                        logger.error(f"Input type for '{param_name}' is not {expected_type}")
+                        return False
+
+            return True
+        except Exception as e:
+            logger.error(f"Single input validation error: {str(e)}")
+            return False
 
 def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]:
     schemas = []
-- 
GitLab