From e542d7ae61302c26fd64006fd47b5b25c8fb7a7d Mon Sep 17 00:00:00 2001 From: ex0ns <ex0ns@users.noreply.github.com> Date: Thu, 27 Feb 2025 04:35:11 +0100 Subject: [PATCH] feat: add tests for gemini, fix function calling parameters (#17889) --- .../llama_index/llms/gemini/base.py | 44 +++++---- .../llama-index-llms-gemini/pyproject.toml | 2 +- .../tests/test_llms_gemini.py | 90 ++++++++++++++++++- 3 files changed, 119 insertions(+), 17 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/base.py b/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/base.py index 83899153ca..5318fb2e9a 100644 --- a/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/base.py +++ b/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/base.py @@ -404,9 +404,11 @@ class Gemini(FunctionCallingLLM): return { "messages": messages, - "tools": ToolDict(function_declarations=tool_declarations) - if tool_declarations - else None, + "tools": ( + ToolDict(function_declarations=tool_declarations) + if tool_declarations + else None + ), "tool_config": tool_config, **kwargs, } @@ -451,9 +453,12 @@ class Gemini(FunctionCallingLLM): llm_kwargs = llm_kwargs or {} all_kwargs = {**llm_kwargs, **kwargs} - llm_kwargs["tool_choice"] = ( - "required" if "tool_choice" not in all_kwargs else all_kwargs["tool_choice"] - ) + if self._is_function_call_model: + llm_kwargs["tool_choice"] = ( + "required" + if "tool_choice" not in all_kwargs + else all_kwargs["tool_choice"] + ) # by default structured prediction uses function calling to extract structured outputs # here we force tool_choice to be required return super().structured_predict(*args, llm_kwargs=llm_kwargs, **kwargs) @@ -466,9 +471,12 @@ class Gemini(FunctionCallingLLM): llm_kwargs = llm_kwargs or {} all_kwargs = {**llm_kwargs, **kwargs} - llm_kwargs["tool_choice"] = ( - "required" if "tool_choice" not in all_kwargs else all_kwargs["tool_choice"] - ) + if self._is_function_call_model: + llm_kwargs["tool_choice"] = ( + "required" + if "tool_choice" not in all_kwargs + else all_kwargs["tool_choice"] + ) # by default structured prediction uses function calling to extract structured outputs # here we force tool_choice to be required return await super().astructured_predict(*args, llm_kwargs=llm_kwargs, **kwargs) @@ -481,9 +489,12 @@ class Gemini(FunctionCallingLLM): llm_kwargs = llm_kwargs or {} all_kwargs = {**llm_kwargs, **kwargs} - llm_kwargs["tool_choice"] = ( - "required" if "tool_choice" not in all_kwargs else all_kwargs["tool_choice"] - ) + if self._is_function_call_model: + llm_kwargs["tool_choice"] = ( + "required" + if "tool_choice" not in all_kwargs + else all_kwargs["tool_choice"] + ) # by default structured prediction uses function calling to extract structured outputs # here we force tool_choice to be required return super().stream_structured_predict(*args, llm_kwargs=llm_kwargs, **kwargs) @@ -496,9 +507,12 @@ class Gemini(FunctionCallingLLM): llm_kwargs = llm_kwargs or {} all_kwargs = {**llm_kwargs, **kwargs} - llm_kwargs["tool_choice"] = ( - "required" if "tool_choice" not in all_kwargs else all_kwargs["tool_choice"] - ) + if self._is_function_call_model: + llm_kwargs["tool_choice"] = ( + "required" + if "tool_choice" not in all_kwargs + else all_kwargs["tool_choice"] + ) # by default structured prediction uses function calling to extract structured outputs # here we force tool_choice to be required return await super().astream_structured_predict( diff --git a/llama-index-integrations/llms/llama-index-llms-gemini/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-gemini/pyproject.toml index 15488244ef..873abca591 100644 --- a/llama-index-integrations/llms/llama-index-llms-gemini/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-gemini/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-gemini" readme = "README.md" -version = "0.4.10" +version = "0.4.11" [tool.poetry.dependencies] python = ">=3.9,<4.0" diff --git a/llama-index-integrations/llms/llama-index-llms-gemini/tests/test_llms_gemini.py b/llama-index-integrations/llms/llama-index-llms-gemini/tests/test_llms_gemini.py index 5ea3aac33d..f9f1f7ca0d 100644 --- a/llama-index-integrations/llms/llama-index-llms-gemini/tests/test_llms_gemini.py +++ b/llama-index-integrations/llms/llama-index-llms-gemini/tests/test_llms_gemini.py @@ -11,7 +11,7 @@ from llama_index.core.prompts.base import ChatPromptTemplate from llama_index.core.tools.function_tool import FunctionTool from llama_index.llms.gemini import Gemini from llama_index.llms.gemini.utils import chat_message_to_gemini -from pydantic import BaseModel +from pydantic import BaseModel, Field def test_embedding_class() -> None: @@ -128,3 +128,91 @@ def test_is_function_calling_model() -> None: manual_override._is_function_call_model = False assert not manual_override._is_function_call_model assert not manual_override.metadata.is_function_calling_model + + +@pytest.mark.skipif( + os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set" +) +def test_structure_gen_without_function_call() -> None: + class Test(BaseModel): + test: str + + gemini_flash = Gemini( + model="models/gemini-2.0-flash-001", + api_key=os.environ["GOOGLE_API_KEY"], + ) + gemini_flash._is_function_call_model = False + output = gemini_flash.as_structured_llm(Test).complete("test") + assert output.raw.test + + +@pytest.mark.skipif( + os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set" +) +def test_function_call_deeply_nested_structured_generation() -> None: + class Column(BaseModel): + name: str = Field(description="Column field") + data_type: str = Field(description="Data type field") + + class Table(BaseModel): + name: str = Field(description="Table name field") + columns: list[Column] = Field(description="List of random Column objects") + + class Schema(BaseModel): + schema_name: str = Field(description="Schema name") + columns: list[Table] = Field(description="List of random Table objects") + + prompt = ChatPromptTemplate.from_messages( + [ChatMessage(role="user", content="Generate a simple database structure")] + ) + + gemini_flash = Gemini( + model="models/gemini-2.0-flash-001", + api_key=os.environ["GOOGLE_API_KEY"], + ) + prompt = ChatPromptTemplate.from_messages( + [ChatMessage(role="user", content="Generate a simple database structure")] + ) + + gemini_flash._is_function_call_model = ( + True # this is the default, but let's be explicit + ) + schema = gemini_flash.structured_predict(output_cls=Schema, prompt=prompt) + assert schema.columns + assert schema.columns[0].columns + assert schema.columns[0].columns[0].name + + +# this is the same test as above, but with function call disabled +@pytest.mark.skipif( + os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set" +) +def test_deeply_nested_structured_generation() -> None: + class Column(BaseModel): + name: str = Field(description="Column field") + data_type: str = Field(description="Data type field") + + class Table(BaseModel): + name: str = Field(description="Table name field") + columns: list[Column] = Field(description="List of random Column objects") + + class Schema(BaseModel): + schema_name: str = Field(description="Schema name") + columns: list[Table] = Field(description="List of random Table objects") + + prompt = ChatPromptTemplate.from_messages( + [ChatMessage(role="user", content="Generate a simple database structure")] + ) + + gemini_flash = Gemini( + model="models/gemini-2.0-flash-001", + api_key=os.environ["GOOGLE_API_KEY"], + ) + prompt = ChatPromptTemplate.from_messages( + [ChatMessage(role="user", content="Generate a simple database structure")] + ) + gemini_flash._is_function_call_model = False + schema = gemini_flash.structured_predict(output_cls=Schema, prompt=prompt) + assert schema.columns + assert schema.columns[0].columns + assert schema.columns[0].columns[0].name -- GitLab