Code owners
Assign users and groups as approvers for specific file changes. Learn more.
function_call.py 3.91 KiB
import inspect
import json
from typing import Any, Callable, Union
from pydantic import BaseModel
from semantic_router.llms import BaseLLM
from semantic_router.schema import Message, RouteChoice
from semantic_router.utils.logger import logger
def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]:
if isinstance(item, BaseModel):
signature_parts = []
for field_name, field_model in item.__annotations__.items():
field_info = item.__fields__[field_name]
default_value = field_info.default
if default_value:
default_repr = repr(default_value)
signature_part = (
f"{field_name}: {field_model.__name__} = {default_repr}"
)
else:
signature_part = f"{field_name}: {field_model.__name__}"
signature_parts.append(signature_part)
signature = f"({', '.join(signature_parts)}) -> str"
schema = {
"name": item.__class__.__name__,
"description": item.__doc__,
"signature": signature,
}
else:
schema = {
"name": item.__name__,
"description": str(inspect.getdoc(item)),
"signature": str(inspect.signature(item)),
"output": str(inspect.signature(item).return_annotation),
}
return schema
def extract_function_inputs(
query: str, llm: BaseLLM, function_schema: dict[str, Any]
) -> dict:
logger.info("Extracting function input...")
prompt = f"""
You are a helpful assistant designed to output JSON.
Given the following function schema
<< {function_schema} >>
and query
<< {query} >>
extract the parameters values from the query, in a valid JSON format.
Example:
Input:
query: "How is the weather in Hawaii right now in International units?"
schema:
{{
"name": "get_weather",
"description": "Useful to get the weather in a specific location",
"signature": "(location: str, degree: str) -> float",
"output": "<class 'float'>",
}}
Result: {{
"location": "Hawaii",
"degree": "Kelvin",
}}
Input:
query: {query}
schema: {function_schema}
Result:
"""
llm_input = [Message(role="user", content=prompt)]
output = llm(llm_input)
if not output:
raise Exception("No output generated for extract function input")
output = output.replace("'", '"').strip().rstrip(",")
function_inputs = json.loads(output)
if not is_valid_inputs(function_inputs, function_schema):
raise ValueError("Invalid inputs")
return function_inputs
def is_valid_inputs(inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool:
"""Validate the extracted inputs against the function schema"""
try:
# Extract parameter names and types from the signature string
signature = function_schema["signature"]
param_info = [param.strip() for param in signature[1:-1].split(",")]
param_names = [info.split(":")[0].strip() for info in param_info]
param_types = [
info.split(":")[1].strip().split("=")[0].strip() for info in param_info
]
for name, type_str in zip(param_names, param_types):
if name not in inputs:
logger.error(f"Input {name} missing from query")
return False
return True
except Exception as e:
logger.error(f"Input validation error: {str(e)}")
return False
# TODO: Add route layer object to the input, solve circular import issue
async def route_and_execute(
query: str, llm: BaseLLM, functions: list[Callable], layer
) -> Any:
route_choice: RouteChoice = layer(query)
for function in functions:
if function.__name__ == route_choice.name:
if route_choice.function_call:
return function(**route_choice.function_call)
logger.warning("No function found, calling LLM.")
llm_input = [Message(role="user", content=query)]
return llm(llm_input)