Skip to content
Snippets Groups Projects
Unverified Commit 8392895c authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

_is_valid_inputs() fixes.

Now no longer requires typehints in the function signature (it wasn't using these anyway, and would break when they weren't included.

Also, we now only check if mandatary arguments have been provided in input. None mandatory don't need to be present.

Finally, addde a check to ensure that, if there are extra arguments provided in input not present in the signature, then these result in false being returned.
parent 705e6043
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import datetime
import pytz
from semantic_router.llms.openrouter import OpenRouterLLM
from semantic_router import Route, RouteLayer
from semantic_router.encoders import HuggingFaceEncoder
from semantic_router.utils.function_call import get_schema
import geonamescache
class Skill:
def __init__(self):
self.geocoder = geonamescache.GeonamesCache()
self.location = self.geocode_location()
self.route = Route(
name='time',
utterances=[
"tell me what is the time",
"what is the date ",
"time in varshava",
"date",
"what date is it today",
"time in ny",
"what is the time and date in boston",
"time",
"what is the time in makhachkala",
"date time in st petersburg",
"what's the date in vienna",
"date time"
],
function_schema=get_schema(self.run),
)
self.rl = RouteLayer(
encoder=HuggingFaceEncoder(),
routes=[self.route],
llm=OpenRouterLLM(
name='mistralai/mistral-7b-instruct:free',
openrouter_api_key='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'
)
)
def geocode_location(self, location_name=None):
if location_name:
location_name = location_name.title()
location = list(
self.geocoder.get_cities_by_name(location_name)[0].values() if self.geocoder.get_cities_by_name(
location_name) else self.geocoder.get_us_states_by_names(location_name)[
0].values() if self.geocoder.get_us_states_by_names(location_name) else
self.geocoder.get_countries_by_names(location_name)[
0].values() if self.geocoder.get_countries_by_names(location_name) else None)[0]
return location['timezone']
else:
return ''
def run(self, location:str=None, day:int=0, hour:int=0, minute:int=0):
"""Finds the current time in a specific location.
:param location: The location to find the current time in, should
be a valid location. Put the place name itself
like "rome", or "new york" in the lowercase.
:type location: str
:param day: The offset in days from the current date.
Use positive integers for future dates (e.g., day=1 for tomorrow),
negative integers for past dates (e.g., day=-1 for yesterday),
and 0 for the current date.
:type day: int
:param hour: The offset in hours from the current time.
Use positive integers for future times (e.g., hour=1 for one hour ahead),
negative integers for past times (e.g., hour=-1 for one hour ago),
and 0 to maintain the current hour.
:type hour: int
:param minute: The offset in minutes from the current time.
Use positive integers for future minutes (e.g., minute=20 for twenty minutes ahead),
negative integers for past minutes (e.g., minute=-20 for twenty minutes ago),
and 0 to maintain the current minute.
:type minute: int
:return: The time in the specified location."""
timezone = self.geocode_location(location)
if timezone:
tz = pytz.timezone(timezone)
else:
tz = None
current_time = datetime.datetime.now(tz) + datetime.timedelta(days=day)
# Adding hours and minutes to the current time
current_time += datetime.timedelta(hours=hour, minutes=minute)
# Format the date and time as required
formatted_time = current_time.strftime("%Y-%m-%d %H:%M")
return formatted_time
s = Skill()
out = s.rl('time in berlin')
print(s.run(**out.function_call))
```
%% Output
c:\Users\Siraj\Documents\Personal\Work\Aurelio\Virtual Environments\semantic_router_3\Lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
2024-05-13 17:10:58 INFO semantic_router.utils.logger local
2024-05-13 17:10:59 INFO semantic_router.utils.logger Extracting function input...
2024-05-13 17:11:01 INFO semantic_router.utils.logger LLM output: {
"location": "berlin"
}
2024-05-13 17:11:01 INFO semantic_router.utils.logger Function inputs: {'location': 'berlin'}
2024-05-13 15:11
%% Cell type:code id: tags:
``` python
```
......@@ -18,23 +18,55 @@ class BaseLLM(BaseModel):
def __call__(self, messages: List[Message]) -> Optional[str]:
raise NotImplementedError("Subclasses must implement this method")
def _check_for_mandatory_inputs(self, inputs: dict[str, Any], mandatory_params: List[str]) -> bool:
"""Check for mandatory parameters in inputs"""
for name in mandatory_params:
if name not in inputs:
logger.error(f"Mandatory input {name} missing from query")
return False
return True
def _check_for_extra_inputs(self, inputs: dict[str, Any], all_params: List[str]) -> bool:
"""Check for extra parameters not defined in the signature"""
input_keys = set(inputs.keys())
param_keys = set(all_params)
if not input_keys.issubset(param_keys):
extra_keys = input_keys - param_keys
logger.error(f"Extra inputs provided that are not in the signature: {extra_keys}")
return False
return True
def _is_valid_inputs(
self, inputs: dict[str, Any], function_schema: dict[str, Any]
) -> bool:
"""Validate the extracted inputs against the function schema"""
try:
# Extract parameter names and types from the signature string
# Extract parameter names and determine if they are optional
signature = function_schema["signature"]
param_info = [param.strip() for param in signature[1:-1].split(",")]
param_names = [info.split(":")[0].strip() for info in param_info]
param_types = [
info.split(":")[1].strip().split("=")[0].strip() for info in param_info
]
for name, type_str in zip(param_names, param_types):
if name not in inputs:
logger.error(f"Input {name} missing from query")
return False
mandatory_params = []
all_params = []
for info in param_info:
parts = info.split("=")
name_type_pair = parts[0].strip()
name = name_type_pair.split(":")[0].strip()
all_params.append(name)
# If there is no default value, it's a mandatory parameter
if len(parts) == 1:
mandatory_params.append(name)
# Check for mandatory parameters
if not self._check_for_mandatory_inputs(inputs, mandatory_params):
return False
# Check for extra parameters not defined in the signature
if not self._check_for_extra_inputs(inputs, all_params):
return False
return True
except Exception as e:
logger.error(f"Input validation error: {str(e)}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment