From c12d80e2c14f640927c64d14dd1179b786d395ec Mon Sep 17 00:00:00 2001
From: Rafal-Chrzanowski-IBM <Rafal.Chrzanowski@ibm.com>
Date: Thu, 27 Feb 2025 23:15:25 +0100
Subject: [PATCH] Update `WatsonxLLM.metadata` property to avoid validation
 error when `model_limits` field isn't present (#17839)

---
 .../llama_index/llms/ibm/base.py              |  24 ++-
 .../llms/llama-index-llms-ibm/pyproject.toml  |   2 +-
 .../llama-index-llms-ibm/tests/test_ibm.py    | 145 +++++++++++++++++-
 3 files changed, 156 insertions(+), 15 deletions(-)

diff --git a/llama-index-integrations/llms/llama-index-llms-ibm/llama_index/llms/ibm/base.py b/llama-index-integrations/llms/llama-index-llms-ibm/llama_index/llms/ibm/base.py
index 91f789688a..21b515bfef 100644
--- a/llama-index-integrations/llms/llama-index-llms-ibm/llama_index/llms/ibm/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-ibm/llama_index/llms/ibm/base.py
@@ -291,12 +291,9 @@ class WatsonxLLM(FunctionCallingLLM):
     @property
     def metadata(self) -> LLMMetadata:
         if self.model_id:
-            return LLMMetadata(
-                context_window=(
-                    self.model_info.get("model_limits", {}).get("max_sequence_length")
-                ),
-                num_output=(self.max_new_tokens or DEFAULT_MAX_TOKENS),
-                model_name=self.model_id,
+            model_id = self.model_id
+            context_window = self.model_info.get("model_limits", {}).get(
+                "max_sequence_length"
             )
         else:
             model_id = self.deployment_info.get("entity", {}).get("base_model_id")
@@ -305,13 +302,14 @@ class WatsonxLLM(FunctionCallingLLM):
                 .get("model_limits", {})
                 .get("max_sequence_length")
             )
-            return LLMMetadata(
-                context_window=context_window
-                or self._context_window
-                or DEFAULT_CONTEXT_WINDOW,
-                num_output=(self.max_new_tokens or DEFAULT_MAX_TOKENS),
-                model_name=model_id or self._model.deployment_id,
-            )
+
+        return LLMMetadata(
+            context_window=context_window
+            or self._context_window
+            or DEFAULT_CONTEXT_WINDOW,
+            num_output=self.max_new_tokens or DEFAULT_MAX_TOKENS,
+            model_name=model_id or self._model.deployment_id,
+        )
 
     @property
     def sample_generation_text_params(self) -> Dict[str, Any]:
diff --git a/llama-index-integrations/llms/llama-index-llms-ibm/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-ibm/pyproject.toml
index 5730f562b4..bd264bd238 100644
--- a/llama-index-integrations/llms/llama-index-llms-ibm/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-ibm/pyproject.toml
@@ -31,7 +31,7 @@ license = "MIT"
 name = "llama-index-llms-ibm"
 packages = [{include = "llama_index/"}]
 readme = "README.md"
-version = "0.3.2"
+version = "0.3.3"
 
 [tool.poetry.dependencies]
 python = ">=3.10,<3.13"
diff --git a/llama-index-integrations/llms/llama-index-llms-ibm/tests/test_ibm.py b/llama-index-integrations/llms/llama-index-llms-ibm/tests/test_ibm.py
index 892c138040..f300b5de5f 100644
--- a/llama-index-integrations/llms/llama-index-llms-ibm/tests/test_ibm.py
+++ b/llama-index-integrations/llms/llama-index-llms-ibm/tests/test_ibm.py
@@ -3,8 +3,9 @@ from unittest.mock import MagicMock, patch
 import warnings
 import pytest
 
-from llama_index.core.base.llms.types import ChatMessage
+from llama_index.core.base.llms.types import ChatMessage, LLMMetadata
 from llama_index.llms.ibm import WatsonxLLM
+from llama_index.llms.ibm.base import DEFAULT_MAX_TOKENS, DEFAULT_CONTEXT_WINDOW
 
 
 def mock_return_guardrails_stats(*args) -> Dict:
@@ -163,8 +164,60 @@ class TestWasonxLLMInference:
     TEST_URL = "https://us-south.ml.cloud.ibm.com"
     TEST_APIKEY = "12345"
     TEST_PROJECT_ID = "1234"
+    TEST_DEPLOYMENT_ID = "4321"
 
     TEST_MODEL = "google/flan-ul2"
