From cee23c9da888b6bbd09a05e544686aa3c7ded689 Mon Sep 17 00:00:00 2001
From: Alex Sherstinsky <alexsherstinsky@users.noreply.github.com>
Date: Tue, 23 Apr 2024 17:35:18 -0700
Subject: [PATCH] [FEATURE] Support both Predibase SDK-v1 and SDK-v2 (#13066)

---
 docs/docs/examples/llm/predibase.ipynb        |   6 +-
 .../llama_index/llms/predibase/base.py        | 204 +++++++++++++-----
 .../llama-index-llms-predibase/pyproject.toml |   2 +-
 3 files changed, 155 insertions(+), 57 deletions(-)

diff --git a/docs/docs/examples/llm/predibase.ipynb b/docs/docs/examples/llm/predibase.ipynb
index 26e3cad282..9725c27beb 100644
--- a/docs/docs/examples/llm/predibase.ipynb
+++ b/docs/docs/examples/llm/predibase.ipynb
@@ -79,6 +79,7 @@
     "# Predibase-hosted fine-tuned adapter example\n",
     "llm = PredibaseLLM(\n",
     "    model_name=\"mistral-7b\",\n",
+    "    predibase_sdk_version=None,  # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
     "    adapter_id=\"e2e_nlg\",  # adapter_id is optional\n",
     "    adapter_version=1,  # optional parameter (applies to Predibase only)\n",
     "    temperature=0.3,\n",
@@ -87,7 +88,7 @@
     "# The `model_name` parameter is the Predibase \"serverless\" base_model ID\n",
     "# (see https://docs.predibase.com/user-guide/inference/models for the catalog).\n",
     "# You can also optionally specify a fine-tuned adapter that's hosted on Predibase or HuggingFace\n",
-    "# In the case of Predibase-hosted adapters, you can also specify the adapter_version (assumed latest if omitted)"
+    "# In the case of Predibase-hosted adapters, you must also specify the adapter_version"
    ]
   },
   {
@@ -100,6 +101,7 @@
     "# HuggingFace-hosted fine-tuned adapter example\n",
     "llm = PredibaseLLM(\n",
     "    model_name=\"mistral-7b\",\n",
+    "    predibase_sdk_version=None,  # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
     "    adapter_id=\"predibase/e2e_nlg\",  # adapter_id is optional\n",
     "    temperature=0.3,\n",
     "    max_new_tokens=512,\n",
@@ -197,6 +199,7 @@
     "# Predibase-hosted fine-tuned adapter\n",
     "llm = PredibaseLLM(\n",
     "    model_name=\"mistral-7b\",\n",
+    "    predibase_sdk_version=None,  # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
     "    adapter_id=\"e2e_nlg\",  # adapter_id is optional\n",
     "    temperature=0.3,\n",
     "    context_window=1024,\n",
@@ -213,6 +216,7 @@
     "# HuggingFace-hosted fine-tuned adapter\n",
     "llm = PredibaseLLM(\n",
     "    model_name=\"mistral-7b\",\n",
+    "    predibase_sdk_version=None,  # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
     "    adapter_id=\"predibase/e2e_nlg\",  # adapter_id is optional\n",
     "    temperature=0.3,\n",
     "    context_window=1024,\n",
diff --git a/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py b/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py
index cfa1dd717e..8e28e7e5fd 100644
--- a/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py
@@ -33,7 +33,7 @@ class PredibaseLLM(CustomLLM):
     of a fine-tuned LLM adapter, whose base model is the `model` parameter; the
     fine-tuned adapter must be compatible with its base model; otherwise, an
     error is raised.  If the fine-tuned adapter is hosted at Predibase,
-    `adapter_version` can be specified (omitting it gives the latest version).
+    `adapter_version` must be specified.
 
     Examples:
         `pip install llama-index-llms-predibase`
@@ -47,6 +47,7 @@ class PredibaseLLM(CustomLLM):
 
         llm = PredibaseLLM(
             model_name="mistral-7b",
+            predibase_sdk_version=None,  # optional parameter (defaults to the latest Predibase SDK version if omitted)
             adapter_id="my-adapter-id",  # optional parameter
             adapter_version=3,  # optional parameter (applies to Predibase only)
             temperature=0.3,
@@ -59,6 +60,10 @@ class PredibaseLLM(CustomLLM):
 
     model_name: str = Field(description="The Predibase base model to use.")
     predibase_api_key: str = Field(description="The Predibase API key to use.")
+    predibase_sdk_version: str = Field(
+        default=None,
+        description="The optional version (string) of the Predibase SDK (defaults to the latest if not specified).",
+    )
     adapter_id: str = Field(
         default=None,
         description="The optional Predibase ID or HuggingFace ID of a fine-tuned adapter to use.",
@@ -90,6 +95,7 @@ class PredibaseLLM(CustomLLM):
         self,
         model_name: str,
         predibase_api_key: Optional[str] = None,
+        predibase_sdk_version: Optional[str] = None,
         adapter_id: Optional[str] = None,
         adapter_version: Optional[int] = None,
         max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
@@ -111,9 +117,10 @@ class PredibaseLLM(CustomLLM):
 
         super().__init__(
             model_name=model_name,
+            predibase_api_key=predibase_api_key,
+            predibase_sdk_version=predibase_sdk_version,
             adapter_id=adapter_id,
             adapter_version=adapter_version,
-            predibase_api_key=predibase_api_key,
             max_new_tokens=max_new_tokens,
             temperature=temperature,
             context_window=context_window,
@@ -125,26 +132,28 @@ class PredibaseLLM(CustomLLM):
             output_parser=output_parser,
         )
 
-        self._client = self.initialize_client(predibase_api_key)
+        self._client: Union["PredibaseClient", "Predibase"] = self.initialize_client()
 
-    @staticmethod
-    def initialize_client(predibase_api_key: str) -> Any:
+    def initialize_client(
+        self,
+    ) -> Union["PredibaseClient", "Predibase"]:
         try:
-            from predibase import PredibaseClient
-            from predibase.pql import get_session
-            from predibase.pql.api import Session
-
-            session: Session = get_session(
-                token=predibase_api_key,
-                gateway="https://api.app.predibase.com/v1",
-                serving_endpoint="serving.app.predibase.com",
-            )
-            return PredibaseClient(session=session)
-        except ImportError as e:
-            raise ImportError(
-                "Could not import Predibase Python package. "
-                "Please install it with `pip install predibase`."
-            ) from e
+            if self._is_deprecated_sdk_version():
+                from predibase import PredibaseClient
+                from predibase.pql import get_session
+                from predibase.pql.api import Session
+
+                session: Session = get_session(
+                    token=self.predibase_api_key,
+                    gateway="https://api.app.predibase.com/v1",
+                    serving_endpoint="serving.app.predibase.com",
+                )
+                return PredibaseClient(session=session)
+
+            from predibase import Predibase
+
+            os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com"
+            return Predibase(api_token=self.predibase_api_key)
         except ValueError as e:
             raise ValueError("Your API key is not correct. Please try again") from e
 
@@ -165,18 +174,6 @@ class PredibaseLLM(CustomLLM):
     def complete(
         self, prompt: str, formatted: bool = False, **kwargs: Any
     ) -> "CompletionResponse":
-        from predibase.pql.api import ServerResponseError
-        from predibase.resource.llm.interface import (
-            HuggingFaceLLM,
-            LLMDeployment,
-        )
-        from predibase.resource.llm.response import GeneratedResponse
-        from predibase.resource.model import Model
-
-        base_llm_deployment: LLMDeployment = self._client.LLM(
-            uri=f"pb://deployments/{self.model_name}"
-        )
-
         options: Dict[str, Union[str, float]] = copy.deepcopy(kwargs)
         options.update(
             {
@@ -185,36 +182,133 @@ class PredibaseLLM(CustomLLM):
             }
         )
 
-        result: GeneratedResponse
-        if self.adapter_id:
-            """
-            Attempt to retrieve the fine-tuned adapter from a Predibase repository.
-            If absent, then load the fine-tuned adapter from a HuggingFace repository.
-            """
-            adapter_model: Union[Model, HuggingFaceLLM]
-            try:
-                adapter_model = self._client.get_model(
-                    name=self.adapter_id,
-                    version=self.adapter_version,
-                    model_id=None,
-                )
-            except ServerResponseError:
-                # Predibase does not recognize the adapter ID (query HuggingFace).
-                adapter_model = self._client.LLM(uri=f"hf://{self.adapter_id}")
-            result = base_llm_deployment.with_adapter(model=adapter_model).generate(
-                prompt=prompt,
-                options=options,
+        response_text: str
+
+        if self._is_deprecated_sdk_version():
+            from predibase.pql.api import ServerResponseError
+            from predibase.resource.llm.interface import (
+                HuggingFaceLLM,
+                LLMDeployment,
+            )
+            from predibase.resource.llm.response import GeneratedResponse
+            from predibase.resource.model import Model
+
+            base_llm_deployment: LLMDeployment = self._client.LLM(
+                uri=f"pb://deployments/{self.model_name}"
             )
+
+            result: GeneratedResponse
+            if self.adapter_id:
+                """
+                Attempt to retrieve the fine-tuned adapter from a Predibase repository.
+                If absent, then load the fine-tuned adapter from a HuggingFace repository.
+                """
+                adapter_model: Union[Model, HuggingFaceLLM]
+                try:
+                    adapter_model = self._client.get_model(
+                        name=self.adapter_id,
+                        version=self.adapter_version,
+                        model_id=None,
+                    )
+                except ServerResponseError:
+                    # Predibase does not recognize the adapter ID (query HuggingFace).
+                    adapter_model = self._client.LLM(uri=f"hf://{self.adapter_id}")
+                result = base_llm_deployment.with_adapter(model=adapter_model).generate(
+                    prompt=prompt,
+                    options=options,
+                )
+            else:
+                result = base_llm_deployment.generate(
+                    prompt=prompt,
+                    options=options,
+                )
+            response_text = result.response
         else:
-            result = base_llm_deployment.generate(
-                prompt=prompt,
-                options=options,
+            import requests
+            from lorax.client import Client as LoraxClient
+            from lorax.errors import GenerationError
+            from lorax.types import Response
+
+            lorax_client: LoraxClient = self._client.deployments.client(
+                deployment_ref=self.model_name
             )
 
-        return CompletionResponse(text=result.response)
+            response: Response
+            if self.adapter_id:
+                """
+                Attempt to retrieve the fine-tuned adapter from a Predibase repository.
+                If absent, then load the fine-tuned adapter from a HuggingFace repository.
+                """
+                if self.adapter_version:
+                    # Since the adapter version is provided, query the Predibase repository.
+                    pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}"
+                    try:
+                        response = lorax_client.generate(
+                            prompt=prompt,
+                            adapter_id=pb_adapter_id,
+                            **options,
+                        )
+                    except GenerationError as ge:
+                        raise ValueError(
+                            f'An adapter with the ID "{pb_adapter_id}" cannot be found in the Predibase repository of fine-tuned adapters.'
+                        ) from ge
+                else:
+                    # The adapter version is omitted, hence look for the adapter ID in the HuggingFace repository.
+                    try:
+                        response = lorax_client.generate(
+                            prompt=prompt,
+                            adapter_id=self.adapter_id,
+                            adapter_source="hub",
+                            **options,
+                        )
+                    except GenerationError as ge:
+                        raise ValueError(
+                            f"""Either an adapter with the ID "{self.adapter_id}" cannot be found in a HuggingFace repository, \
+or it is incompatible with the base model (please make sure that the adapter configuration is consistent).
+"""
+                        ) from ge
+            else:
+                try:
+                    response = lorax_client.generate(
+                        prompt=prompt,
+                        **options,
+                    )
+                except requests.JSONDecodeError as jde:
+                    raise ValueError(
+                        f"""An LLM with the deployment ID "{self.model_name}" cannot be found at Predibase \
+(please refer to "https://docs.predibase.com/user-guide/inference/models" for the list of supported models).
+"""
+                    ) from jde
+            response_text = response.generated_text
+
+        return CompletionResponse(text=response_text)
 
     @llm_completion_callback()
     def stream_complete(
         self, prompt: str, formatted: bool = False, **kwargs: Any
     ) -> "CompletionResponseGen":
         raise NotImplementedError
+
+    def _is_deprecated_sdk_version(self) -> bool:
+        try:
+            import semantic_version
+            from semantic_version.base import Version
+
+            from predibase.version import __version__ as current_version
+
+            sdk_semver_deprecated: Version = semantic_version.Version(
+                version_string="2024.4.8"
+            )
+            actual_current_version: str = self.predibase_sdk_version or current_version
+            sdk_semver_current: Version = semantic_version.Version(
+                version_string=actual_current_version
+            )
+            return not (
+                (sdk_semver_current > sdk_semver_deprecated)
+                or ("+dev" in actual_current_version)
+            )
+        except ImportError as e:
+            raise ImportError(
+                "Could not import Predibase Python package. "
+                "Please install it with `pip install semantic_version predibase`."
+            ) from e
diff --git a/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml
index 3cda9f5e11..7855014ada 100644
--- a/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-predibase"
 readme = "README.md"
-version = "0.1.4"
+version = "0.1.5"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
-- 
GitLab