From 8392895cd6ce8c3c064df0590a90ab0a3f7f8cb7 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Mon, 13 May 2024 17:15:11 +0400 Subject: [PATCH] _is_valid_inputs() fixes. Now no longer requires typehints in the function signature (it wasn't using these anyway, and would break when they weren't included. Also, we now only check if mandatary arguments have been provided in input. None mandatory don't need to be present. Finally, addde a check to ensure that, if there are extra arguments provided in input not present in the signature, then these result in false being returned. --- docs/10-debugging-discord-issue.ipynb | 163 ++++++++++++++++++++++++++ semantic_router/llms/base.py | 52 ++++++-- 2 files changed, 205 insertions(+), 10 deletions(-) create mode 100644 docs/10-debugging-discord-issue.ipynb diff --git a/docs/10-debugging-discord-issue.ipynb b/docs/10-debugging-discord-issue.ipynb new file mode 100644 index 00000000..9c781d4a --- /dev/null +++ b/docs/10-debugging-discord-issue.ipynb @@ -0,0 +1,163 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "\u001b[32m2024-05-13 17:10:58 INFO semantic_router.utils.logger local\u001b[0m\n", + "\u001b[32m2024-05-13 17:10:59 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n", + "\u001b[32m2024-05-13 17:11:01 INFO semantic_router.utils.logger LLM output: {\n", + " \"location\": \"berlin\"\n", + "}\u001b[0m\n", + "\u001b[32m2024-05-13 17:11:01 INFO semantic_router.utils.logger Function inputs: {'location': 'berlin'}\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-05-13 15:11\n" + ] + } + ], + "source": [ + "import datetime\n", + "import pytz\n", + "from semantic_router.llms.openrouter import OpenRouterLLM\n", + "from semantic_router import Route, RouteLayer\n", + "from semantic_router.encoders import HuggingFaceEncoder\n", + "from semantic_router.utils.function_call import get_schema\n", + "import geonamescache\n", + "\n", + "class Skill:\n", + " def __init__(self):\n", + " self.geocoder = geonamescache.GeonamesCache()\n", + " self.location = self.geocode_location()\n", + " self.route = Route(\n", + " name='time',\n", + " utterances=[\n", + " \"tell me what is the time\",\n", + " \"what is the date \",\n", + " \"time in varshava\",\n", + " \"date\",\n", + " \"what date is it today\",\n", + " \"time in ny\",\n", + " \"what is the time and date in boston\",\n", + " \"time\",\n", + " \"what is the time in makhachkala\",\n", + " \"date time in st petersburg\",\n", + " \"what's the date in vienna\",\n", + " \"date time\"\n", + " ],\n", + " function_schema=get_schema(self.run),\n", + "\n", + " )\n", + "\n", + " self.rl = RouteLayer(\n", + " encoder=HuggingFaceEncoder(),\n", + " routes=[self.route],\n", + " llm=OpenRouterLLM(\n", + " name='mistralai/mistral-7b-instruct:free',\n", + " openrouter_api_key='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'\n", + "\n", + " )\n", + " )\n", + "\n", + " def geocode_location(self, location_name=None):\n", + " if location_name:\n", + " location_name = location_name.title()\n", + " location = list(\n", + " self.geocoder.get_cities_by_name(location_name)[0].values() if self.geocoder.get_cities_by_name(\n", + " location_name) else self.geocoder.get_us_states_by_names(location_name)[\n", + " 0].values() if self.geocoder.get_us_states_by_names(location_name) else\n", + " self.geocoder.get_countries_by_names(location_name)[\n", + " 0].values() if self.geocoder.get_countries_by_names(location_name) else None)[0]\n", + " return location['timezone']\n", + " else:\n", + " return ''\n", + "\n", + " def run(self, location:str=None, day:int=0, hour:int=0, minute:int=0):\n", + " \"\"\"Finds the current time in a specific location.\n", + "\n", + " :param location: The location to find the current time in, should\n", + " be a valid location. Put the place name itself\n", + " like \"rome\", or \"new york\" in the lowercase.\n", + " :type location: str\n", + "\n", + " :param day: The offset in days from the current date.\n", + " Use positive integers for future dates (e.g., day=1 for tomorrow),\n", + " negative integers for past dates (e.g., day=-1 for yesterday),\n", + " and 0 for the current date.\n", + " :type day: int\n", + "\n", + " :param hour: The offset in hours from the current time.\n", + " Use positive integers for future times (e.g., hour=1 for one hour ahead),\n", + " negative integers for past times (e.g., hour=-1 for one hour ago),\n", + " and 0 to maintain the current hour.\n", + " :type hour: int\n", + "\n", + " :param minute: The offset in minutes from the current time.\n", + " Use positive integers for future minutes (e.g., minute=20 for twenty minutes ahead),\n", + " negative integers for past minutes (e.g., minute=-20 for twenty minutes ago),\n", + " and 0 to maintain the current minute.\n", + " :type minute: int\n", + "\n", + " :return: The time in the specified location.\"\"\"\n", + " timezone = self.geocode_location(location)\n", + " if timezone:\n", + " tz = pytz.timezone(timezone)\n", + " else:\n", + " tz = None\n", + "\n", + " current_time = datetime.datetime.now(tz) + datetime.timedelta(days=day)\n", + "\n", + " # Adding hours and minutes to the current time\n", + " current_time += datetime.timedelta(hours=hour, minutes=minute)\n", + "\n", + " # Format the date and time as required\n", + " formatted_time = current_time.strftime(\"%Y-%m-%d %H:%M\")\n", + "\n", + " return formatted_time\n", + "\n", + "s = Skill()\n", + "out = s.rl('time in berlin')\n", + "print(s.run(**out.function_call))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "semantic_router_3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 963c754a..658ed08d 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -18,23 +18,55 @@ class BaseLLM(BaseModel): def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") - + + + def _check_for_mandatory_inputs(self, inputs: dict[str, Any], mandatory_params: List[str]) -> bool: + """Check for mandatory parameters in inputs""" + for name in mandatory_params: + if name not in inputs: + logger.error(f"Mandatory input {name} missing from query") + return False + return True + + def _check_for_extra_inputs(self, inputs: dict[str, Any], all_params: List[str]) -> bool: + """Check for extra parameters not defined in the signature""" + input_keys = set(inputs.keys()) + param_keys = set(all_params) + if not input_keys.issubset(param_keys): + extra_keys = input_keys - param_keys + logger.error(f"Extra inputs provided that are not in the signature: {extra_keys}") + return False + return True + def _is_valid_inputs( self, inputs: dict[str, Any], function_schema: dict[str, Any] ) -> bool: """Validate the extracted inputs against the function schema""" try: - # Extract parameter names and types from the signature string + # Extract parameter names and determine if they are optional signature = function_schema["signature"] param_info = [param.strip() for param in signature[1:-1].split(",")] - param_names = [info.split(":")[0].strip() for info in param_info] - param_types = [ - info.split(":")[1].strip().split("=")[0].strip() for info in param_info - ] - for name, type_str in zip(param_names, param_types): - if name not in inputs: - logger.error(f"Input {name} missing from query") - return False + mandatory_params = [] + all_params = [] + + for info in param_info: + parts = info.split("=") + name_type_pair = parts[0].strip() + name = name_type_pair.split(":")[0].strip() + all_params.append(name) + + # If there is no default value, it's a mandatory parameter + if len(parts) == 1: + mandatory_params.append(name) + + # Check for mandatory parameters + if not self._check_for_mandatory_inputs(inputs, mandatory_params): + return False + + # Check for extra parameters not defined in the signature + if not self._check_for_extra_inputs(inputs, all_params): + return False + return True except Exception as e: logger.error(f"Input validation error: {str(e)}") -- GitLab