Skip to content
Snippets Groups Projects
Commit fe135a51 authored by jamescalam's avatar jamescalam
Browse files

fix: types for openai llm and encoders

parent 095fa2c5
Branches
Tags
No related merge requests found
import os
from typing import List, Optional, Any, Callable, Dict, Union
from pydantic import PrivateAttr
import openai
from openai._types import NotGiven, NOT_GIVEN
......@@ -21,8 +22,8 @@ from openai.types.chat.chat_completion_message_tool_call import (
class OpenAILLM(BaseLLM):
client: Optional[openai.OpenAI]
async_client: Optional[openai.AsyncOpenAI]
_client: Optional[openai.OpenAI] = PrivateAttr(default=None)
_async_client: Optional[openai.AsyncOpenAI] = PrivateAttr(default=None)
def __init__(
self,
......@@ -38,8 +39,8 @@ class OpenAILLM(BaseLLM):
if api_key is None:
raise ValueError("OpenAI API key cannot be 'None'.")
try:
self.async_client = openai.AsyncOpenAI(api_key=api_key)
self.client = openai.OpenAI(api_key=api_key)
self._async_client = openai.AsyncOpenAI(api_key=api_key)
self._client = openai.OpenAI(api_key=api_key)
except Exception as e:
raise ValueError(
f"OpenAI API client failed to initialize. Error: {e}"
......@@ -86,14 +87,14 @@ class OpenAILLM(BaseLLM):
messages: List[Message],
function_schemas: Optional[List[Dict[str, Any]]] = None,
) -> str:
if self.client is None:
if self._client is None:
raise ValueError("OpenAI client is not initialized.")
try:
tools: Union[List[Dict[str, Any]], NotGiven] = (
function_schemas if function_schemas else NOT_GIVEN
)
completion = self.client.chat.completions.create(
completion = self._client.chat.completions.create(
model=self.name,
messages=[m.to_openai() for m in messages],
temperature=self.temperature,
......@@ -130,14 +131,14 @@ class OpenAILLM(BaseLLM):
messages: List[Message],
function_schemas: Optional[List[Dict[str, Any]]] = None,
) -> str:
if self.async_client is None:
if self._async_client is None:
raise ValueError("OpenAI async_client is not initialized.")
try:
tools: Union[List[Dict[str, Any]], NotGiven] = (
function_schemas if function_schemas is not None else NOT_GIVEN
)
completion = await self.async_client.chat.completions.create(
completion = await self._async_client.chat.completions.create(
model=self.name,
messages=[m.to_openai() for m in messages],
temperature=self.temperature,
......
......@@ -30,7 +30,7 @@ class TestOpenAIEncoder:
side_effect = ["fake-model-name", "fake-api-key", "fake-org-id"]
mocker.patch("os.getenv", side_effect=side_effect)
encoder = OpenAIEncoder()
assert encoder.client is not None
assert encoder._client is not None
def test_openai_encoder_init_no_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
......@@ -39,7 +39,7 @@ class TestOpenAIEncoder:
def test_openai_encoder_call_uninitialized_client(self, openai_encoder):
# Set the client to None to simulate an uninitialized client
openai_encoder.client = None
openai_encoder._client = None
with pytest.raises(ValueError) as e:
openai_encoder(["test document"])
assert "OpenAI client is not initialized." in str(e.value)
......@@ -74,7 +74,7 @@ class TestOpenAIEncoder:
responses = [OpenAIError("OpenAI error"), mock_response]
mocker.patch.object(
openai_encoder.client.embeddings, "create", side_effect=responses
openai_encoder._client.embeddings, "create", side_effect=responses
)
with patch("semantic_router.encoders.openai.sleep", return_value=None):
embeddings = openai_encoder(["test document"])
......@@ -84,7 +84,7 @@ class TestOpenAIEncoder:
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
openai_encoder.client.embeddings,
openai_encoder._client.embeddings,
"create",
side_effect=Exception("Non-OpenAIError"),
)
......@@ -114,7 +114,7 @@ class TestOpenAIEncoder:
responses = [OpenAIError("OpenAI error"), mock_response]
mocker.patch.object(
openai_encoder.client.embeddings, "create", side_effect=responses
openai_encoder._client.embeddings, "create", side_effect=responses
)
with patch("semantic_router.encoders.openai.sleep", return_value=None):
embeddings = openai_encoder(["test document"])
......
......@@ -12,13 +12,13 @@ def azure_openai_llm(mocker):
class TestOpenAILLM:
def test_azure_openai_llm_init_with_api_key(self, azure_openai_llm):
assert azure_openai_llm.client is not None, "Client should be initialized"
assert azure_openai_llm._client is not None, "Client should be initialized"
assert azure_openai_llm.name == "gpt-4o", "Default name not set correctly"
def test_azure_openai_llm_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
llm = AzureOpenAILLM()
assert llm.client is not None
assert llm._client is not None
def test_azure_openai_llm_init_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
......@@ -44,7 +44,7 @@ class TestOpenAILLM:
def test_azure_openai_llm_call_uninitialized_client(self, azure_openai_llm):
# Set the client to None to simulate an uninitialized client
azure_openai_llm.client = None
azure_openai_llm._client = None
with pytest.raises(ValueError) as e:
llm_input = [Message(role="user", content="test")]
azure_openai_llm(llm_input)
......@@ -83,7 +83,7 @@ class TestOpenAILLM:
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
azure_openai_llm.client.chat.completions,
azure_openai_llm._client.chat.completions,
"create",
return_value=mock_completion,
)
......
......@@ -42,13 +42,13 @@ example_function_schema = {
class TestOpenAILLM:
def test_openai_llm_init_with_api_key(self, openai_llm):
assert openai_llm.client is not None, "Client should be initialized"
assert openai_llm._client is not None, "Client should be initialized"
assert openai_llm.name == "gpt-4o", "Default name not set correctly"
def test_openai_llm_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
llm = OpenAILLM()
assert llm.client is not None
assert llm._client is not None
def test_openai_llm_init_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
......@@ -57,7 +57,7 @@ class TestOpenAILLM:
def test_openai_llm_call_uninitialized_client(self, openai_llm):
# Set the client to None to simulate an uninitialized client
openai_llm.client = None
openai_llm._client = None
with pytest.raises(ValueError) as e:
llm_input = [Message(role="user", content="test")]
openai_llm(llm_input)
......@@ -79,7 +79,7 @@ class TestOpenAILLM:
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
openai_llm.client.chat.completions, "create", return_value=mock_completion
openai_llm._client.chat.completions, "create", return_value=mock_completion
)
llm_input = [Message(role="user", content="test")]
output = openai_llm(llm_input)
......@@ -127,7 +127,7 @@ class TestOpenAILLM:
# mocker.MagicMock(function=mocker.MagicMock(arguments="result"))
# ]
# mocker.patch.object(
# openai_llm.client.chat.completions, "create", return_value=mock_completion
# openai_llm._client.chat.completions, "create", return_value=mock_completion
# )
# llm_input = [Message(role="user", content="test")]
# function_schemas = [{"type": "function", "name": "sample_function"}]
......@@ -145,7 +145,7 @@ class TestOpenAILLM:
mock_completion.choices[0].message.tool_calls = [mock_tool_call]
mocker.patch.object(
openai_llm.client.chat.completions, "create", return_value=mock_completion
openai_llm._client.chat.completions, "create", return_value=mock_completion
)
llm_input = [Message(role="user", content="test")]
......@@ -160,7 +160,7 @@ class TestOpenAILLM:
mock_completion = mocker.MagicMock()
mock_completion.choices[0].message.tool_calls = None
mocker.patch.object(
openai_llm.client.chat.completions, "create", return_value=mock_completion
openai_llm._client.chat.completions, "create", return_value=mock_completion
)
llm_input = [Message(role="user", content="test")]
function_schemas = [{"type": "function", "name": "sample_function"}]
......@@ -180,7 +180,7 @@ class TestOpenAILLM:
mocker.MagicMock(function=mocker.MagicMock(arguments=None))
]
mocker.patch.object(
openai_llm.client.chat.completions, "create", return_value=mock_completion
openai_llm._client.chat.completions, "create", return_value=mock_completion
)
llm_input = [Message(role="user", content="test")]
function_schemas = [{"type": "function", "name": "sample_function"}]
......@@ -230,7 +230,7 @@ class TestOpenAILLM:
# Patching the completions.create method to return the mocked completion
mocker.patch.object(
openai_llm.client.chat.completions, "create", return_value=mock_completion
openai_llm._client.chat.completions, "create", return_value=mock_completion
)
# Input message list
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment