From 51fc6d0e83f59287d5cd1cff1c9a80eb341e3383 Mon Sep 17 00:00:00 2001 From: Danipulok <45077699+Danipulok@users.noreply.github.com> Date: Tue, 5 Dec 2023 00:49:19 +0200 Subject: [PATCH] Improve `Vertex` hints (#9296) --- docs/examples/llm/vertex.ipynb | 31 +++++++++++++++++++++++++++++-- llama_index/llms/vertex.py | 16 ++++++++-------- llama_index/llms/vertex_utils.py | 2 +- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/docs/examples/llm/vertex.ipynb b/docs/examples/llm/vertex.ipynb index 93e722cbf..6384f7ba0 100644 --- a/docs/examples/llm/vertex.ipynb +++ b/docs/examples/llm/vertex.ipynb @@ -9,7 +9,34 @@ "## Installing Vertex AI \n", "To Install Vertex AI you need to follow the following steps\n", "* Install Vertex Cloud SDK (https://googleapis.dev/python/aiplatform/latest/index.html)\n", - "* Setup your Default Project , credentials , region\n", + "* Setup your Default Project, credentials, region\n", + "# Basic auth example for service account" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d42f4996210bdc7", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.vertex import Vertex\n", + "from google.oauth2 import service_account\n", + "\n", + "filename = \"vertex-407108-37495ce6c303.json\"\n", + "credentials: service_account.Credentials = (\n", + " service_account.Credentials.from_service_account_file(filename)\n", + ")\n", + "Vertex(\n", + " model=\"text-bison\", project=credentials.project_id, credentials=credentials\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "119bbfb7d84a593d", + "metadata": {}, + "source": [ "## Basic Usage\n", "a Basic call to the text-bison model" ] @@ -33,7 +60,7 @@ ], "source": [ "from llama_index.llms.vertex import Vertex\n", - "from llama_index.llms.base import ChatMessage, MessageRole, CompletionResponse\n", + "from llama_index.llms.base import ChatMessage, MessageRole\n", "\n", "llm = Vertex(model=\"text-bison\", temperature=0, additional_kwargs={})\n", "llm.complete(\"Hello this is a sample text\").text" diff --git a/llama_index/llms/vertex.py b/llama_index/llms/vertex.py index 7ef8c2e33..e74964c3d 100644 --- a/llama_index/llms/vertex.py +++ b/llama_index/llms/vertex.py @@ -45,14 +45,14 @@ class Vertex(LLM): default=False, description="Flag to determine if current model is a Code Model" ) _client: Any = PrivateAttr() - _chatclient: Any = PrivateAttr() + _chat_client: Any = PrivateAttr() def __init__( self, model: str = "text-bison", project: Optional[str] = None, location: Optional[str] = None, - credential: Optional[str] = None, + credentials: Optional[Any] = None, examples: Optional[Sequence[ChatMessage]] = None, temperature: float = 0.1, max_tokens: int = 512, @@ -61,7 +61,7 @@ class Vertex(LLM): additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, ) -> None: - init_vertexai(project=project, location=location, credentials=credential) + init_vertexai(project=project, location=location, credentials=credentials) additional_kwargs = additional_kwargs or {} callback_manager = callback_manager or CallbackManager([]) @@ -69,11 +69,11 @@ class Vertex(LLM): if model in CHAT_MODELS: from vertexai.language_models import ChatModel - self._chatclient = ChatModel.from_pretrained(model) + self._chat_client = ChatModel.from_pretrained(model) elif model in CODE_CHAT_MODELS: from vertexai.language_models import CodeChatModel - self._chatclient = CodeChatModel.from_pretrained(model) + self._chat_client = CodeChatModel.from_pretrained(model) iscode = True elif model in CODE_MODELS: from vertexai.language_models import CodeGenerationModel @@ -148,7 +148,7 @@ class Vertex(LLM): ) generation = completion_with_retry( - client=self._chatclient, + client=self._chat_client, prompt=question, chat=True, stream=False, @@ -195,7 +195,7 @@ class Vertex(LLM): ) response = completion_with_retry( - client=self._chatclient, + client=self._chat_client, prompt=question, chat=True, stream=True, @@ -267,7 +267,7 @@ class Vertex(LLM): ) ) generation = await acompletion_with_retry( - client=self._chatclient, + client=self._chat_client, prompt=question, chat=True, params=chat_params, diff --git a/llama_index/llms/vertex_utils.py b/llama_index/llms/vertex_utils.py index 39103741d..7f5e8fba5 100644 --- a/llama_index/llms/vertex_utils.py +++ b/llama_index/llms/vertex_utils.py @@ -97,7 +97,7 @@ async def acompletion_with_retry( def init_vertexai( project: Optional[str] = None, location: Optional[str] = None, - credentials: Optional[str] = None, + credentials: Optional[Any] = None, ) -> None: """Init vertexai. -- GitLab