From ee057786a889332aa2147d4c903c2a41fb6bc59f Mon Sep 17 00:00:00 2001 From: David Wiszowaty <wiszowatyd@gmail.com> Date: Mon, 24 Feb 2025 04:43:47 -0600 Subject: [PATCH] fix: Chat messages with tool calls incorrectly mapping to Vertex message (#17893) --- .../llama_index/llms/vertex/gemini_utils.py | 4 +- .../llama-index-llms-vertex/pyproject.toml | 2 +- .../tests/test_gemini_utils.py | 63 +++++++++++++++++++ 3 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 llama-index-integrations/llms/llama-index-llms-vertex/tests/test_gemini_utils.py diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py index dbb0a9e962..33d87eacd1 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py @@ -46,7 +46,9 @@ def convert_chat_message_to_gemini_content( raise ValueError("Only text and image_url types are supported!") return Part.from_image(image) - if message.content == "" and "tool_calls" in message.additional_kwargs: + if ( + message.content == "" or message.content is None + ) and "tool_calls" in message.additional_kwargs: tool_calls = message.additional_kwargs["tool_calls"] parts = [ Part._from_gapic(raw_part=gapic_content_types.Part(function_call=tool_call)) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml index 9c2893a409..5447293c3b 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-vertex" readme = "README.md" -version = "0.4.2" +version = "0.4.3" [tool.poetry.dependencies] python = ">=3.9,<4.0" diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/tests/test_gemini_utils.py b/llama-index-integrations/llms/llama-index-llms-vertex/tests/test_gemini_utils.py new file mode 100644 index 0000000000..52f627e465 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-vertex/tests/test_gemini_utils.py @@ -0,0 +1,63 @@ +from google.cloud.aiplatform_v1beta1 import FunctionCall +from llama_index.core.base.llms.types import ChatMessage, MessageRole + +from llama_index.llms.vertex.gemini_utils import ( + convert_chat_message_to_gemini_content, + is_gemini_model, +) + + +def test_is_gemini_model(): + assert is_gemini_model("gemini-2.0-flash") is True + assert is_gemini_model("chat-bison") is False + + +def test_convert_chat_message_to_gemini_content_with_function_call(): + message = ChatMessage( + role=MessageRole.ASSISTANT, + content="", + additional_kwargs={ + "tool_calls": [ + FunctionCall( + name="test_fn", + args={"arg1": "val1"}, + ) + ] + }, + ) + + result = convert_chat_message_to_gemini_content(message=message, is_history=True) + + assert result.role == "model" + assert len(result.parts) == 1 + assert result.parts[0].function_call is not None + assert result.parts[0].function_call.name == "test_fn" + assert result.parts[0].function_call.args == {"arg1": "val1"} + + +def test_convert_chat_message_to_gemini_content_with_content(): + message = ChatMessage( + role=MessageRole.USER, + content="test content", + ) + + result = convert_chat_message_to_gemini_content(message=message, is_history=True) + + assert result.role == "user" + assert result.text == "test content" + assert len(result.parts) == 1 + assert result.parts[0].text == "test content" + assert result.parts[0].function_call is None + + +def test_convert_chat_message_to_gemini_content_no_history(): + message = ChatMessage( + role=MessageRole.USER, + content="test content", + ) + + result = convert_chat_message_to_gemini_content(message=message, is_history=False) + + assert len(result) == 1 + assert result[0].text == "test content" + assert result[0].function_call is None -- GitLab