+    TEST_CONTEXT_WINDOW = 1111
+    TEST_MAX_SEQUENCE_LENGTH = 2222
+    TEST_MAX_NEW_TOKENS = 3333
+
+    CONTEXT_WINDOW_PARAMETRIZATION = [
+        pytest.param(
+            {
+                "model_limits": {
+                    "max_sequence_length": TEST_MAX_SEQUENCE_LENGTH,
+                }
+            },
+            TEST_CONTEXT_WINDOW,
+            TEST_MAX_SEQUENCE_LENGTH,
+            id="max_sequence_length_with_context_window",
+        ),
+        pytest.param(
+            {
+                "model_limits": {
+                    "max_sequence_length": TEST_MAX_SEQUENCE_LENGTH,
+                }
+            },
+            None,
+            TEST_MAX_SEQUENCE_LENGTH,
+            id="max_sequence_length_only",
+        ),
+        pytest.param(
+            {},
+            TEST_CONTEXT_WINDOW,
+            TEST_CONTEXT_WINDOW,
+            id="context_window_only",
+        ),
+        pytest.param({}, None, DEFAULT_CONTEXT_WINDOW, id="default_context_window"),
+    ]
+
+    MAX_TOKENS_PARAMETRIZATION = [
+        pytest.param(TEST_MAX_NEW_TOKENS, TEST_MAX_NEW_TOKENS, id="max_new_tokens"),
+        pytest.param(None, DEFAULT_MAX_TOKENS, id="default_max_tokens"),
+    ]
+
+    MODEL_ID_PARAMETRIZATION = [
+        pytest.param(
+            {
+                "entity": {
+                    "base_model_id": TEST_MODEL,
+                }
+            },
+            TEST_MODEL,
+            id="base_model_id",
+        ),
+        pytest.param({}, TEST_DEPLOYMENT_ID, id="deployment_id"),
+    ]
 
     def test_initialization(self) -> None:
         with pytest.raises(ValueError, match=r"^Did not find") as e_info:
@@ -332,3 +385,93 @@ class TestWasonxLLMInference:
         assert chat_responses[-1].additional_kwargs["prompt_tokens"] == 10
         assert chat_responses[-1].additional_kwargs["completion_tokens"] == 6
         assert chat_responses[-1].additional_kwargs["total_tokens"] == 16
+
+    @pytest.mark.parametrize(
+        ("get_details_result", "instance_context_window", "expected_context_window"),
+        CONTEXT_WINDOW_PARAMETRIZATION,
+    )
+    @pytest.mark.parametrize(
+        ("instance_max_new_tokens", "expected_num_output"),
+        MAX_TOKENS_PARAMETRIZATION,
+    )
+    @patch("llama_index.llms.ibm.base.ModelInference")
+    def test_model_metadata_with_provided_model_id(
+        self,
+        MockModelInference: MagicMock,
+        get_details_result,
+        instance_context_window,
+        instance_max_new_tokens,
+        expected_context_window,
+        expected_num_output,
+    ) -> None:
+        mock_instance = MockModelInference.return_value
+        mock_instance.get_details.return_value = get_details_result
+
+        watson_llm = WatsonxLLM(
+            model_id=self.TEST_MODEL,
+            project_id=self.TEST_PROJECT_ID,
+            url=self.TEST_URL,
+            apikey=self.TEST_APIKEY,
+            context_window=instance_context_window,
+            max_new_tokens=instance_max_new_tokens,
+        )
+
+        metadata = watson_llm.metadata
+
+        assert metadata == LLMMetadata(
+            context_window=expected_context_window,
+            num_output=expected_num_output,
+            model_name=self.TEST_MODEL,
+        )
+
+    @pytest.mark.parametrize(
+        (
+            "get_model_specs_result",
+            "instance_context_window",
+            "expected_context_window",
+        ),
+        CONTEXT_WINDOW_PARAMETRIZATION,
+    )
+    @pytest.mark.parametrize(
+        ("instance_max_new_tokens", "expected_num_output"),
+        MAX_TOKENS_PARAMETRIZATION,
+    )
+    @pytest.mark.parametrize(
+        ("get_details_result", "expected_model_name"),
+        MODEL_ID_PARAMETRIZATION,
+    )
+    @patch("llama_index.llms.ibm.base.ModelInference")
+    def test_model_metadata_with_provided_deployment_id(
+        self,
+        MockModelInference: MagicMock,
+        get_details_result,
+        get_model_specs_result,
+        instance_context_window,
+        instance_max_new_tokens,
+        expected_context_window,
+        expected_num_output,
+        expected_model_name,
+    ):
+        mock_instance = MockModelInference.return_value
+        mock_instance.deployment_id = self.TEST_DEPLOYMENT_ID
+        mock_instance.get_details.return_value = get_details_result
+        mock_instance._client.foundation_models.get_model_specs.return_value = (
+            get_model_specs_result
+        )
+
+        watson_llm = WatsonxLLM(
+            deployment_id=self.TEST_DEPLOYMENT_ID,
+            project_id=self.TEST_PROJECT_ID,
+            url=self.TEST_URL,
+            apikey=self.TEST_APIKEY,
+            context_window=instance_context_window,
+            max_new_tokens=instance_max_new_tokens,
+        )
+
+        metadata = watson_llm.metadata
+
+        assert metadata == LLMMetadata(
+            context_window=expected_context_window,
+            num_output=expected_num_output,
+            model_name=expected_model_name,
+        )
-- 
GitLab