diff --git a/docs/docs/examples/llm/google_genai.ipynb b/docs/docs/examples/llm/google_genai.ipynb index d4ec209e149ce975689bb8be8502cd421980a3a6..250a7becea08a39e7c2081d24fc4b729e637cadc 100644 --- a/docs/docs/examples/llm/google_genai.ipynb +++ b/docs/docs/examples/llm/google_genai.ipynb @@ -377,6 +377,9 @@ "llm = GoogleGenAI(\n", " model=\"gemini-2.0-flash\",\n", " vertexai_config={\"project\": \"your-project-id\", \"location\": \"us-central1\"},\n", + " # you should set the context window to the max input tokens for the model\n", + " context_window=200000,\n", + " max_tokens=512,\n", ")" ] }, diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py index df5d03e5b98c5a8c053a8e6b922973e0197a3253..8db162dc99109949ba27ea607618c2f872138c09 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py @@ -33,7 +33,7 @@ from llama_index.core.base.llms.types import ( ) from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr from llama_index.core.callbacks import CallbackManager -from llama_index.core.constants import DEFAULT_TEMPERATURE +from llama_index.core.constants import DEFAULT_TEMPERATURE, DEFAULT_NUM_OUTPUTS from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.llms.llm import ToolSelection @@ -88,8 +88,9 @@ class GoogleGenAI(FunctionCallingLLM): ge=0.0, le=2.0, ) - generate_kwargs: dict = Field( - default_factory=dict, description="Kwargs for generation." + context_window: Optional[int] = Field( + default=None, + description="The context window of the model. If not provided, the default context window 200000 will be used.", ) is_function_calling_model: bool = Field( default=True, description="Whether the model is a function calling model." @@ -106,13 +107,14 @@ class GoogleGenAI(FunctionCallingLLM): api_key: Optional[str] = None, temperature: float = DEFAULT_TEMPERATURE, max_tokens: Optional[int] = None, + context_window: Optional[int] = None, vertexai_config: Optional[VertexAIConfig] = None, http_options: Optional[types.HttpOptions] = None, debug_config: Optional[google.genai.client.DebugConfig] = None, generation_config: Optional[types.GenerateContentConfig] = None, callback_manager: Optional[CallbackManager] = None, is_function_calling_model: bool = True, - **generate_kwargs: Any, + **kwargs: Any, ): # API keys are optional. The API can be authorised via OAuth (detected # environmentally) or by the GOOGLE_API_KEY environment variable. @@ -149,18 +151,14 @@ class GoogleGenAI(FunctionCallingLLM): client = google.genai.Client(**config_params) model_meta = client.models.get(model=model) - if not max_tokens: - max_tokens = model_meta.output_token_limit - else: - max_tokens = min(max_tokens, model_meta.output_token_limit) super().__init__( model=model, temperature=temperature, - max_tokens=max_tokens, - generate_kwargs=generate_kwargs, + context_window=context_window, callback_manager=callback_manager, is_function_calling_model=is_function_calling_model, + **kwargs, ) self.model = model @@ -169,9 +167,16 @@ class GoogleGenAI(FunctionCallingLLM): # store this as a dict and not as a pydantic model so we can more easily # merge it later self._generation_config = ( - generation_config.model_dump() if generation_config else {} + generation_config.model_dump() + if generation_config + else types.GenerateContentConfig( + temperature=temperature, + max_output_tokens=max_tokens, + ).model_dump() + ) + self._max_tokens = ( + max_tokens or model_meta.output_token_limit or DEFAULT_NUM_OUTPUTS ) - self._max_tokens = max_tokens @classmethod def class_name(cls) -> str: @@ -179,7 +184,12 @@ class GoogleGenAI(FunctionCallingLLM): @property def metadata(self) -> LLMMetadata: - total_tokens = (self._model_meta.input_token_limit or 0) + self._max_tokens + if self.context_window is None: + base = self._model_meta.input_token_limit or 200000 + total_tokens = base + self._max_tokens + else: + total_tokens = self.context_window + return LLMMetadata( context_window=total_tokens, num_output=self._max_tokens, diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml index a0a0b6406ddf35661b037c1193f6b9b16506e4c1..782bcf74ef2fccb0734e75a96b8ed6eba0506357 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-google-genai" readme = "README.md" -version = "0.1.2" +version = "0.1.3" [tool.poetry.dependencies] python = ">=3.9,<4.0"