From 548bd40393b8e4483d9092359051c699558c709f Mon Sep 17 00:00:00 2001
From: Ismail Ashraq <issey1455@gmail.com>
Date: Sun, 7 Jan 2024 03:51:54 +0500
Subject: [PATCH] add tests for llms

---
 semantic_router/llms/base.py           |  3 +-
 tests/unit/llms/test_llm_base.py       | 16 +++++++
 tests/unit/llms/test_llm_cohere.py     | 52 +++++++++++++++++++++++
 tests/unit/llms/test_llm_openai.py     | 55 ++++++++++++++++++++++++
 tests/unit/llms/test_llm_openrouter.py | 59 ++++++++++++++++++++++++++
 tests/unit/test_route.py               | 17 +++++++-
 tests/unit/test_schema.py              | 27 +++++++++++-
 7 files changed, 226 insertions(+), 3 deletions(-)
 create mode 100644 tests/unit/llms/test_llm_base.py
 create mode 100644 tests/unit/llms/test_llm_cohere.py
 create mode 100644 tests/unit/llms/test_llm_openai.py
 create mode 100644 tests/unit/llms/test_llm_openrouter.py

diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py
index 2a1a038e..dd8a0afa 100644
--- a/semantic_router/llms/base.py
+++ b/semantic_router/llms/base.py
@@ -1,4 +1,5 @@
 from pydantic import BaseModel
+from semantic_router.schema import Message
 
 
 class BaseLLM(BaseModel):
@@ -7,5 +8,5 @@ class BaseLLM(BaseModel):
     class Config:
         arbitrary_types_allowed = True
 
-    def __call__(self, prompt) -> str | None:
+    def __call__(self, messages: list[Message]) -> str | None:
         raise NotImplementedError("Subclasses must implement this method")
diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py
new file mode 100644
index 00000000..df78d8f5
--- /dev/null
+++ b/tests/unit/llms/test_llm_base.py
@@ -0,0 +1,16 @@
+import pytest
+
+from semantic_router.llms import BaseLLM
+
+
+class TestBaseLLM:
+    @pytest.fixture
+    def base_llm(self):
+        return BaseLLM(name="TestLLM")
+
+    def test_base_llm_initialization(self, base_llm):
+        assert base_llm.name == "TestLLM", "Initialization of name failed"
+
+    def test_base_llm_call_method_not_implemented(self, base_llm):
+        with pytest.raises(NotImplementedError):
+            base_llm("test")
diff --git a/tests/unit/llms/test_llm_cohere.py b/tests/unit/llms/test_llm_cohere.py
new file mode 100644
index 00000000..32443f04
--- /dev/null
+++ b/tests/unit/llms/test_llm_cohere.py
@@ -0,0 +1,52 @@
+import pytest
+
+from semantic_router.llms import Cohere
+from semantic_router.schema import Message
+
+
+@pytest.fixture
+def cohere_llm(mocker):
+    mocker.patch("cohere.Client")
+    return Cohere(cohere_api_key="test_api_key")
+
+
+class TestCohereLLM:
+    def test_initialization_with_api_key(self, cohere_llm):
+        assert cohere_llm.client is not None, "Client should be initialized"
+        assert cohere_llm.name == "command", "Default name not set correctly"
+
+    def test_initialization_without_api_key(self, mocker, monkeypatch):
+        monkeypatch.delenv("COHERE_API_KEY", raising=False)
+        mocker.patch("cohere.Client")
+        with pytest.raises(ValueError):
+            Cohere()
+
+    def test_call_method(self, cohere_llm, mocker):
+        mock_llm = mocker.MagicMock()
+        mock_llm.text = "test"
+        cohere_llm.client.chat.return_value = mock_llm
+
+        llm_input = [Message(role="user", content="test")]
+        result = cohere_llm(llm_input)
+        assert isinstance(result, str), "Result should be a str"
+        cohere_llm.client.chat.assert_called_once()
+
+    def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker):
+        mocker.patch(
+            "cohere.Client", side_effect=Exception("Failed to initialize client")
+        )
+        with pytest.raises(ValueError):
+            Cohere(cohere_api_key="test_api_key")
+
+    def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker):
+        mocker.patch("cohere.Client", return_value=None)
+        llm = Cohere(cohere_api_key="test_api_key")
+        with pytest.raises(ValueError):
+            llm("test")
+
+    def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker):
+        mocker.patch.object(
+            cohere_llm.client, "__call__", side_effect=Exception("API call failed")
+        )
+        with pytest.raises(ValueError):
+            cohere_llm("test")
diff --git a/tests/unit/llms/test_llm_openai.py b/tests/unit/llms/test_llm_openai.py
new file mode 100644
index 00000000..4b2b2f54
--- /dev/null
+++ b/tests/unit/llms/test_llm_openai.py
@@ -0,0 +1,55 @@
+import pytest
+from semantic_router.llms import OpenAI
+from semantic_router.schema import Message
+
+
+@pytest.fixture
+def openai_llm(mocker):
+    mocker.patch("openai.Client")
+    return OpenAI(openai_api_key="test_api_key")
+
+
+class TestOpenAILLM:
+    def test_openai_llm_init_with_api_key(self, openai_llm):
+        assert openai_llm.client is not None, "Client should be initialized"
+        assert openai_llm.name == "gpt-3.5-turbo", "Default name not set correctly"
+
+    def test_openai_llm_init_success(self, mocker):
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        llm = OpenAI()
+        assert llm.client is not None
+
+    def test_openai_llm_init_without_api_key(self, mocker):
+        mocker.patch("os.getenv", return_value=None)
+        with pytest.raises(ValueError) as _:
+            OpenAI()
+
+    def test_openai_llm_call_uninitialized_client(self, openai_llm):
+        # Set the client to None to simulate an uninitialized client
+        openai_llm.client = None
+        with pytest.raises(ValueError) as e:
+            llm_input = [Message(role="user", content="test")]
+            openai_llm(llm_input)
+        assert "OpenAI client is not initialized." in str(e.value)
+
+    def test_openai_llm_init_exception(self, mocker):
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error"))
+        with pytest.raises(ValueError) as e:
+            OpenAI()
+        assert (
+            "OpenAI API client failed to initialize. Error: Initialization error"
+            in str(e.value)
+        )
+
+    def test_openai_llm_call_success(self, openai_llm, mocker):
+        mock_completion = mocker.MagicMock()
+        mock_completion.choices[0].message.content = "test"
+
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch.object(
+            openai_llm.client.chat.completions, "create", return_value=mock_completion
+        )
+        llm_input = [Message(role="user", content="test")]
+        output = openai_llm(llm_input)
+        assert output == "test"
diff --git a/tests/unit/llms/test_llm_openrouter.py b/tests/unit/llms/test_llm_openrouter.py
new file mode 100644
index 00000000..3009e293
--- /dev/null
+++ b/tests/unit/llms/test_llm_openrouter.py
@@ -0,0 +1,59 @@
+import pytest
+from semantic_router.llms import OpenRouter
+from semantic_router.schema import Message
+
+
+@pytest.fixture
+def openrouter_llm(mocker):
+    mocker.patch("openai.Client")
+    return OpenRouter(openrouter_api_key="test_api_key")
+
+
+class TestOpenRouterLLM:
+    def test_openrouter_llm_init_with_api_key(self, openrouter_llm):
+        assert openrouter_llm.client is not None, "Client should be initialized"
+        assert (
+            openrouter_llm.name == "mistralai/mistral-7b-instruct"
+        ), "Default name not set correctly"
+
+    def test_openrouter_llm_init_success(self, mocker):
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        llm = OpenRouter()
+        assert llm.client is not None
+
+    def test_openrouter_llm_init_without_api_key(self, mocker):
+        mocker.patch("os.getenv", return_value=None)
+        with pytest.raises(ValueError) as _:
+            OpenRouter()
+
+    def test_openrouter_llm_call_uninitialized_client(self, openrouter_llm):
+        # Set the client to None to simulate an uninitialized client
+        openrouter_llm.client = None
+        with pytest.raises(ValueError) as e:
+            llm_input = [Message(role="user", content="test")]
+            openrouter_llm(llm_input)
+        assert "OpenRouter client is not initialized." in str(e.value)
+
+    def test_openrouter_llm_init_exception(self, mocker):
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error"))
+        with pytest.raises(ValueError) as e:
+            OpenRouter()
+        assert (
+            "OpenRouter API client failed to initialize. Error: Initialization error"
+            in str(e.value)
+        )
+
+    def test_openrouter_llm_call_success(self, openrouter_llm, mocker):
+        mock_completion = mocker.MagicMock()
+        mock_completion.choices[0].message.content = "test"
+
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch.object(
+            openrouter_llm.client.chat.completions,
+            "create",
+            return_value=mock_completion,
+        )
+        llm_input = [Message(role="user", content="test")]
+        output = openrouter_llm(llm_input)
+        assert output == "test"
diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py
index 2eb784d4..e7842d39 100644
--- a/tests/unit/test_route.py
+++ b/tests/unit/test_route.py
@@ -1,6 +1,6 @@
 from unittest.mock import patch  # , AsyncMock
 
