diff --git a/docs/10-debugging-discord-issue.ipynb b/docs/10-debugging-discord-issue.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9c781d4a7a40d10680ff8c50177d3664e2d9a0d1 --- /dev/null +++ b/docs/10-debugging-discord-issue.ipynb @@ -0,0 +1,163 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "\u001b[32m2024-05-13 17:10:58 INFO semantic_router.utils.logger local\u001b[0m\n", + "\u001b[32m2024-05-13 17:10:59 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n", + "\u001b[32m2024-05-13 17:11:01 INFO semantic_router.utils.logger LLM output: {\n", + " \"location\": \"berlin\"\n", + "}\u001b[0m\n", + "\u001b[32m2024-05-13 17:11:01 INFO semantic_router.utils.logger Function inputs: {'location': 'berlin'}\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-05-13 15:11\n" + ] + } + ], + "source": [ + "import datetime\n", + "import pytz\n", + "from semantic_router.llms.openrouter import OpenRouterLLM\n", + "from semantic_router import Route, RouteLayer\n", + "from semantic_router.encoders import HuggingFaceEncoder\n", + "from semantic_router.utils.function_call import get_schema\n", + "import geonamescache\n", + "\n", + "class Skill:\n", + " def __init__(self):\n", + " self.geocoder = geonamescache.GeonamesCache()\n", + " self.location = self.geocode_location()\n", + " self.route = Route(\n", + " name='time',\n", + " utterances=[\n", + " \"tell me what is the time\",\n", + " \"what is the date \",\n", + " \"time in varshava\",\n", + " \"date\",\n", + " \"what date is it today\",\n", + " \"time in ny\",\n", + " \"what is the time and date in boston\",\n", + " \"time\",\n", + " \"what is the time in makhachkala\",\n", + " \"date time in st petersburg\",\n", + " \"what's the date in vienna\",\n", + " \"date time\"\n", + " ],\n", + " function_schema=get_schema(self.run),\n", + "\n", + " )\n", + "\n", + " self.rl = RouteLayer(\n", + " encoder=HuggingFaceEncoder(),\n", + " routes=[self.route],\n", + " llm=OpenRouterLLM(\n", + " name='mistralai/mistral-7b-instruct:free',\n", + " openrouter_api_key='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'\n", + "\n", + " )\n", + " )\n", + "\n", + " def geocode_location(self, location_name=None):\n", + " if location_name:\n", + " location_name = location_name.title()\n", + " location = list(\n", + " self.geocoder.get_cities_by_name(location_name)[0].values() if self.geocoder.get_cities_by_name(\n", + " location_name) else self.geocoder.get_us_states_by_names(location_name)[\n", + " 0].values() if self.geocoder.get_us_states_by_names(location_name) else\n", + " self.geocoder.get_countries_by_names(location_name)[\n", + " 0].values() if self.geocoder.get_countries_by_names(location_name) else None)[0]\n", + " return location['timezone']\n", + " else:\n", + " return ''\n", + "\n", + " def run(self, location:str=None, day:int=0, hour:int=0, minute:int=0):\n", + " \"\"\"Finds the current time in a specific location.\n", + "\n", + " :param location: The location to find the current time in, should\n", + " be a valid location. Put the place name itself\n", + " like \"rome\", or \"new york\" in the lowercase.\n", + " :type location: str\n", + "\n", + " :param day: The offset in days from the current date.\n", + " Use positive integers for future dates (e.g., day=1 for tomorrow),\n", + " negative integers for past dates (e.g., day=-1 for yesterday),\n", + " and 0 for the current date.\n", + " :type day: int\n", + "\n", + " :param hour: The offset in hours from the current time.\n", + " Use positive integers for future times (e.g., hour=1 for one hour ahead),\n", + " negative integers for past times (e.g., hour=-1 for one hour ago),\n", + " and 0 to maintain the current hour.\n", + " :type hour: int\n", + "\n", + " :param minute: The offset in minutes from the current time.\n", + " Use positive integers for future minutes (e.g., minute=20 for twenty minutes ahead),\n", + " negative integers for past minutes (e.g., minute=-20 for twenty minutes ago),\n", + " and 0 to maintain the current minute.\n", + " :type minute: int\n", + "\n", + " :return: The time in the specified location.\"\"\"\n", + " timezone = self.geocode_location(location)\n", + " if timezone:\n", + " tz = pytz.timezone(timezone)\n", + " else:\n", + " tz = None\n", + "\n", + " current_time = datetime.datetime.now(tz) + datetime.timedelta(days=day)\n", + "\n", + " # Adding hours and minutes to the current time\n", + " current_time += datetime.timedelta(hours=hour, minutes=minute)\n", + "\n", + " # Format the date and time as required\n", + " formatted_time = current_time.strftime(\"%Y-%m-%d %H:%M\")\n", + "\n", + " return formatted_time\n", + "\n", + "s = Skill()\n", + "out = s.rl('time in berlin')\n", + "print(s.run(**out.function_call))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "semantic_router_3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 963c754a1c49ec4c9d2308562dd6f6f17e090090..658ed08da6a1015df839f0c349524ad955605108 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -18,23 +18,55 @@ class BaseLLM(BaseModel): def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") - + + + def _check_for_mandatory_inputs(self, inputs: dict[str, Any], mandatory_params: List[str]) -> bool: + """Check for mandatory parameters in inputs""" + for name in mandatory_params: + if name not in inputs: + logger.error(f"Mandatory input {name} missing from query") + return False + return True + + def _check_for_extra_inputs(self, inputs: dict[str, Any], all_params: List[str]) -> bool: + """Check for extra parameters not defined in the signature""" + input_keys = set(inputs.keys()) + param_keys = set(all_params) + if not input_keys.issubset(param_keys): + extra_keys = input_keys - param_keys + logger.error(f"Extra inputs provided that are not in the signature: {extra_keys}") + return False + return True + def _is_valid_inputs( self, inputs: dict[str, Any], function_schema: dict[str, Any] ) -> bool: """Validate the extracted inputs against the function schema""" try: - # Extract parameter names and types from the signature string + # Extract parameter names and determine if they are optional signature = function_schema["signature"] param_info = [param.strip() for param in signature[1:-1].split(",")] - param_names = [info.split(":")[0].strip() for info in param_info] - param_types = [ - info.split(":")[1].strip().split("=")[0].strip() for info in param_info - ] - for name, type_str in zip(param_names, param_types): - if name not in inputs: - logger.error(f"Input {name} missing from query") - return False + mandatory_params = [] + all_params = [] + + for info in param_info: + parts = info.split("=") + name_type_pair = parts[0].strip() + name = name_type_pair.split(":")[0].strip() + all_params.append(name) + + # If there is no default value, it's a mandatory parameter + if len(parts) == 1: + mandatory_params.append(name) + + # Check for mandatory parameters + if not self._check_for_mandatory_inputs(inputs, mandatory_params): + return False + + # Check for extra parameters not defined in the signature + if not self._check_for_extra_inputs(inputs, all_params): + return False + return True except Exception as e: logger.error(f"Input validation error: {str(e)}")