From 7ef27694ce26321987e0ed1922b8997ad8c51df5 Mon Sep 17 00:00:00 2001
From: Alex Sherstinsky <alexsherstinsky@users.noreply.github.com>
Date: Thu, 11 Apr 2024 09:20:19 -0700
Subject: [PATCH] [BUGFIX] Update LlamaIndex-Predibase Integration (#12736)

---
 docs/docs/examples/llm/predibase.ipynb        | 12 ++-
 .../llama_index/llms/predibase/base.py        | 78 ++++++++++++++++---
 .../llama-index-llms-predibase/pyproject.toml |  2 +-
 3 files changed, 77 insertions(+), 15 deletions(-)

diff --git a/docs/docs/examples/llm/predibase.ipynb b/docs/docs/examples/llm/predibase.ipynb
index dffa20a060..25902a1559 100644
--- a/docs/docs/examples/llm/predibase.ipynb
+++ b/docs/docs/examples/llm/predibase.ipynb
@@ -77,9 +77,14 @@
    "outputs": [],
    "source": [
     "llm = PredibaseLLM(\n",
-    "    model_name=\"llama-2-13b\", temperature=0.3, max_new_tokens=512\n",
+    "    model_name=\"mistral-7b\",\n",
+    "    adapter_id=\"predibase/e2e_nlg\",  # adapter_id is optional\n",
+    "    temperature=0.3,\n",
+    "    max_new_tokens=512,\n",
     ")\n",
-    "# You can query any HuggingFace or fine-tuned LLM that's hosted on Predibase"
+    "# The `model_name` parameter is the Predibase \"serverless\" base_model ID\n",
+    "# (see https://docs.predibase.com/user-guide/inference/models for the catalog).\n",
+    "# You can also optionally specify a fine-tuned adapter that's hosted on HuggingFace"
    ]
   },
   {
@@ -167,7 +172,8 @@
    "outputs": [],
    "source": [
     "llm = PredibaseLLM(\n",
-    "    model_name=\"llama-2-13b\",\n",
+    "    model_name=\"mistral-7b\",\n",
+    "    adapter_id=\"predibase/e2e_nlg\",  # adapter_id is optional\n",
     "    temperature=0.3,\n",
     "    max_new_tokens=400,\n",
     "    context_window=1024,\n",
diff --git a/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py b/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py
index d16576dbc5..ff8d7b8f0d 100644
--- a/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py
@@ -1,5 +1,6 @@
 import os
-from typing import Any, Callable, Optional, Sequence
+from typing import Any, Callable, Dict, Optional, Sequence, Union
+import copy
 
 from llama_index.core.base.llms.types import (
     ChatMessage,
@@ -22,6 +23,16 @@ from llama_index.core.types import BaseOutputParser, PydanticProgramMode
 class PredibaseLLM(CustomLLM):
     """Predibase LLM.
 
+    To use, you should have the ``predibase`` python package installed,
+    and have your Predibase API key.
+
+    The `model_name` parameter is the Predibase "serverless" base_model ID
+    (see https://docs.predibase.com/user-guide/inference/models for the catalog).
+
+    An optional `adapter_id` parameter is the HuggingFace ID of a fine-tuned LLM
+    adapter, whose base model is the `model` parameter; the fine-tuned adapter
+    must be compatible with its base model; otherwise, an error is raised.
+
     Examples:
         `pip install llama-index-llms-predibase`
 
@@ -33,15 +44,22 @@ class PredibaseLLM(CustomLLM):
         from llama_index.llms.predibase import PredibaseLLM
 
         llm = PredibaseLLM(
-            model_name="llama-2-13b", temperature=0.3, max_new_tokens=512
+            model_name="mistral-7b",
+            adapter_id="my-repo/my-adapter",  # optional parameter
+            temperature=0.3,
+            max_new_tokens=512,
         )
         response = llm.complete("Hello World!")
         print(str(response))
         ```
     """
 
-    model_name: str = Field(description="The Predibase model to use.")
+    model_name: str = Field(description="The Predibase base model to use.")
     predibase_api_key: str = Field(description="The Predibase API key to use.")
+    adapter_id: str = Field(
+        default=None,
+        description="The optional HuggingFace ID of a fine-tuned adapter to use.",
+    )
     max_new_tokens: int = Field(
         default=DEFAULT_NUM_OUTPUTS,
         description="The number of tokens to generate.",
@@ -65,6 +83,7 @@ class PredibaseLLM(CustomLLM):
         self,
         model_name: str,
         predibase_api_key: Optional[str] = None,
+        adapter_id: Optional[str] = None,
         max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
         temperature: float = DEFAULT_TEMPERATURE,
         context_window: int = DEFAULT_CONTEXT_WINDOW,
@@ -82,10 +101,9 @@ class PredibaseLLM(CustomLLM):
         )
         assert predibase_api_key is not None
 
-        self._client = self.initialize_client(predibase_api_key)
-
         super().__init__(
             model_name=model_name,
+            adapter_id=adapter_id,
             predibase_api_key=predibase_api_key,
             max_new_tokens=max_new_tokens,
             temperature=temperature,
@@ -98,12 +116,21 @@ class PredibaseLLM(CustomLLM):
             output_parser=output_parser,
         )
 
+        self._client = self.initialize_client(predibase_api_key)
+
     @staticmethod
     def initialize_client(predibase_api_key: str) -> Any:
         try:
             from predibase import PredibaseClient
-
-            return PredibaseClient(token=predibase_api_key)
+            from predibase.pql import get_session
+            from predibase.pql.api import Session
+
+            session: Session = get_session(
+                token=predibase_api_key,
+                gateway="https://api.app.predibase.com/v1",
+                serving_endpoint="serving.app.predibase.com",
+            )
+            return PredibaseClient(session=session)
         except ImportError as e:
             raise ImportError(
                 "Could not import Predibase Python package. "
@@ -129,11 +156,40 @@ class PredibaseLLM(CustomLLM):
     def complete(
         self, prompt: str, formatted: bool = False, **kwargs: Any
     ) -> "CompletionResponse":
-        llm = self._client.LLM(f"pb://deployments/{self.model_name}")
-        results = llm.prompt(
-            prompt, max_new_tokens=self.max_new_tokens, temperature=self.temperature
+        from predibase.resource.llm.interface import (
+            HuggingFaceLLM,
+            LLMDeployment,
+        )
+        from predibase.resource.llm.response import GeneratedResponse
+
+        base_llm_deployment: LLMDeployment = self._client.LLM(
+            uri=f"pb://deployments/{self.model_name}"
         )
-        return CompletionResponse(text=results.response)
+
+        options: Dict[str, Union[str, float]] = copy.deepcopy(kwargs)
+        options.update(
+            {
+                "max_new_tokens": self.max_new_tokens,
+                "temperature": self.temperature,
+            }
+        )
+
+        result: GeneratedResponse
+        if self.adapter_id:
+            adapter_model: HuggingFaceLLM = self._client.LLM(
+                uri=f"hf://{self.adapter_id}"
+            )
+            result = base_llm_deployment.with_adapter(model=adapter_model).generate(
+                prompt=prompt,
+                options=options,
+            )
+        else:
+            result = base_llm_deployment.generate(
+                prompt=prompt,
+                options=options,
+            )
+
+        return CompletionResponse(text=result.response)
 
     @llm_completion_callback()
     def stream_complete(
diff --git a/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml
index 43015e6938..422215ee45 100644
--- a/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-predibase"
 readme = "README.md"
-version = "0.1.2"
+version = "0.1.3"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
-- 
GitLab