Skip to content
Snippets Groups Projects
Unverified Commit 307cc1b1 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Fixed bug where non-dynamic routes for OpenAI LLM weren't working.

parent 6186e925
No related branches found
No related tags found
No related merge requests found
......@@ -58,22 +58,49 @@ class OpenAILLM(BaseLLM):
tools=tools, # type: ignore # MyPy expecting Iterable[ChatCompletionToolParam] | NotGiven, but dict is accepted by OpenAI.
)
tool_calls = completion.choices[0].message.tool_calls
if tool_calls is None:
raise ValueError("Invalid output, expected a tool call.")
if len(tool_calls) != 1:
raise ValueError(
"Invalid output, expected a single tool to be specified."
)
arguments = tool_calls[0].function.arguments
if arguments is None:
raise ValueError("Invalid output, expected arguments to be specified.")
output = str(arguments) # str to keep MyPy happy.
if function_schema:
tool_calls = completion.choices[0].message.tool_calls
if tool_calls is None:
raise ValueError("Invalid output, expected a tool call.")
if len(tool_calls) != 1:
raise ValueError(
"Invalid output, expected a single tool to be specified."
)
arguments = tool_calls[0].function.arguments
if arguments is None:
raise ValueError("Invalid output, expected arguments to be specified.")
output = str(arguments) # str to keep MyPy happy.
else:
content = completion.choices[0].message.content
if content is None:
raise ValueError("Invalid output, expected content.")
output = str(content) # str to keep MyPy happy.
return output
except Exception as e:
logger.error(f"LLM error: {e}")
raise Exception(f"LLM error: {e}") from e
def __call__(self, messages: List[Message]) -> str:
if self.client is None:
raise ValueError("OpenAI client is not initialized.")
try:
completion = self.client.chat.completions.create(
model=self.name,
messages=[m.to_openai() for m in messages],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
output = completion.choices[0].message.content
if not output:
raise Exception("No output generated")
return output
except Exception as e:
logger.error(f"LLM error: {e}")
raise Exception(f"LLM error: {e}") from e
def extract_function_inputs(
self, query: str, function_schema: dict[str, Any]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment