import os from typing import List, Optional, Any import openai from semantic_router.llms import BaseLLM from semantic_router.schema import Message from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger import json class OpenAILLM(BaseLLM): client: Optional[openai.OpenAI] temperature: Optional[float] max_tokens: Optional[int] def __init__( self, name: Optional[str] = None, openai_api_key: Optional[str] = None, temperature: float = 0.01, max_tokens: int = 200, ): if name is None: name = EncoderDefault.OPENAI.value["language_model"] super().__init__(name=name) api_key = openai_api_key or os.getenv("OPENAI_API_KEY") if api_key is None: raise ValueError("OpenAI API key cannot be 'None'.") try: self.client = openai.OpenAI(api_key=api_key) except Exception as e: raise ValueError( f"OpenAI API client failed to initialize. Error: {e}" ) from e self.temperature = temperature self.max_tokens = max_tokens def __call__( self, messages: List[Message], openai_function_schema: Optional[dict[str, Any]] = None, ) -> str: if self.client is None: raise ValueError("OpenAI client is not initialized.") try: if openai_function_schema: tools = [openai_function_schema] else: tools = None completion = self.client.chat.completions.create( model=self.name, messages=[m.to_openai() for m in messages], temperature=self.temperature, max_tokens=self.max_tokens, tools=tools, # type: ignore # MyPy expecting Iterable[ChatCompletionToolParam] | NotGiven, but dict is accepted by OpenAI. ) if openai_function_schema: tool_calls = completion.choices[0].message.tool_calls if tool_calls is None: raise ValueError("Invalid output, expected a tool call.") if len(tool_calls) != 1: raise ValueError( "Invalid output, expected a single tool to be specified." ) arguments = tool_calls[0].function.arguments if arguments is None: raise ValueError( "Invalid output, expected arguments to be specified." ) output = str(arguments) # str to keep MyPy happy. else: content = completion.choices[0].message.content if content is None: raise ValueError("Invalid output, expected content.") output = str(content) # str to keep MyPy happy. return output except Exception as e: logger.error(f"LLM error: {e}") raise Exception(f"LLM error: {e}") from e def extract_function_inputs_openai( self, query: str, openai_function_schema: dict[str, Any] ) -> dict: messages = [] system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." messages.append(Message(role="system", content=system_prompt)) messages.append(Message(role="user", content=query)) function_inputs_str = self( messages=messages, openai_function_schema=openai_function_schema ) function_inputs = json.loads(function_inputs_str) return function_inputs