-# import pytest
+import pytest
 from semantic_router.llms import BaseLLM
 from semantic_router.route import Route, is_valid
 
@@ -61,6 +61,21 @@ class MockLLM(BaseLLM):
 
 
 class TestRoute:
+    def test_value_error_in_route_call(self):
+        function_schema = {"name": "test_function", "type": "function"}
+
+        route = Route(
+            name="test_function",
+            utterances=["utterance1", "utterance2"],
+            function_schema=function_schema,
+        )
+
+        with pytest.raises(
+            ValueError,
+            match="LLM is required for dynamic routes. Please ensure the 'llm' is set.",
+        ):
+            route("test_query")
+
     def test_generate_dynamic_route(self):
         mock_llm = MockLLM(name="test")
         function_schema = {"name": "test_function", "type": "function"}
diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py
index 97b5028e..a9e794cb 100644
--- a/tests/unit/test_schema.py
+++ b/tests/unit/test_schema.py
@@ -1,9 +1,10 @@
 import pytest
-
+from pydantic import ValidationError
 from semantic_router.schema import (
     CohereEncoder,
     Encoder,
     EncoderType,
+    Message,
     OpenAIEncoder,
 )
 
@@ -38,3 +39,27 @@ class TestEncoderDataclass:
         encoder = Encoder(type="openai", name="test-engine")
         result = encoder(["test"])
         assert result == [0.1, 0.2, 0.3]
+
+
+class TestMessageDataclass:
+    def test_message_creation(self):
+        message = Message(role="user", content="Hello!")
+        assert message.role == "user"
+        assert message.content == "Hello!"
+
+        with pytest.raises(ValidationError):
+            Message(user_role="invalid_role", message="Hello!")
+
+    def test_message_to_openai(self):
+        message = Message(role="user", content="Hello!")
+        openai_format = message.to_openai()
+        assert openai_format == {"role": "user", "content": "Hello!"}
+
+        message = Message(role="invalid_role", content="Hello!")
+        with pytest.raises(ValueError):
+            message.to_openai()
+
+    def test_message_to_cohere(self):
+        message = Message(role="user", content="Hello!")
+        cohere_format = message.to_cohere()
+        assert cohere_format == {"role": "user", "message": "Hello!"}
-- 
GitLab