From 548bd40393b8e4483d9092359051c699558c709f Mon Sep 17 00:00:00 2001 From: Ismail Ashraq <issey1455@gmail.com> Date: Sun, 7 Jan 2024 03:51:54 +0500 Subject: [PATCH] add tests for llms --- semantic_router/llms/base.py | 3 +- tests/unit/llms/test_llm_base.py | 16 +++++++ tests/unit/llms/test_llm_cohere.py | 52 +++++++++++++++++++++++ tests/unit/llms/test_llm_openai.py | 55 ++++++++++++++++++++++++ tests/unit/llms/test_llm_openrouter.py | 59 ++++++++++++++++++++++++++ tests/unit/test_route.py | 17 +++++++- tests/unit/test_schema.py | 27 +++++++++++- 7 files changed, 226 insertions(+), 3 deletions(-) create mode 100644 tests/unit/llms/test_llm_base.py create mode 100644 tests/unit/llms/test_llm_cohere.py create mode 100644 tests/unit/llms/test_llm_openai.py create mode 100644 tests/unit/llms/test_llm_openrouter.py diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 2a1a038e..dd8a0afa 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -1,4 +1,5 @@ from pydantic import BaseModel +from semantic_router.schema import Message class BaseLLM(BaseModel): @@ -7,5 +8,5 @@ class BaseLLM(BaseModel): class Config: arbitrary_types_allowed = True - def __call__(self, prompt) -> str | None: + def __call__(self, messages: list[Message]) -> str | None: raise NotImplementedError("Subclasses must implement this method") diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py new file mode 100644 index 00000000..df78d8f5 --- /dev/null +++ b/tests/unit/llms/test_llm_base.py @@ -0,0 +1,16 @@ +import pytest + +from semantic_router.llms import BaseLLM + + +class TestBaseLLM: + @pytest.fixture + def base_llm(self): + return BaseLLM(name="TestLLM") + + def test_base_llm_initialization(self, base_llm): + assert base_llm.name == "TestLLM", "Initialization of name failed" + + def test_base_llm_call_method_not_implemented(self, base_llm): + with pytest.raises(NotImplementedError): + base_llm("test") diff --git a/tests/unit/llms/test_llm_cohere.py b/tests/unit/llms/test_llm_cohere.py new file mode 100644 index 00000000..32443f04 --- /dev/null +++ b/tests/unit/llms/test_llm_cohere.py @@ -0,0 +1,52 @@ +import pytest + +from semantic_router.llms import Cohere +from semantic_router.schema import Message + + +@pytest.fixture +def cohere_llm(mocker): + mocker.patch("cohere.Client") + return Cohere(cohere_api_key="test_api_key") + + +class TestCohereLLM: + def test_initialization_with_api_key(self, cohere_llm): + assert cohere_llm.client is not None, "Client should be initialized" + assert cohere_llm.name == "command", "Default name not set correctly" + + def test_initialization_without_api_key(self, mocker, monkeypatch): + monkeypatch.delenv("COHERE_API_KEY", raising=False) + mocker.patch("cohere.Client") + with pytest.raises(ValueError): + Cohere() + + def test_call_method(self, cohere_llm, mocker): + mock_llm = mocker.MagicMock() + mock_llm.text = "test" + cohere_llm.client.chat.return_value = mock_llm + + llm_input = [Message(role="user", content="test")] + result = cohere_llm(llm_input) + assert isinstance(result, str), "Result should be a str" + cohere_llm.client.chat.assert_called_once() + + def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker): + mocker.patch( + "cohere.Client", side_effect=Exception("Failed to initialize client") + ) + with pytest.raises(ValueError): + Cohere(cohere_api_key="test_api_key") + + def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker): + mocker.patch("cohere.Client", return_value=None) + llm = Cohere(cohere_api_key="test_api_key") + with pytest.raises(ValueError): + llm("test") + + def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker): + mocker.patch.object( + cohere_llm.client, "__call__", side_effect=Exception("API call failed") + ) + with pytest.raises(ValueError): + cohere_llm("test") diff --git a/tests/unit/llms/test_llm_openai.py b/tests/unit/llms/test_llm_openai.py new file mode 100644 index 00000000..4b2b2f54 --- /dev/null +++ b/tests/unit/llms/test_llm_openai.py @@ -0,0 +1,55 @@ +import pytest +from semantic_router.llms import OpenAI +from semantic_router.schema import Message + + +@pytest.fixture +def openai_llm(mocker): + mocker.patch("openai.Client") + return OpenAI(openai_api_key="test_api_key") + + +class TestOpenAILLM: + def test_openai_llm_init_with_api_key(self, openai_llm): + assert openai_llm.client is not None, "Client should be initialized" + assert openai_llm.name == "gpt-3.5-turbo", "Default name not set correctly" + + def test_openai_llm_init_success(self, mocker): + mocker.patch("os.getenv", return_value="fake-api-key") + llm = OpenAI() + assert llm.client is not None + + def test_openai_llm_init_without_api_key(self, mocker): + mocker.patch("os.getenv", return_value=None) + with pytest.raises(ValueError) as _: + OpenAI() + + def test_openai_llm_call_uninitialized_client(self, openai_llm): + # Set the client to None to simulate an uninitialized client + openai_llm.client = None + with pytest.raises(ValueError) as e: + llm_input = [Message(role="user", content="test")] + openai_llm(llm_input) + assert "OpenAI client is not initialized." in str(e.value) + + def test_openai_llm_init_exception(self, mocker): + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error")) + with pytest.raises(ValueError) as e: + OpenAI() + assert ( + "OpenAI API client failed to initialize. Error: Initialization error" + in str(e.value) + ) + + def test_openai_llm_call_success(self, openai_llm, mocker): + mock_completion = mocker.MagicMock() + mock_completion.choices[0].message.content = "test" + + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch.object( + openai_llm.client.chat.completions, "create", return_value=mock_completion + ) + llm_input = [Message(role="user", content="test")] + output = openai_llm(llm_input) + assert output == "test" diff --git a/tests/unit/llms/test_llm_openrouter.py b/tests/unit/llms/test_llm_openrouter.py new file mode 100644 index 00000000..3009e293 --- /dev/null +++ b/tests/unit/llms/test_llm_openrouter.py @@ -0,0 +1,59 @@ +import pytest +from semantic_router.llms import OpenRouter +from semantic_router.schema import Message + + +@pytest.fixture +def openrouter_llm(mocker): + mocker.patch("openai.Client") + return OpenRouter(openrouter_api_key="test_api_key") + + +class TestOpenRouterLLM: + def test_openrouter_llm_init_with_api_key(self, openrouter_llm): + assert openrouter_llm.client is not None, "Client should be initialized" + assert ( + openrouter_llm.name == "mistralai/mistral-7b-instruct" + ), "Default name not set correctly" + + def test_openrouter_llm_init_success(self, mocker): + mocker.patch("os.getenv", return_value="fake-api-key") + llm = OpenRouter() + assert llm.client is not None + + def test_openrouter_llm_init_without_api_key(self, mocker): + mocker.patch("os.getenv", return_value=None) + with pytest.raises(ValueError) as _: + OpenRouter() + + def test_openrouter_llm_call_uninitialized_client(self, openrouter_llm): + # Set the client to None to simulate an uninitialized client + openrouter_llm.client = None + with pytest.raises(ValueError) as e: + llm_input = [Message(role="user", content="test")] + openrouter_llm(llm_input) + assert "OpenRouter client is not initialized." in str(e.value) + + def test_openrouter_llm_init_exception(self, mocker): + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error")) + with pytest.raises(ValueError) as e: + OpenRouter() + assert ( + "OpenRouter API client failed to initialize. Error: Initialization error" + in str(e.value) + ) + + def test_openrouter_llm_call_success(self, openrouter_llm, mocker): + mock_completion = mocker.MagicMock() + mock_completion.choices[0].message.content = "test" + + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch.object( + openrouter_llm.client.chat.completions, + "create", + return_value=mock_completion, + ) + llm_input = [Message(role="user", content="test")] + output = openrouter_llm(llm_input) + assert output == "test" diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py index 2eb784d4..e7842d39 100644 --- a/tests/unit/test_route.py +++ b/tests/unit/test_route.py @@ -1,6 +1,6 @@ from unittest.mock import patch # , AsyncMock -# import pytest +import pytest from semantic_router.llms import BaseLLM from semantic_router.route import Route, is_valid @@ -61,6 +61,21 @@ class MockLLM(BaseLLM): class TestRoute: + def test_value_error_in_route_call(self): + function_schema = {"name": "test_function", "type": "function"} + + route = Route( + name="test_function", + utterances=["utterance1", "utterance2"], + function_schema=function_schema, + ) + + with pytest.raises( + ValueError, + match="LLM is required for dynamic routes. Please ensure the 'llm' is set.", + ): + route("test_query") + def test_generate_dynamic_route(self): mock_llm = MockLLM(name="test") function_schema = {"name": "test_function", "type": "function"} diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index 97b5028e..a9e794cb 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -1,9 +1,10 @@ import pytest - +from pydantic import ValidationError from semantic_router.schema import ( CohereEncoder, Encoder, EncoderType, + Message, OpenAIEncoder, ) @@ -38,3 +39,27 @@ class TestEncoderDataclass: encoder = Encoder(type="openai", name="test-engine") result = encoder(["test"]) assert result == [0.1, 0.2, 0.3] + + +class TestMessageDataclass: + def test_message_creation(self): + message = Message(role="user", content="Hello!") + assert message.role == "user" + assert message.content == "Hello!" + + with pytest.raises(ValidationError): + Message(user_role="invalid_role", message="Hello!") + + def test_message_to_openai(self): + message = Message(role="user", content="Hello!") + openai_format = message.to_openai() + assert openai_format == {"role": "user", "content": "Hello!"} + + message = Message(role="invalid_role", content="Hello!") + with pytest.raises(ValueError): + message.to_openai() + + def test_message_to_cohere(self): + message = Message(role="user", content="Hello!") + cohere_format = message.to_cohere() + assert cohere_format == {"role": "user", "message": "Hello!"} -- GitLab