diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 1e121a3e9952b71bb1f4ca812266cd9db24d584a..6b1bebba65cb9c88c6ab7afbcfc352b75aa21467 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -2,7 +2,6 @@ import os from typing import List, Optional, Any, Callable, Dict import openai -from openai._types import NotGiven from semantic_router.llms import BaseLLM from semantic_router.schema import Message @@ -64,8 +63,8 @@ class OpenAILLM(BaseLLM): if function_schemas: tools = function_schemas else: - tools = NotGiven - + tools = None + completion = self.client.chat.completions.create( model=self.name, messages=[m.to_openai() for m in messages], @@ -96,15 +95,192 @@ class OpenAILLM(BaseLLM): except Exception as e: logger.error(f"LLM error: {e}") raise Exception(f"LLM error: {e}") from e + + + def _extract_multiple_function_inputs(self, query: str, function_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - def extract_function_inputs( - self, query: str, function_schemas: List[Dict[str, Any]] - ) -> Dict: + prompt = f""" +You are an accurate and reliable computer program that only outputs valid JSON. +Your task is to: + 1) Pick the most relevant Python function schema(s) from FUNCTION_SCHEMAS below, based on the input QUERY. If only one schema is provided, choose that. If multiple schemas are relevant, output a list of JSON objects for each. + 2) Output JSON representing the input arguments of the chosen function schema(s), including the function name, with argument values determined by information in the QUERY. + +These are the Python functions' schema: + +### FUNCTION_SCHEMAS Start ### + {json.dumps([schema['function'] for schema in function_schemas], indent=4)} +### FUNCTION_SCHEMAS End ### + +This is the input query. + +### QUERY Start ### + {query} +### QUERY End ### + +The arguments that you need to provide values for, together with their datatypes, are stated in the "parameters" in the FUNCTION_SCHEMA. +The values these arguments must take are made clear by the QUERY. +Use the FUNCTION_SCHEMA "description" too, as this might provide helpful clues about the arguments and their values. +Include the function name in your JSON output. +Return only JSON, stating the function name and the argument names with their corresponding values. + +### FORMATTING_INSTRUCTIONS Start ### + Return a response in valid JSON format. Do not return any other explanation or text, just the JSON. + The JSON output should always be an array of JSON objects. If only one function is relevant, return an array with a single JSON object. + Each JSON object should include a key 'function_name' with the value being the name of the function. + Under the key 'arguments', include a nested JSON object where the keys are the names of the arguments and the values are the values those arguments should take. +### FORMATTING_INSTRUCTIONS End ### + +### EXAMPLE Start ### + === EXAMPLE_INPUT_QUERY Start === + "What is the temperature in Hawaii and New York right now in Celsius, and what is the humidity in Hawaii?" + === EXAMPLE_INPUT_QUERY End === + === EXAMPLE_INPUT_SCHEMA Start === + {{ + "name": "get_temperature", + "description": "Useful to get the temperature in a specific location", + "parameters": {{ + "type": "object", + "properties": {{ + "location": {{ + "type": "string", + "description": "The location to get the temperature from." + }}, + "degree": {{ + "type": "string", + "description": "The degree type, e.g., Celsius or Fahrenheit." + }} + }}, + "required": ["location", "degree"] + }} + }} + === EXAMPLE_INPUT_SCHEMA End === + === EXAMPLE_OUTPUT Start === + [ + {{ + "function_name": "get_temperature", + "arguments": {{ + "location": "Hawaii", + "degree": "Celsius" + }} + }}, + {{ + "function_name": "get_temperature", + "arguments": {{ + "location": "New York", + "degree": "Celsius" + }} + }}, + {{ + "function_name": "get_humidity", + "arguments": {{ + "location": "Hawaii" + }} + }} + ] + === EXAMPLE_OUTPUT End === +### EXAMPLE End ### + +Note: I will tip $500 for an accurate JSON output. You will be penalized for an inaccurate JSON output. + +Provide JSON output now: +""" + llm_input = [Message(role="user", content=prompt)] + output = self(llm_input) + if not output: + raise Exception("No output generated for extract function input") + + output = output.replace("'", '"').strip().rstrip(",") + logger.info(f"LLM output: {output}") + function_inputs = json.loads(output) + if not isinstance(function_inputs, list): # Local LLMs return a single JSON object that isn't in an array sometimes. + function_inputs = [function_inputs] + logger.info(f"Function inputs: {function_inputs}") + if not self._is_valid_inputs(function_inputs, function_schemas): + raise ValueError("Invalid inputs") + return function_inputs + + def _extract_single_function_input(self, query: str, function_schemas: Dict[str, Any]) -> Dict[str, Any]: messages = [] system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." messages.append(Message(role="system", content=system_prompt)) messages.append(Message(role="user", content=query)) - return self(messages=messages, function_schemas=function_schemas) + function_inputs = self(messages=messages, function_schemas=function_schemas) + + if not self._is_valid_inputs(function_inputs, function_schemas): + raise ValueError("Invalid inputs") + return function_inputs + + def extract_function_inputs( + self, query: str, function_schemas: List[Dict[str, Any]] + ) -> Dict: + + if len(function_schemas) == 0: + raise ValueError("No function schemas provided") + elif len(function_schemas) == 1: + logger.info("Extracting single function input...") + return self._extract_single_function_input(query, function_schemas) + else: + logger.info("Extracting multiple function inputs...") + return self._extract_multiple_function_inputs(query, function_schemas) + + def _is_valid_inputs( + self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]] + ) -> bool: + """Determine if the functions chosen by the LLM exist within the function_schemas, + and if the input arguments are valid for those functions.""" + try: + for input_dict in inputs: + # Check if 'function_name' and 'arguments' keys exist in each input dictionary + if "function_name" not in input_dict or "arguments" not in input_dict: + logger.error("Missing 'function_name' or 'arguments' in inputs") + return False + + function_name = input_dict["function_name"] + arguments = input_dict["arguments"] + + # Find the matching function schema based on function_name + matching_schema = next((schema['function'] for schema in function_schemas if schema['function']['name'] == function_name), None) + if not matching_schema: + logger.error(f"No matching function schema found for function name: {function_name}") + return False + + # Validate the inputs against the function schema + if not self._validate_single_function_inputs(arguments, matching_schema): + logger.error(f"Validation failed for function name: {function_name}") + return False + + return True + except Exception as e: + logger.error(f"Input validation error: {str(e)}") + return False + + def _validate_single_function_inputs(self, inputs: Dict[str, Any], function_schema: Dict[str, Any]) -> bool: + """Validate the extracted inputs against the function schema""" + try: + # Access the parameters and their properties from the function schema directly + parameters = function_schema['parameters']['properties'] + required_params = function_schema['parameters'].get('required', []) + + # Check if all required parameters are present in the inputs + for param_name in required_params: + if param_name not in inputs: + logger.error(f"Required input '{param_name}' missing from query") + return False + + # Check if the types of the inputs match the expected types (if type checking is needed) + for param_name, param_info in parameters.items(): + if param_name in inputs: + expected_type = param_info['type'] + # This is a simple type check, consider expanding it based on your needs + if expected_type == 'string' and not isinstance(inputs[param_name], str): + logger.error(f"Input type for '{param_name}' is not {expected_type}") + return False + + return True + except Exception as e: + logger.error(f"Single input validation error: {str(e)}") + return False + def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: schemas = [] @@ -154,3 +330,4 @@ def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: return schemas +