From f04846d9a57384af540d3d0d59e5721f85f0781b Mon Sep 17 00:00:00 2001
From: Kurtis Massey <55586356+kurtismassey@users.noreply.github.com>
Date: Sat, 4 May 2024 19:16:47 +0100
Subject: [PATCH] Increase test coverage for bedrock cohere

---
 tests/unit/encoders/test_bedrock.py | 76 +++++++++++++++++++++++++++++
 1 file changed, 76 insertions(+)

diff --git a/tests/unit/encoders/test_bedrock.py b/tests/unit/encoders/test_bedrock.py
index 2076d36e..6d438824 100644
--- a/tests/unit/encoders/test_bedrock.py
+++ b/tests/unit/encoders/test_bedrock.py
@@ -88,3 +88,79 @@ class TestBedrockEncoder:
         )
         with pytest.raises(ValueError):
             bedrock_encoder(["test"])
+
+    def test_raises_value_error_if_no_aws_session_credentials(self, mocker):
+        mocker.patch("boto3.Session")
+        mock_session = mocker.Mock()
+        mock_session.get_credentials.return_value = None
+        with pytest.raises(ValueError, match="Could not get AWS session"):
+            BedrockEncoder(session=mock_session)
+
+    def test_raises_value_error_if_no_aws_region(self, mocker):
+        mocker.patch("boto3.Session")
+        mock_session = mocker.Mock()
+        mock_session.region_name = None
+        with pytest.raises(ValueError, match="No AWS region provided"):
+            BedrockEncoder(session=mock_session)
+
+    def test_raises_value_error_if_client_initialisation_fails(self, mocker):
+        mocker.patch("boto3.Session")
+        mock_session = mocker.Mock()
+        mock_session.client.side_effect = Exception("Client initialisation failed")
+        with pytest.raises(ValueError, match="Bedrock client failed to initialise"):
+            BedrockEncoder(session=mock_session)
+
+    def test_raises_value_error_for_unknown_model_name(self, mocker):
+        mocker.patch("boto3.Session")
+        mock_session = mocker.Mock()
+        mock_session.get_credentials.return_value = True
+        mocker.patch("boto3.Session.client")
+
+        unknown_model_name = "unknown_model"
+        bedrock_encoder = BedrockEncoder(
+            name=unknown_model_name,
+            session=mock_session,
+            region="us-west-2",
+        )
+
+        with pytest.raises(ValueError, match="Unknown model name"):
+            bedrock_encoder(["test"])
+
+
+@pytest.fixture
+def bedrock_encoder_with_cohere(mocker):
+    mocker.patch("boto3.Session")
+    mock_session = mocker.Mock()
+    mock_session.get_credentials.return_value = True
+    mocker.patch("boto3.Session.client")
+    return BedrockEncoder(name="cohere_model", session=mock_session, region="us-west-2")
+
+
+class TestBedrockEncoderWithCohere:
+    def test_cohere_embedding_single_chunk(self, bedrock_encoder_with_cohere):
+        response_content = json.dumps({"embeddings": [[0.1, 0.2, 0.3]]})
+        response_body = BytesIO(response_content.encode("utf-8"))
+        mock_response = {"body": response_body}
+
+        bedrock_encoder_with_cohere.client.invoke_model.return_value = mock_response
+
+        result = bedrock_encoder_with_cohere(["short test"])
+
+        assert isinstance(result, list), "Result should be a list"
+        assert all(
+            isinstance(item, list) for item in result
+        ), "Each item should be a list"
+        assert result == [[0.1, 0.2, 0.3]], "Expected embedding [0.1, 0.2, 0.3]"
+
+    def test_cohere_input_type(self, bedrock_encoder_with_cohere):
+        bedrock_encoder_with_cohere.input_type = "different_type"
+        response_content = json.dumps({"embeddings": [[0.1, 0.2, 0.3]]})
+        response_body = BytesIO(response_content.encode("utf-8"))
+        mock_response = {"body": response_body}
+
+        bedrock_encoder_with_cohere.client.invoke_model.return_value = mock_response
+
+        result = bedrock_encoder_with_cohere(["test with different input type"])
+
+        assert isinstance(result, list), "Result should be a list"
+        assert result == [[0.1, 0.2, 0.3]], "Expected specific embeddings"
-- 
GitLab