Skip to content
Snippets Groups Projects
Commit 548bd403 authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

add tests for llms

parent 5141474a
No related branches found
No related tags found
No related merge requests found
from pydantic import BaseModel
from semantic_router.schema import Message
class BaseLLM(BaseModel):
......@@ -7,5 +8,5 @@ class BaseLLM(BaseModel):
class Config:
arbitrary_types_allowed = True
def __call__(self, prompt) -> str | None:
def __call__(self, messages: list[Message]) -> str | None:
raise NotImplementedError("Subclasses must implement this method")
import pytest
from semantic_router.llms import BaseLLM
class TestBaseLLM:
@pytest.fixture
def base_llm(self):
return BaseLLM(name="TestLLM")
def test_base_llm_initialization(self, base_llm):
assert base_llm.name == "TestLLM", "Initialization of name failed"
def test_base_llm_call_method_not_implemented(self, base_llm):
with pytest.raises(NotImplementedError):
base_llm("test")
import pytest
from semantic_router.llms import Cohere
from semantic_router.schema import Message
@pytest.fixture
def cohere_llm(mocker):
mocker.patch("cohere.Client")
return Cohere(cohere_api_key="test_api_key")
class TestCohereLLM:
def test_initialization_with_api_key(self, cohere_llm):
assert cohere_llm.client is not None, "Client should be initialized"
assert cohere_llm.name == "command", "Default name not set correctly"
def test_initialization_without_api_key(self, mocker, monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
mocker.patch("cohere.Client")
with pytest.raises(ValueError):
Cohere()
def test_call_method(self, cohere_llm, mocker):
mock_llm = mocker.MagicMock()
mock_llm.text = "test"
cohere_llm.client.chat.return_value = mock_llm
llm_input = [Message(role="user", content="test")]
result = cohere_llm(llm_input)
assert isinstance(result, str), "Result should be a str"
cohere_llm.client.chat.assert_called_once()
def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker):
mocker.patch(
"cohere.Client", side_effect=Exception("Failed to initialize client")
)
with pytest.raises(ValueError):
Cohere(cohere_api_key="test_api_key")
def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker):
mocker.patch("cohere.Client", return_value=None)
llm = Cohere(cohere_api_key="test_api_key")
with pytest.raises(ValueError):
llm("test")
def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker):
mocker.patch.object(
cohere_llm.client, "__call__", side_effect=Exception("API call failed")
)
with pytest.raises(ValueError):
cohere_llm("test")
import pytest
from semantic_router.llms import OpenAI
from semantic_router.schema import Message
@pytest.fixture
def openai_llm(mocker):
mocker.patch("openai.Client")
return OpenAI(openai_api_key="test_api_key")
class TestOpenAILLM:
def test_openai_llm_init_with_api_key(self, openai_llm):
assert openai_llm.client is not None, "Client should be initialized"
assert openai_llm.name == "gpt-3.5-turbo", "Default name not set correctly"
def test_openai_llm_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
llm = OpenAI()
assert llm.client is not None
def test_openai_llm_init_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
with pytest.raises(ValueError) as _:
OpenAI()
def test_openai_llm_call_uninitialized_client(self, openai_llm):
# Set the client to None to simulate an uninitialized client
openai_llm.client = None
with pytest.raises(ValueError) as e:
llm_input = [Message(role="user", content="test")]
openai_llm(llm_input)
assert "OpenAI client is not initialized." in str(e.value)
def test_openai_llm_init_exception(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error"))
with pytest.raises(ValueError) as e:
OpenAI()
assert (
"OpenAI API client failed to initialize. Error: Initialization error"
in str(e.value)
)
def test_openai_llm_call_success(self, openai_llm, mocker):
mock_completion = mocker.MagicMock()
mock_completion.choices[0].message.content = "test"
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
openai_llm.client.chat.completions, "create", return_value=mock_completion
)
llm_input = [Message(role="user", content="test")]
output = openai_llm(llm_input)
assert output == "test"
import pytest
from semantic_router.llms import OpenRouter
from semantic_router.schema import Message
@pytest.fixture
def openrouter_llm(mocker):
mocker.patch("openai.Client")
return OpenRouter(openrouter_api_key="test_api_key")
class TestOpenRouterLLM:
def test_openrouter_llm_init_with_api_key(self, openrouter_llm):
assert openrouter_llm.client is not None, "Client should be initialized"
assert (
openrouter_llm.name == "mistralai/mistral-7b-instruct"
), "Default name not set correctly"
def test_openrouter_llm_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
llm = OpenRouter()
assert llm.client is not None
def test_openrouter_llm_init_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
with pytest.raises(ValueError) as _:
OpenRouter()
def test_openrouter_llm_call_uninitialized_client(self, openrouter_llm):
# Set the client to None to simulate an uninitialized client
openrouter_llm.client = None
with pytest.raises(ValueError) as e:
llm_input = [Message(role="user", content="test")]
openrouter_llm(llm_input)
assert "OpenRouter client is not initialized." in str(e.value)
def test_openrouter_llm_init_exception(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error"))
with pytest.raises(ValueError) as e:
OpenRouter()
assert (
"OpenRouter API client failed to initialize. Error: Initialization error"
in str(e.value)
)
def test_openrouter_llm_call_success(self, openrouter_llm, mocker):
mock_completion = mocker.MagicMock()
mock_completion.choices[0].message.content = "test"
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
openrouter_llm.client.chat.completions,
"create",
return_value=mock_completion,
)
llm_input = [Message(role="user", content="test")]
output = openrouter_llm(llm_input)
assert output == "test"
from unittest.mock import patch # , AsyncMock
# import pytest
import pytest
from semantic_router.llms import BaseLLM
from semantic_router.route import Route, is_valid
......@@ -61,6 +61,21 @@ class MockLLM(BaseLLM):
class TestRoute:
def test_value_error_in_route_call(self):
function_schema = {"name": "test_function", "type": "function"}
route = Route(
name="test_function",
utterances=["utterance1", "utterance2"],
function_schema=function_schema,
)
with pytest.raises(
ValueError,
match="LLM is required for dynamic routes. Please ensure the 'llm' is set.",
):
route("test_query")
def test_generate_dynamic_route(self):
mock_llm = MockLLM(name="test")
function_schema = {"name": "test_function", "type": "function"}
......
import pytest
from pydantic import ValidationError
from semantic_router.schema import (
CohereEncoder,
Encoder,
EncoderType,
Message,
OpenAIEncoder,
)
......@@ -38,3 +39,27 @@ class TestEncoderDataclass:
encoder = Encoder(type="openai", name="test-engine")
result = encoder(["test"])
assert result == [0.1, 0.2, 0.3]
class TestMessageDataclass:
def test_message_creation(self):
message = Message(role="user", content="Hello!")
assert message.role == "user"
assert message.content == "Hello!"
with pytest.raises(ValidationError):
Message(user_role="invalid_role", message="Hello!")
def test_message_to_openai(self):
message = Message(role="user", content="Hello!")
openai_format = message.to_openai()
assert openai_format == {"role": "user", "content": "Hello!"}
message = Message(role="invalid_role", content="Hello!")
with pytest.raises(ValueError):
message.to_openai()
def test_message_to_cohere(self):
message = Message(role="user", content="Hello!")
cohere_format = message.to_cohere()
assert cohere_format == {"role": "user", "message": "Hello!"}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment