From e542d7ae61302c26fd64006fd47b5b25c8fb7a7d Mon Sep 17 00:00:00 2001
From: ex0ns <ex0ns@users.noreply.github.com>
Date: Thu, 27 Feb 2025 04:35:11 +0100
Subject: [PATCH] feat: add tests for gemini, fix function calling parameters
 (#17889)

---
 .../llama_index/llms/gemini/base.py           | 44 +++++----
 .../llama-index-llms-gemini/pyproject.toml    |  2 +-
 .../tests/test_llms_gemini.py                 | 90 ++++++++++++++++++-
 3 files changed, 119 insertions(+), 17 deletions(-)

diff --git a/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/base.py b/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/base.py
index 83899153ca..5318fb2e9a 100644
--- a/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/base.py
@@ -404,9 +404,11 @@ class Gemini(FunctionCallingLLM):
 
         return {
             "messages": messages,
-            "tools": ToolDict(function_declarations=tool_declarations)
-            if tool_declarations
-            else None,
+            "tools": (
+                ToolDict(function_declarations=tool_declarations)
+                if tool_declarations
+                else None
+            ),
             "tool_config": tool_config,
             **kwargs,
         }
@@ -451,9 +453,12 @@ class Gemini(FunctionCallingLLM):
         llm_kwargs = llm_kwargs or {}
         all_kwargs = {**llm_kwargs, **kwargs}
 
-        llm_kwargs["tool_choice"] = (
-            "required" if "tool_choice" not in all_kwargs else all_kwargs["tool_choice"]
-        )
+        if self._is_function_call_model:
+            llm_kwargs["tool_choice"] = (
+                "required"
+                if "tool_choice" not in all_kwargs
+                else all_kwargs["tool_choice"]
+            )
         # by default structured prediction uses function calling to extract structured outputs
         # here we force tool_choice to be required
         return super().structured_predict(*args, llm_kwargs=llm_kwargs, **kwargs)
@@ -466,9 +471,12 @@ class Gemini(FunctionCallingLLM):
         llm_kwargs = llm_kwargs or {}
         all_kwargs = {**llm_kwargs, **kwargs}
 
-        llm_kwargs["tool_choice"] = (
-            "required" if "tool_choice" not in all_kwargs else all_kwargs["tool_choice"]
-        )
+        if self._is_function_call_model:
+            llm_kwargs["tool_choice"] = (
+                "required"
+                if "tool_choice" not in all_kwargs
+                else all_kwargs["tool_choice"]
+            )
         # by default structured prediction uses function calling to extract structured outputs
         # here we force tool_choice to be required
         return await super().astructured_predict(*args, llm_kwargs=llm_kwargs, **kwargs)
@@ -481,9 +489,12 @@ class Gemini(FunctionCallingLLM):
         llm_kwargs = llm_kwargs or {}
         all_kwargs = {**llm_kwargs, **kwargs}
 
-        llm_kwargs["tool_choice"] = (
-            "required" if "tool_choice" not in all_kwargs else all_kwargs["tool_choice"]
-        )
+        if self._is_function_call_model:
+            llm_kwargs["tool_choice"] = (
+                "required"
+                if "tool_choice" not in all_kwargs
+                else all_kwargs["tool_choice"]
+            )
         # by default structured prediction uses function calling to extract structured outputs
         # here we force tool_choice to be required
         return super().stream_structured_predict(*args, llm_kwargs=llm_kwargs, **kwargs)
@@ -496,9 +507,12 @@ class Gemini(FunctionCallingLLM):
         llm_kwargs = llm_kwargs or {}
         all_kwargs = {**llm_kwargs, **kwargs}
 
-        llm_kwargs["tool_choice"] = (
-            "required" if "tool_choice" not in all_kwargs else all_kwargs["tool_choice"]
-        )
+        if self._is_function_call_model:
+            llm_kwargs["tool_choice"] = (
+                "required"
+                if "tool_choice" not in all_kwargs
+                else all_kwargs["tool_choice"]
+            )
         # by default structured prediction uses function calling to extract structured outputs
         # here we force tool_choice to be required
         return await super().astream_structured_predict(
diff --git a/llama-index-integrations/llms/llama-index-llms-gemini/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-gemini/pyproject.toml
index 15488244ef..873abca591 100644
--- a/llama-index-integrations/llms/llama-index-llms-gemini/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-gemini/pyproject.toml
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-gemini"
 readme = "README.md"
-version = "0.4.10"
+version = "0.4.11"
 
 [tool.poetry.dependencies]
 python = ">=3.9,<4.0"
diff --git a/llama-index-integrations/llms/llama-index-llms-gemini/tests/test_llms_gemini.py b/llama-index-integrations/llms/llama-index-llms-gemini/tests/test_llms_gemini.py
index 5ea3aac33d..f9f1f7ca0d 100644
--- a/llama-index-integrations/llms/llama-index-llms-gemini/tests/test_llms_gemini.py
+++ b/llama-index-integrations/llms/llama-index-llms-gemini/tests/test_llms_gemini.py
@@ -11,7 +11,7 @@ from llama_index.core.prompts.base import ChatPromptTemplate
 from llama_index.core.tools.function_tool import FunctionTool
 from llama_index.llms.gemini import Gemini
 from llama_index.llms.gemini.utils import chat_message_to_gemini
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
 
 
 def test_embedding_class() -> None:
@@ -128,3 +128,91 @@ def test_is_function_calling_model() -> None:
     manual_override._is_function_call_model = False
     assert not manual_override._is_function_call_model
     assert not manual_override.metadata.is_function_calling_model
+
+
+@pytest.mark.skipif(
+    os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
+)
+def test_structure_gen_without_function_call() -> None:
+    class Test(BaseModel):
+        test: str
+
+    gemini_flash = Gemini(
+        model="models/gemini-2.0-flash-001",
+        api_key=os.environ["GOOGLE_API_KEY"],
+    )
+    gemini_flash._is_function_call_model = False
+    output = gemini_flash.as_structured_llm(Test).complete("test")
+    assert output.raw.test
+
+
+@pytest.mark.skipif(
+    os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
+)
+def test_function_call_deeply_nested_structured_generation() -> None:
+    class Column(BaseModel):
+        name: str = Field(description="Column field")
+        data_type: str = Field(description="Data type field")
+
+    class Table(BaseModel):
+        name: str = Field(description="Table name field")
+        columns: list[Column] = Field(description="List of random Column objects")
+
+    class Schema(BaseModel):
+        schema_name: str = Field(description="Schema name")
+        columns: list[Table] = Field(description="List of random Table objects")
+
+    prompt = ChatPromptTemplate.from_messages(
+        [ChatMessage(role="user", content="Generate a simple database structure")]
+    )
+
+    gemini_flash = Gemini(
+        model="models/gemini-2.0-flash-001",
+        api_key=os.environ["GOOGLE_API_KEY"],
+    )
+    prompt = ChatPromptTemplate.from_messages(
+        [ChatMessage(role="user", content="Generate a simple database structure")]
+    )
+
+    gemini_flash._is_function_call_model = (
+        True  # this is the default, but let's be explicit
+    )
+    schema = gemini_flash.structured_predict(output_cls=Schema, prompt=prompt)
+    assert schema.columns
+    assert schema.columns[0].columns
+    assert schema.columns[0].columns[0].name
+
+
+# this is the same test as above, but with function call disabled
+@pytest.mark.skipif(
+    os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
+)
+def test_deeply_nested_structured_generation() -> None:
+    class Column(BaseModel):
+        name: str = Field(description="Column field")
+        data_type: str = Field(description="Data type field")
+
+    class Table(BaseModel):
+        name: str = Field(description="Table name field")
+        columns: list[Column] = Field(description="List of random Column objects")
+
+    class Schema(BaseModel):
+        schema_name: str = Field(description="Schema name")
+        columns: list[Table] = Field(description="List of random Table objects")
+
+    prompt = ChatPromptTemplate.from_messages(
+        [ChatMessage(role="user", content="Generate a simple database structure")]
+    )
+
+    gemini_flash = Gemini(
+        model="models/gemini-2.0-flash-001",
+        api_key=os.environ["GOOGLE_API_KEY"],
+    )
+    prompt = ChatPromptTemplate.from_messages(
+        [ChatMessage(role="user", content="Generate a simple database structure")]
+    )
+    gemini_flash._is_function_call_model = False
+    schema = gemini_flash.structured_predict(output_cls=Schema, prompt=prompt)
+    assert schema.columns
+    assert schema.columns[0].columns
+    assert schema.columns[0].columns[0].name
-- 
GitLab