From e9aa2a8569b1bd0982b142a16d86b15df47dd639 Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Wed, 12 Mar 2025 14:30:44 +0000 Subject: [PATCH] fix null model meta references (#18109) * fix null model meta references * notebook --- docs/docs/examples/llm/google_genai.ipynb | 3 ++ .../llama_index/llms/google_genai/base.py | 36 ++++++++++++------- .../pyproject.toml | 2 +- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/docs/docs/examples/llm/google_genai.ipynb b/docs/docs/examples/llm/google_genai.ipynb index d4ec209e14..250a7becea 100644 --- a/docs/docs/examples/llm/google_genai.ipynb +++ b/docs/docs/examples/llm/google_genai.ipynb @@ -377,6 +377,9 @@ "llm = GoogleGenAI(\n", " model=\"gemini-2.0-flash\",\n", " vertexai_config={\"project\": \"your-project-id\", \"location\": \"us-central1\"},\n", + " # you should set the context window to the max input tokens for the model\n", + " context_window=200000,\n", + " max_tokens=512,\n", ")" ] }, diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py index df5d03e5b9..8db162dc99 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py @@ -33,7 +33,7 @@ from llama_index.core.base.llms.types import ( ) from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr from llama_index.core.callbacks import CallbackManager -from llama_index.core.constants import DEFAULT_TEMPERATURE +from llama_index.core.constants import DEFAULT_TEMPERATURE, DEFAULT_NUM_OUTPUTS from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.llms.llm import ToolSelection @@ -88,8 +88,9 @@ class GoogleGenAI(FunctionCallingLLM): ge=0.0, le=2.0, ) - generate_kwargs: dict = Field( - default_factory=dict, description="Kwargs for generation." + context_window: Optional[int] = Field( + default=None, + description="The context window of the model. If not provided, the default context window 200000 will be used.", ) is_function_calling_model: bool = Field( default=True, description="Whether the model is a function calling model." @@ -106,13 +107,14 @@ class GoogleGenAI(FunctionCallingLLM): api_key: Optional[str] = None, temperature: float = DEFAULT_TEMPERATURE, max_tokens: Optional[int] = None, + context_window: Optional[int] = None, vertexai_config: Optional[VertexAIConfig] = None, http_options: Optional[types.HttpOptions] = None, debug_config: Optional[google.genai.client.DebugConfig] = None, generation_config: Optional[types.GenerateContentConfig] = None, callback_manager: Optional[CallbackManager] = None, is_function_calling_model: bool = True, - **generate_kwargs: Any, + **kwargs: Any, ): # API keys are optional. The API can be authorised via OAuth (detected # environmentally) or by the GOOGLE_API_KEY environment variable. @@ -149,18 +151,14 @@ class GoogleGenAI(FunctionCallingLLM): client = google.genai.Client(**config_params) model_meta = client.models.get(model=model) - if not max_tokens: - max_tokens = model_meta.output_token_limit - else: - max_tokens = min(max_tokens, model_meta.output_token_limit) super().__init__( model=model, temperature=temperature, - max_tokens=max_tokens, - generate_kwargs=generate_kwargs, + context_window=context_window, callback_manager=callback_manager, is_function_calling_model=is_function_calling_model, + **kwargs, ) self.model = model @@ -169,9 +167,16 @@ class GoogleGenAI(FunctionCallingLLM): # store this as a dict and not as a pydantic model so we can more easily # merge it later self._generation_config = ( - generation_config.model_dump() if generation_config else {} + generation_config.model_dump() + if generation_config + else types.GenerateContentConfig( + temperature=temperature, + max_output_tokens=max_tokens, + ).model_dump() + ) + self._max_tokens = ( + max_tokens or model_meta.output_token_limit or DEFAULT_NUM_OUTPUTS ) - self._max_tokens = max_tokens @classmethod def class_name(cls) -> str: @@ -179,7 +184,12 @@ class GoogleGenAI(FunctionCallingLLM): @property def metadata(self) -> LLMMetadata: - total_tokens = (self._model_meta.input_token_limit or 0) + self._max_tokens + if self.context_window is None: + base = self._model_meta.input_token_limit or 200000 + total_tokens = base + self._max_tokens + else: + total_tokens = self.context_window + return LLMMetadata( context_window=total_tokens, num_output=self._max_tokens, diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml index a0a0b6406d..782bcf74ef 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-google-genai" readme = "README.md" -version = "0.1.2" +version = "0.1.3" [tool.poetry.dependencies] python = ">=3.9,<4.0" -- GitLab