Skip to content
Snippets Groups Projects
Unverified Commit 27f76910 authored by Yuzhong Zhang's avatar Yuzhong Zhang Committed by GitHub
Browse files

Fix api_base undefined bug in Gemini embeddings (#11393)


* Support Gemini "transport" configuration

Added Gemini transportation method configuration support.

* Sync updates in multi_modal_llms\gemini

* Updated Dashscope qwen llm defaults

Setting qwen default num_outputs and temperature

* cr

* support gemini embedding configuration

support configuring api_base, api_key, transport method

* fix gptrepo data connector encoding issue

reading a file in default encoding(GBK) will cause error characters problem. Added encoding configuration

* sync latest repo

* sync latest repo

* cr

* cr

* Fix api_base undefined bug in Gemini embeddings

* add comments

* fix linter test

* sync fix in integrations/embeddings

* fix unit test

---------

Co-authored-by: default avatarHaotian Zhang <socool.king@gmail.com>
parent e6e9abd9
Branches
Tags
No related merge requests found
"""Gemini embeddings file."""
import os
from typing import Any, List, Optional
import google.generativeai as gemini
......@@ -19,6 +20,8 @@ class GeminiEmbedding(BaseEmbedding):
Defaults to "models/embedding-001".
api_key (Optional[str]): API key to access the model. Defaults to None.
api_base (Optional[str]): API base to access the model. Defaults to Official Base.
transport (Optional[str]): Transport to access the model.
"""
_model: Any = PrivateAttr()
......@@ -36,12 +39,24 @@ class GeminiEmbedding(BaseEmbedding):
model_name: str = "models/embedding-001",
task_type: Optional[str] = "retrieval_document",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
transport: Optional[str] = None,
title: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
):
gemini.configure(api_key=api_key)
# API keys are optional. The API can be authorised via OAuth (detected
# environmentally) or by the GOOGLE_API_KEY environment variable.
config_params: Dict[str, Any] = {
"api_key": api_key or os.getenv("GOOGLE_API_KEY"),
}
if api_base:
config_params["client_options"] = {"api_endpoint": api_base}
if transport:
config_params["transport"] = transport
# transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
gemini.configure(**config_params)
self._model = gemini
super().__init__(
......
......@@ -18,6 +18,8 @@ class GeminiEmbedding(BaseEmbedding):
Defaults to "models/embedding-001".
api_key (Optional[str]): API key to access the model. Defaults to None.
api_base (Optional[str]): API base to access the model. Defaults to Official Base.
transport (Optional[str]): Transport to access the model.
"""
_model: Any = PrivateAttr()
......@@ -35,6 +37,8 @@ class GeminiEmbedding(BaseEmbedding):
model_name: str = "models/embedding-001",
task_type: Optional[str] = "retrieval_document",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
transport: Optional[str] = None,
title: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment