From c12d80e2c14f640927c64d14dd1179b786d395ec Mon Sep 17 00:00:00 2001 From: Rafal-Chrzanowski-IBM <Rafal.Chrzanowski@ibm.com> Date: Thu, 27 Feb 2025 23:15:25 +0100 Subject: [PATCH] Update `WatsonxLLM.metadata` property to avoid validation error when `model_limits` field isn't present (#17839) --- .../llama_index/llms/ibm/base.py | 24 ++- .../llms/llama-index-llms-ibm/pyproject.toml | 2 +- .../llama-index-llms-ibm/tests/test_ibm.py | 145 +++++++++++++++++- 3 files changed, 156 insertions(+), 15 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-ibm/llama_index/llms/ibm/base.py b/llama-index-integrations/llms/llama-index-llms-ibm/llama_index/llms/ibm/base.py index 91f789688a..21b515bfef 100644 --- a/llama-index-integrations/llms/llama-index-llms-ibm/llama_index/llms/ibm/base.py +++ b/llama-index-integrations/llms/llama-index-llms-ibm/llama_index/llms/ibm/base.py @@ -291,12 +291,9 @@ class WatsonxLLM(FunctionCallingLLM): @property def metadata(self) -> LLMMetadata: if self.model_id: - return LLMMetadata( - context_window=( - self.model_info.get("model_limits", {}).get("max_sequence_length") - ), - num_output=(self.max_new_tokens or DEFAULT_MAX_TOKENS), - model_name=self.model_id, + model_id = self.model_id + context_window = self.model_info.get("model_limits", {}).get( + "max_sequence_length" ) else: model_id = self.deployment_info.get("entity", {}).get("base_model_id") @@ -305,13 +302,14 @@ class WatsonxLLM(FunctionCallingLLM): .get("model_limits", {}) .get("max_sequence_length") ) - return LLMMetadata( - context_window=context_window - or self._context_window - or DEFAULT_CONTEXT_WINDOW, - num_output=(self.max_new_tokens or DEFAULT_MAX_TOKENS), - model_name=model_id or self._model.deployment_id, - ) + + return LLMMetadata( + context_window=context_window + or self._context_window + or DEFAULT_CONTEXT_WINDOW, + num_output=self.max_new_tokens or DEFAULT_MAX_TOKENS, + model_name=model_id or self._model.deployment_id, + ) @property def sample_generation_text_params(self) -> Dict[str, Any]: diff --git a/llama-index-integrations/llms/llama-index-llms-ibm/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-ibm/pyproject.toml index 5730f562b4..bd264bd238 100644 --- a/llama-index-integrations/llms/llama-index-llms-ibm/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-ibm/pyproject.toml @@ -31,7 +31,7 @@ license = "MIT" name = "llama-index-llms-ibm" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.3.2" +version = "0.3.3" [tool.poetry.dependencies] python = ">=3.10,<3.13" diff --git a/llama-index-integrations/llms/llama-index-llms-ibm/tests/test_ibm.py b/llama-index-integrations/llms/llama-index-llms-ibm/tests/test_ibm.py index 892c138040..f300b5de5f 100644 --- a/llama-index-integrations/llms/llama-index-llms-ibm/tests/test_ibm.py +++ b/llama-index-integrations/llms/llama-index-llms-ibm/tests/test_ibm.py @@ -3,8 +3,9 @@ from unittest.mock import MagicMock, patch import warnings import pytest -from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.base.llms.types import ChatMessage, LLMMetadata from llama_index.llms.ibm import WatsonxLLM +from llama_index.llms.ibm.base import DEFAULT_MAX_TOKENS, DEFAULT_CONTEXT_WINDOW def mock_return_guardrails_stats(*args) -> Dict: @@ -163,8 +164,60 @@ class TestWasonxLLMInference: TEST_URL = "https://us-south.ml.cloud.ibm.com" TEST_APIKEY = "12345" TEST_PROJECT_ID = "1234" + TEST_DEPLOYMENT_ID = "4321" TEST_MODEL = "google/flan-ul2" + TEST_CONTEXT_WINDOW = 1111 + TEST_MAX_SEQUENCE_LENGTH = 2222 + TEST_MAX_NEW_TOKENS = 3333 + + CONTEXT_WINDOW_PARAMETRIZATION = [ + pytest.param( + { + "model_limits": { + "max_sequence_length": TEST_MAX_SEQUENCE_LENGTH, + } + }, + TEST_CONTEXT_WINDOW, + TEST_MAX_SEQUENCE_LENGTH, + id="max_sequence_length_with_context_window", + ), + pytest.param( + { + "model_limits": { + "max_sequence_length": TEST_MAX_SEQUENCE_LENGTH, + } + }, + None, + TEST_MAX_SEQUENCE_LENGTH, + id="max_sequence_length_only", + ), + pytest.param( + {}, + TEST_CONTEXT_WINDOW, + TEST_CONTEXT_WINDOW, + id="context_window_only", + ), + pytest.param({}, None, DEFAULT_CONTEXT_WINDOW, id="default_context_window"), + ] + + MAX_TOKENS_PARAMETRIZATION = [ + pytest.param(TEST_MAX_NEW_TOKENS, TEST_MAX_NEW_TOKENS, id="max_new_tokens"), + pytest.param(None, DEFAULT_MAX_TOKENS, id="default_max_tokens"), + ] + + MODEL_ID_PARAMETRIZATION = [ + pytest.param( + { + "entity": { + "base_model_id": TEST_MODEL, + } + }, + TEST_MODEL, + id="base_model_id", + ), + pytest.param({}, TEST_DEPLOYMENT_ID, id="deployment_id"), + ] def test_initialization(self) -> None: with pytest.raises(ValueError, match=r"^Did not find") as e_info: @@ -332,3 +385,93 @@ class TestWasonxLLMInference: assert chat_responses[-1].additional_kwargs["prompt_tokens"] == 10 assert chat_responses[-1].additional_kwargs["completion_tokens"] == 6 assert chat_responses[-1].additional_kwargs["total_tokens"] == 16 + + @pytest.mark.parametrize( + ("get_details_result", "instance_context_window", "expected_context_window"), + CONTEXT_WINDOW_PARAMETRIZATION, + ) + @pytest.mark.parametrize( + ("instance_max_new_tokens", "expected_num_output"), + MAX_TOKENS_PARAMETRIZATION, + ) + @patch("llama_index.llms.ibm.base.ModelInference") + def test_model_metadata_with_provided_model_id( + self, + MockModelInference: MagicMock, + get_details_result, + instance_context_window, + instance_max_new_tokens, + expected_context_window, + expected_num_output, + ) -> None: + mock_instance = MockModelInference.return_value + mock_instance.get_details.return_value = get_details_result + + watson_llm = WatsonxLLM( + model_id=self.TEST_MODEL, + project_id=self.TEST_PROJECT_ID, + url=self.TEST_URL, + apikey=self.TEST_APIKEY, + context_window=instance_context_window, + max_new_tokens=instance_max_new_tokens, + ) + + metadata = watson_llm.metadata + + assert metadata == LLMMetadata( + context_window=expected_context_window, + num_output=expected_num_output, + model_name=self.TEST_MODEL, + ) + + @pytest.mark.parametrize( + ( + "get_model_specs_result", + "instance_context_window", + "expected_context_window", + ), + CONTEXT_WINDOW_PARAMETRIZATION, + ) + @pytest.mark.parametrize( + ("instance_max_new_tokens", "expected_num_output"), + MAX_TOKENS_PARAMETRIZATION, + ) + @pytest.mark.parametrize( + ("get_details_result", "expected_model_name"), + MODEL_ID_PARAMETRIZATION, + ) + @patch("llama_index.llms.ibm.base.ModelInference") + def test_model_metadata_with_provided_deployment_id( + self, + MockModelInference: MagicMock, + get_details_result, + get_model_specs_result, + instance_context_window, + instance_max_new_tokens, + expected_context_window, + expected_num_output, + expected_model_name, + ): + mock_instance = MockModelInference.return_value + mock_instance.deployment_id = self.TEST_DEPLOYMENT_ID + mock_instance.get_details.return_value = get_details_result + mock_instance._client.foundation_models.get_model_specs.return_value = ( + get_model_specs_result + ) + + watson_llm = WatsonxLLM( + deployment_id=self.TEST_DEPLOYMENT_ID, + project_id=self.TEST_PROJECT_ID, + url=self.TEST_URL, + apikey=self.TEST_APIKEY, + context_window=instance_context_window, + max_new_tokens=instance_max_new_tokens, + ) + + metadata = watson_llm.metadata + + assert metadata == LLMMetadata( + context_window=expected_context_window, + num_output=expected_num_output, + model_name=expected_model_name, + ) -- GitLab