From 307cc1b133711939c6bc3f373aa458a746b992d1 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Tue, 30 Apr 2024 02:19:53 +0400 Subject: [PATCH] Fixed bug where non-dynamic routes for OpenAI LLM weren't working. --- semantic_router/llms/openai.py | 49 ++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index dcce3b45..920d954f 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -58,22 +58,49 @@ class OpenAILLM(BaseLLM): tools=tools, # type: ignore # MyPy expecting Iterable[ChatCompletionToolParam] | NotGiven, but dict is accepted by OpenAI. ) - tool_calls = completion.choices[0].message.tool_calls - if tool_calls is None: - raise ValueError("Invalid output, expected a tool call.") - if len(tool_calls) != 1: - raise ValueError( - "Invalid output, expected a single tool to be specified." - ) - arguments = tool_calls[0].function.arguments - if arguments is None: - raise ValueError("Invalid output, expected arguments to be specified.") - output = str(arguments) # str to keep MyPy happy. + if function_schema: + tool_calls = completion.choices[0].message.tool_calls + if tool_calls is None: + raise ValueError("Invalid output, expected a tool call.") + if len(tool_calls) != 1: + raise ValueError( + "Invalid output, expected a single tool to be specified." + ) + arguments = tool_calls[0].function.arguments + if arguments is None: + raise ValueError("Invalid output, expected arguments to be specified.") + output = str(arguments) # str to keep MyPy happy. + else: + content = completion.choices[0].message.content + if content is None: + raise ValueError("Invalid output, expected content.") + output = str(content) # str to keep MyPy happy. + return output except Exception as e: logger.error(f"LLM error: {e}") raise Exception(f"LLM error: {e}") from e + + def __call__(self, messages: List[Message]) -> str: + if self.client is None: + raise ValueError("OpenAI client is not initialized.") + try: + completion = self.client.chat.completions.create( + model=self.name, + messages=[m.to_openai() for m in messages], + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + output = completion.choices[0].message.content + + if not output: + raise Exception("No output generated") + return output + except Exception as e: + logger.error(f"LLM error: {e}") + raise Exception(f"LLM error: {e}") from e def extract_function_inputs( self, query: str, function_schema: dict[str, Any] -- GitLab