Skip to content
Snippets Groups Projects
Unverified Commit 17230b39 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Revert "Different methods of obtaining arguments for OpenAI"

This reverts commit 47a0fa3d.
parent 47a0fa3d
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ import os
from typing import List, Optional, Any, Callable, Dict
import openai
from openai._types import NotGiven
from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
......@@ -63,8 +64,8 @@ class OpenAILLM(BaseLLM):
if function_schemas:
tools = function_schemas
else:
tools = None
tools = NotGiven
completion = self.client.chat.completions.create(
model=self.name,
messages=[m.to_openai() for m in messages],
......@@ -95,192 +96,15 @@ class OpenAILLM(BaseLLM):
except Exception as e:
logger.error(f"LLM error: {e}")
raise Exception(f"LLM error: {e}") from e
def _extract_multiple_function_inputs(self, query: str, function_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
prompt = f"""
You are an accurate and reliable computer program that only outputs valid JSON.
Your task is to:
1) Pick the most relevant Python function schema(s) from FUNCTION_SCHEMAS below, based on the input QUERY. If only one schema is provided, choose that. If multiple schemas are relevant, output a list of JSON objects for each.
2) Output JSON representing the input arguments of the chosen function schema(s), including the function name, with argument values determined by information in the QUERY.
These are the Python functions' schema:
### FUNCTION_SCHEMAS Start ###
{json.dumps([schema['function'] for schema in function_schemas], indent=4)}
### FUNCTION_SCHEMAS End ###
This is the input query.
### QUERY Start ###
{query}
### QUERY End ###
The arguments that you need to provide values for, together with their datatypes, are stated in the "parameters" in the FUNCTION_SCHEMA.
The values these arguments must take are made clear by the QUERY.
Use the FUNCTION_SCHEMA "description" too, as this might provide helpful clues about the arguments and their values.
Include the function name in your JSON output.
Return only JSON, stating the function name and the argument names with their corresponding values.
### FORMATTING_INSTRUCTIONS Start ###
Return a response in valid JSON format. Do not return any other explanation or text, just the JSON.
The JSON output should always be an array of JSON objects. If only one function is relevant, return an array with a single JSON object.
Each JSON object should include a key 'function_name' with the value being the name of the function.
Under the key 'arguments', include a nested JSON object where the keys are the names of the arguments and the values are the values those arguments should take.
### FORMATTING_INSTRUCTIONS End ###
### EXAMPLE Start ###
=== EXAMPLE_INPUT_QUERY Start ===
"What is the temperature in Hawaii and New York right now in Celsius, and what is the humidity in Hawaii?"
=== EXAMPLE_INPUT_QUERY End ===
=== EXAMPLE_INPUT_SCHEMA Start ===
{{
"name": "get_temperature",
"description": "Useful to get the temperature in a specific location",
"parameters": {{
"type": "object",
"properties": {{
"location": {{
"type": "string",
"description": "The location to get the temperature from."
}},
"degree": {{
"type": "string",
"description": "The degree type, e.g., Celsius or Fahrenheit."
}}
}},
"required": ["location", "degree"]
}}
}}
=== EXAMPLE_INPUT_SCHEMA End ===
=== EXAMPLE_OUTPUT Start ===
[
{{
"function_name": "get_temperature",
"arguments": {{
"location": "Hawaii",
"degree": "Celsius"
}}
}},
{{
"function_name": "get_temperature",
"arguments": {{
"location": "New York",
"degree": "Celsius"
}}
}},
{{
"function_name": "get_humidity",
"arguments": {{
"location": "Hawaii"
}}
}}
]
=== EXAMPLE_OUTPUT End ===
### EXAMPLE End ###
Note: I will tip $500 for an accurate JSON output. You will be penalized for an inaccurate JSON output.
Provide JSON output now:
"""
llm_input = [Message(role="user", content=prompt)]
output = self(llm_input)
if not output:
raise Exception("No output generated for extract function input")
output = output.replace("'", '"').strip().rstrip(",")
logger.info(f"LLM output: {output}")
function_inputs = json.loads(output)
if not isinstance(function_inputs, list): # Local LLMs return a single JSON object that isn't in an array sometimes.
function_inputs = [function_inputs]
logger.info(f"Function inputs: {function_inputs}")
if not self._is_valid_inputs(function_inputs, function_schemas):
raise ValueError("Invalid inputs")
return function_inputs
def _extract_single_function_input(self, query: str, function_schemas: Dict[str, Any]) -> Dict[str, Any]:
def extract_function_inputs(
self, query: str, function_schemas: List[Dict[str, Any]]
) -> Dict:
messages = []
system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request."
messages.append(Message(role="system", content=system_prompt))
messages.append(Message(role="user", content=query))
function_inputs = self(messages=messages, function_schemas=function_schemas)
if not self._is_valid_inputs(function_inputs, function_schemas):
raise ValueError("Invalid inputs")
return function_inputs
def extract_function_inputs(
self, query: str, function_schemas: List[Dict[str, Any]]
) -> Dict:
if len(function_schemas) == 0:
raise ValueError("No function schemas provided")
elif len(function_schemas) == 1:
logger.info("Extracting single function input...")
return self._extract_single_function_input(query, function_schemas)
else:
logger.info("Extracting multiple function inputs...")
return self._extract_multiple_function_inputs(query, function_schemas)
def _is_valid_inputs(
self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]]
) -> bool:
"""Determine if the functions chosen by the LLM exist within the function_schemas,
and if the input arguments are valid for those functions."""
try:
for input_dict in inputs:
# Check if 'function_name' and 'arguments' keys exist in each input dictionary
if "function_name" not in input_dict or "arguments" not in input_dict:
logger.error("Missing 'function_name' or 'arguments' in inputs")
return False
function_name = input_dict["function_name"]
arguments = input_dict["arguments"]
# Find the matching function schema based on function_name
matching_schema = next((schema['function'] for schema in function_schemas if schema['function']['name'] == function_name), None)
if not matching_schema:
logger.error(f"No matching function schema found for function name: {function_name}")
return False
# Validate the inputs against the function schema
if not self._validate_single_function_inputs(arguments, matching_schema):
logger.error(f"Validation failed for function name: {function_name}")
return False
return True
except Exception as e:
logger.error(f"Input validation error: {str(e)}")
return False
def _validate_single_function_inputs(self, inputs: Dict[str, Any], function_schema: Dict[str, Any]) -> bool:
"""Validate the extracted inputs against the function schema"""
try:
# Access the parameters and their properties from the function schema directly
parameters = function_schema['parameters']['properties']
required_params = function_schema['parameters'].get('required', [])
# Check if all required parameters are present in the inputs
for param_name in required_params:
if param_name not in inputs:
logger.error(f"Required input '{param_name}' missing from query")
return False
# Check if the types of the inputs match the expected types (if type checking is needed)
for param_name, param_info in parameters.items():
if param_name in inputs:
expected_type = param_info['type']
# This is a simple type check, consider expanding it based on your needs
if expected_type == 'string' and not isinstance(inputs[param_name], str):
logger.error(f"Input type for '{param_name}' is not {expected_type}")
return False
return True
except Exception as e:
logger.error(f"Single input validation error: {str(e)}")
return False
return self(messages=messages, function_schemas=function_schemas)
def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]:
schemas = []
......@@ -330,4 +154,3 @@ def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]:
return schemas
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment