diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 4e88030b16ae62b2e5da8588fd001de3ae71a2ce..2b41f17a06b3bbe15ba8a3343792e36ff545852e 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -63,16 +63,20 @@ class OpenAILLM(BaseLLM): if tool_calls is None: raise ValueError("Invalid output, expected a tool call.") if len(tool_calls) != 1: - raise ValueError("Invalid output, expected a single tool to be specified.") + raise ValueError( + "Invalid output, expected a single tool to be specified." + ) arguments = tool_calls[0].function.arguments if arguments is None: - raise ValueError("Invalid output, expected arguments to be specified.") - output = str(arguments) # str to keep MyPy happy. + raise ValueError( + "Invalid output, expected arguments to be specified." + ) + output = str(arguments) # str to keep MyPy happy. else: content = completion.choices[0].message.content if content is None: raise ValueError("Invalid output, expected content.") - output = str(content) # str to keep MyPy happy. + output = str(content) # str to keep MyPy happy. return output except Exception as e: logger.error(f"LLM error: {e}") @@ -85,6 +89,8 @@ class OpenAILLM(BaseLLM): system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." messages.append(Message(role="system", content=system_prompt)) messages.append(Message(role="user", content=query)) - function_inputs_str = self(messages=messages, openai_function_schema=openai_function_schema) + function_inputs_str = self( + messages=messages, openai_function_schema=openai_function_schema + ) function_inputs = json.loads(function_inputs_str) return function_inputs diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index 562da07b35018dcb8d6621f164578dbfa477afb7..cddbff3b2f4be7970b3094e3feec56a05b8ad8cf 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -90,13 +90,13 @@ def get_schema_openai_func_calling(item: Callable) -> Dict[str, Any]: if match: param_description = match.group(1).strip() - schema["function"]["parameters"]["properties"][param_name] = { + schema["function"]["parameters"]["properties"][param_name] = { # type: ignore "type": convert_param_type_to_json_type(param_type), "description": param_description, } if param_required: - schema["function"]["parameters"]["required"].append(param_name) + schema["function"]["parameters"]["required"].append(param_name) # type: ignore return schema