import inspect import json from typing import Any, Callable, Union from pydantic import BaseModel from semantic_router.utils.llm import llm from semantic_router.utils.logger import logger def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]: if isinstance(item, BaseModel): signature_parts = [] for field_name, field_model in item.__annotations__.items(): field_info = item.__fields__[field_name] default_value = field_info.default if default_value: default_repr = repr(default_value) signature_part = ( f"{field_name}: {field_model.__name__} = {default_repr}" ) else: signature_part = f"{field_name}: {field_model.__name__}" signature_parts.append(signature_part) signature = f"({', '.join(signature_parts)}) -> str" schema = { "name": item.__class__.__name__, "description": item.__doc__, "signature": signature, } else: schema = { "name": item.__name__, "description": str(inspect.getdoc(item)), "signature": str(inspect.signature(item)), "output": str(inspect.signature(item).return_annotation), } return schema def extract_function_inputs(query: str, function_schema: dict[str, Any]) -> dict: logger.info("Extracting function input...") prompt = f""" You are a helpful assistant designed to output JSON. Given the following function schema << {function_schema} >> and query << {query} >> extract the parameters values from the query, in a valid JSON format. Example: Input: query: "How is the weather in Hawaii right now in International units?" schema: {{ "name": "get_weather", "description": "Useful to get the weather in a specific location", "signature": "(location: str, degree: str) -> str", "output": "<class 'str'>", }} Result: {{ "location": "London", "degree": "Celsius", }} Input: query: {query} schema: {function_schema} Result: """ output = llm(prompt) if not output: raise Exception("No output generated for extract function input") output = output.replace("'", '"').strip().rstrip(",") function_inputs = json.loads(output) if not is_valid_inputs(function_inputs, function_schema): raise ValueError("Invalid inputs") return function_inputs def is_valid_inputs(inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool: """Validate the extracted inputs against the function schema""" try: # Extract parameter names and types from the signature string signature = function_schema["signature"] param_info = [param.strip() for param in signature[1:-1].split(",")] param_names = [info.split(":")[0].strip() for info in param_info] param_types = [ info.split(":")[1].strip().split("=")[0].strip() for info in param_info ] for name, type_str in zip(param_names, param_types): if name not in inputs: logger.error(f"Input {name} missing from query") return False return True except Exception as e: logger.error(f"Input validation error: {str(e)}") return False def call_function(function: Callable, inputs: dict[str, str]): try: return function(**inputs) except TypeError as e: logger.error(f"Error calling function: {e}") # TODO: Add route layer object to the input, solve circular import issue async def route_and_execute(query: str, functions: list[Callable], route_layer): function_name = route_layer(query) if not function_name: logger.warning("No function found, calling LLM...") return llm(query) for function in functions: if function.__name__ == function_name: print(f"Calling function: {function.__name__}") schema = get_schema(function) inputs = extract_function_inputs(query, schema) call_function(function, inputs)