Skip to content
Snippets Groups Projects
Unverified Commit 0ca020bc authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Linting.

parent 6ed79ba2
No related branches found
No related tags found
No related merge requests found
...@@ -63,16 +63,20 @@ class OpenAILLM(BaseLLM): ...@@ -63,16 +63,20 @@ class OpenAILLM(BaseLLM):
if tool_calls is None: if tool_calls is None:
raise ValueError("Invalid output, expected a tool call.") raise ValueError("Invalid output, expected a tool call.")
if len(tool_calls) != 1: if len(tool_calls) != 1:
raise ValueError("Invalid output, expected a single tool to be specified.") raise ValueError(
"Invalid output, expected a single tool to be specified."
)
arguments = tool_calls[0].function.arguments arguments = tool_calls[0].function.arguments
if arguments is None: if arguments is None:
raise ValueError("Invalid output, expected arguments to be specified.") raise ValueError(
output = str(arguments) # str to keep MyPy happy. "Invalid output, expected arguments to be specified."
)
output = str(arguments) # str to keep MyPy happy.
else: else:
content = completion.choices[0].message.content content = completion.choices[0].message.content
if content is None: if content is None:
raise ValueError("Invalid output, expected content.") raise ValueError("Invalid output, expected content.")
output = str(content) # str to keep MyPy happy. output = str(content) # str to keep MyPy happy.
return output return output
except Exception as e: except Exception as e:
logger.error(f"LLM error: {e}") logger.error(f"LLM error: {e}")
...@@ -85,6 +89,8 @@ class OpenAILLM(BaseLLM): ...@@ -85,6 +89,8 @@ class OpenAILLM(BaseLLM):
system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request."
messages.append(Message(role="system", content=system_prompt)) messages.append(Message(role="system", content=system_prompt))
messages.append(Message(role="user", content=query)) messages.append(Message(role="user", content=query))
function_inputs_str = self(messages=messages, openai_function_schema=openai_function_schema) function_inputs_str = self(
messages=messages, openai_function_schema=openai_function_schema
)
function_inputs = json.loads(function_inputs_str) function_inputs = json.loads(function_inputs_str)
return function_inputs return function_inputs
...@@ -90,13 +90,13 @@ def get_schema_openai_func_calling(item: Callable) -> Dict[str, Any]: ...@@ -90,13 +90,13 @@ def get_schema_openai_func_calling(item: Callable) -> Dict[str, Any]:
if match: if match:
param_description = match.group(1).strip() param_description = match.group(1).strip()
schema["function"]["parameters"]["properties"][param_name] = { schema["function"]["parameters"]["properties"][param_name] = { # type: ignore
"type": convert_param_type_to_json_type(param_type), "type": convert_param_type_to_json_type(param_type),
"description": param_description, "description": param_description,
} }
if param_required: if param_required:
schema["function"]["parameters"]["required"].append(param_name) schema["function"]["parameters"]["required"].append(param_name) # type: ignore
return schema return schema
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment