From 51fc6d0e83f59287d5cd1cff1c9a80eb341e3383 Mon Sep 17 00:00:00 2001
From: Danipulok <45077699+Danipulok@users.noreply.github.com>
Date: Tue, 5 Dec 2023 00:49:19 +0200
Subject: [PATCH] Improve `Vertex` hints (#9296)

---
 docs/examples/llm/vertex.ipynb   | 31 +++++++++++++++++++++++++++++--
 llama_index/llms/vertex.py       | 16 ++++++++--------
 llama_index/llms/vertex_utils.py |  2 +-
 3 files changed, 38 insertions(+), 11 deletions(-)

diff --git a/docs/examples/llm/vertex.ipynb b/docs/examples/llm/vertex.ipynb
index 93e722cbf..6384f7ba0 100644
--- a/docs/examples/llm/vertex.ipynb
+++ b/docs/examples/llm/vertex.ipynb
@@ -9,7 +9,34 @@
     "## Installing Vertex AI \n",
     "To Install Vertex AI you need to follow the following steps\n",
     "* Install Vertex Cloud SDK (https://googleapis.dev/python/aiplatform/latest/index.html)\n",
-    "* Setup your Default Project , credentials , region\n",
+    "* Setup your Default Project, credentials, region\n",
+    "# Basic auth example for service account"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3d42f4996210bdc7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from llama_index.llms.vertex import Vertex\n",
+    "from google.oauth2 import service_account\n",
+    "\n",
+    "filename = \"vertex-407108-37495ce6c303.json\"\n",
+    "credentials: service_account.Credentials = (\n",
+    "    service_account.Credentials.from_service_account_file(filename)\n",
+    ")\n",
+    "Vertex(\n",
+    "    model=\"text-bison\", project=credentials.project_id, credentials=credentials\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "119bbfb7d84a593d",
+   "metadata": {},
+   "source": [
     "## Basic Usage\n",
     "a Basic call to the text-bison model"
    ]
@@ -33,7 +60,7 @@
    ],
    "source": [
     "from llama_index.llms.vertex import Vertex\n",
-    "from llama_index.llms.base import ChatMessage, MessageRole, CompletionResponse\n",
+    "from llama_index.llms.base import ChatMessage, MessageRole\n",
     "\n",
     "llm = Vertex(model=\"text-bison\", temperature=0, additional_kwargs={})\n",
     "llm.complete(\"Hello this is a sample text\").text"
diff --git a/llama_index/llms/vertex.py b/llama_index/llms/vertex.py
index 7ef8c2e33..e74964c3d 100644
--- a/llama_index/llms/vertex.py
+++ b/llama_index/llms/vertex.py
@@ -45,14 +45,14 @@ class Vertex(LLM):
         default=False, description="Flag to determine if current model is a Code Model"
     )
     _client: Any = PrivateAttr()
-    _chatclient: Any = PrivateAttr()
+    _chat_client: Any = PrivateAttr()
 
     def __init__(
         self,
         model: str = "text-bison",
         project: Optional[str] = None,
         location: Optional[str] = None,
-        credential: Optional[str] = None,
+        credentials: Optional[Any] = None,
         examples: Optional[Sequence[ChatMessage]] = None,
         temperature: float = 0.1,
         max_tokens: int = 512,
@@ -61,7 +61,7 @@ class Vertex(LLM):
         additional_kwargs: Optional[Dict[str, Any]] = None,
         callback_manager: Optional[CallbackManager] = None,
     ) -> None:
-        init_vertexai(project=project, location=location, credentials=credential)
+        init_vertexai(project=project, location=location, credentials=credentials)
 
         additional_kwargs = additional_kwargs or {}
         callback_manager = callback_manager or CallbackManager([])
@@ -69,11 +69,11 @@ class Vertex(LLM):
         if model in CHAT_MODELS:
             from vertexai.language_models import ChatModel
 
-            self._chatclient = ChatModel.from_pretrained(model)
+            self._chat_client = ChatModel.from_pretrained(model)
         elif model in CODE_CHAT_MODELS:
             from vertexai.language_models import CodeChatModel
 
-            self._chatclient = CodeChatModel.from_pretrained(model)
+            self._chat_client = CodeChatModel.from_pretrained(model)
             iscode = True
         elif model in CODE_MODELS:
             from vertexai.language_models import CodeGenerationModel
@@ -148,7 +148,7 @@ class Vertex(LLM):
             )
 
         generation = completion_with_retry(
-            client=self._chatclient,
+            client=self._chat_client,
             prompt=question,
             chat=True,
             stream=False,
@@ -195,7 +195,7 @@ class Vertex(LLM):
             )
 
         response = completion_with_retry(
-            client=self._chatclient,
+            client=self._chat_client,
             prompt=question,
             chat=True,
             stream=True,
@@ -267,7 +267,7 @@ class Vertex(LLM):
                 )
             )
         generation = await acompletion_with_retry(
-            client=self._chatclient,
+            client=self._chat_client,
             prompt=question,
             chat=True,
             params=chat_params,
diff --git a/llama_index/llms/vertex_utils.py b/llama_index/llms/vertex_utils.py
index 39103741d..7f5e8fba5 100644
--- a/llama_index/llms/vertex_utils.py
+++ b/llama_index/llms/vertex_utils.py
@@ -97,7 +97,7 @@ async def acompletion_with_retry(
 def init_vertexai(
     project: Optional[str] = None,
     location: Optional[str] = None,
-    credentials: Optional[str] = None,
+    credentials: Optional[Any] = None,
 ) -> None:
     """Init vertexai.
 
-- 
GitLab