diff --git a/docs/docs/examples/llm/predibase.ipynb b/docs/docs/examples/llm/predibase.ipynb index 26e3cad282bbf0195f27a38ae6968666c2ba555b..9725c27beb1591525b5e2afa404f3aeaaed9c84d 100644 --- a/docs/docs/examples/llm/predibase.ipynb +++ b/docs/docs/examples/llm/predibase.ipynb @@ -79,6 +79,7 @@ "# Predibase-hosted fine-tuned adapter example\n", "llm = PredibaseLLM(\n", " model_name=\"mistral-7b\",\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"e2e_nlg\", # adapter_id is optional\n", " adapter_version=1, # optional parameter (applies to Predibase only)\n", " temperature=0.3,\n", @@ -87,7 +88,7 @@ "# The `model_name` parameter is the Predibase \"serverless\" base_model ID\n", "# (see https://docs.predibase.com/user-guide/inference/models for the catalog).\n", "# You can also optionally specify a fine-tuned adapter that's hosted on Predibase or HuggingFace\n", - "# In the case of Predibase-hosted adapters, you can also specify the adapter_version (assumed latest if omitted)" + "# In the case of Predibase-hosted adapters, you must also specify the adapter_version" ] }, { @@ -100,6 +101,7 @@ "# HuggingFace-hosted fine-tuned adapter example\n", "llm = PredibaseLLM(\n", " model_name=\"mistral-7b\",\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"predibase/e2e_nlg\", # adapter_id is optional\n", " temperature=0.3,\n", " max_new_tokens=512,\n", @@ -197,6 +199,7 @@ "# Predibase-hosted fine-tuned adapter\n", "llm = PredibaseLLM(\n", " model_name=\"mistral-7b\",\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"e2e_nlg\", # adapter_id is optional\n", " temperature=0.3,\n", " context_window=1024,\n", @@ -213,6 +216,7 @@ "# HuggingFace-hosted fine-tuned adapter\n", "llm = PredibaseLLM(\n", " model_name=\"mistral-7b\",\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"predibase/e2e_nlg\", # adapter_id is optional\n", " temperature=0.3,\n", " context_window=1024,\n", diff --git a/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py b/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py index cfa1dd717e7a39f3f24515972658da0979b8c9e6..8e28e7e5fdd2b86dd0af4ab257de0020a083deef 100644 --- a/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py +++ b/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py @@ -33,7 +33,7 @@ class PredibaseLLM(CustomLLM): of a fine-tuned LLM adapter, whose base model is the `model` parameter; the fine-tuned adapter must be compatible with its base model; otherwise, an error is raised. If the fine-tuned adapter is hosted at Predibase, - `adapter_version` can be specified (omitting it gives the latest version). + `adapter_version` must be specified. Examples: `pip install llama-index-llms-predibase` @@ -47,6 +47,7 @@ class PredibaseLLM(CustomLLM): llm = PredibaseLLM( model_name="mistral-7b", + predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted) adapter_id="my-adapter-id", # optional parameter adapter_version=3, # optional parameter (applies to Predibase only) temperature=0.3, @@ -59,6 +60,10 @@ class PredibaseLLM(CustomLLM): model_name: str = Field(description="The Predibase base model to use.") predibase_api_key: str = Field(description="The Predibase API key to use.") + predibase_sdk_version: str = Field( + default=None, + description="The optional version (string) of the Predibase SDK (defaults to the latest if not specified).", + ) adapter_id: str = Field( default=None, description="The optional Predibase ID or HuggingFace ID of a fine-tuned adapter to use.", @@ -90,6 +95,7 @@ class PredibaseLLM(CustomLLM): self, model_name: str, predibase_api_key: Optional[str] = None, + predibase_sdk_version: Optional[str] = None, adapter_id: Optional[str] = None, adapter_version: Optional[int] = None, max_new_tokens: int = DEFAULT_NUM_OUTPUTS, @@ -111,9 +117,10 @@ class PredibaseLLM(CustomLLM): super().__init__( model_name=model_name, + predibase_api_key=predibase_api_key, + predibase_sdk_version=predibase_sdk_version, adapter_id=adapter_id, adapter_version=adapter_version, - predibase_api_key=predibase_api_key, max_new_tokens=max_new_tokens, temperature=temperature, context_window=context_window, @@ -125,26 +132,28 @@ class PredibaseLLM(CustomLLM): output_parser=output_parser, ) - self._client = self.initialize_client(predibase_api_key) + self._client: Union["PredibaseClient", "Predibase"] = self.initialize_client() - @staticmethod - def initialize_client(predibase_api_key: str) -> Any: + def initialize_client( + self, + ) -> Union["PredibaseClient", "Predibase"]: try: - from predibase import PredibaseClient - from predibase.pql import get_session - from predibase.pql.api import Session - - session: Session = get_session( - token=predibase_api_key, - gateway="https://api.app.predibase.com/v1", - serving_endpoint="serving.app.predibase.com", - ) - return PredibaseClient(session=session) - except ImportError as e: - raise ImportError( - "Could not import Predibase Python package. " - "Please install it with `pip install predibase`." - ) from e + if self._is_deprecated_sdk_version(): + from predibase import PredibaseClient + from predibase.pql import get_session + from predibase.pql.api import Session + + session: Session = get_session( + token=self.predibase_api_key, + gateway="https://api.app.predibase.com/v1", + serving_endpoint="serving.app.predibase.com", + ) + return PredibaseClient(session=session) + + from predibase import Predibase + + os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com" + return Predibase(api_token=self.predibase_api_key) except ValueError as e: raise ValueError("Your API key is not correct. Please try again") from e @@ -165,18 +174,6 @@ class PredibaseLLM(CustomLLM): def complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> "CompletionResponse": - from predibase.pql.api import ServerResponseError - from predibase.resource.llm.interface import ( - HuggingFaceLLM, - LLMDeployment, - ) - from predibase.resource.llm.response import GeneratedResponse - from predibase.resource.model import Model - - base_llm_deployment: LLMDeployment = self._client.LLM( - uri=f"pb://deployments/{self.model_name}" - ) - options: Dict[str, Union[str, float]] = copy.deepcopy(kwargs) options.update( { @@ -185,36 +182,133 @@ class PredibaseLLM(CustomLLM): } ) - result: GeneratedResponse - if self.adapter_id: - """ - Attempt to retrieve the fine-tuned adapter from a Predibase repository. - If absent, then load the fine-tuned adapter from a HuggingFace repository. - """ - adapter_model: Union[Model, HuggingFaceLLM] - try: - adapter_model = self._client.get_model( - name=self.adapter_id, - version=self.adapter_version, - model_id=None, - ) - except ServerResponseError: - # Predibase does not recognize the adapter ID (query HuggingFace). - adapter_model = self._client.LLM(uri=f"hf://{self.adapter_id}") - result = base_llm_deployment.with_adapter(model=adapter_model).generate( - prompt=prompt, - options=options, + response_text: str + + if self._is_deprecated_sdk_version(): + from predibase.pql.api import ServerResponseError + from predibase.resource.llm.interface import ( + HuggingFaceLLM, + LLMDeployment, + ) + from predibase.resource.llm.response import GeneratedResponse + from predibase.resource.model import Model + + base_llm_deployment: LLMDeployment = self._client.LLM( + uri=f"pb://deployments/{self.model_name}" ) + + result: GeneratedResponse + if self.adapter_id: + """ + Attempt to retrieve the fine-tuned adapter from a Predibase repository. + If absent, then load the fine-tuned adapter from a HuggingFace repository. + """ + adapter_model: Union[Model, HuggingFaceLLM] + try: + adapter_model = self._client.get_model( + name=self.adapter_id, + version=self.adapter_version, + model_id=None, + ) + except ServerResponseError: + # Predibase does not recognize the adapter ID (query HuggingFace). + adapter_model = self._client.LLM(uri=f"hf://{self.adapter_id}") + result = base_llm_deployment.with_adapter(model=adapter_model).generate( + prompt=prompt, + options=options, + ) + else: + result = base_llm_deployment.generate( + prompt=prompt, + options=options, + ) + response_text = result.response else: - result = base_llm_deployment.generate( - prompt=prompt, - options=options, + import requests + from lorax.client import Client as LoraxClient + from lorax.errors import GenerationError + from lorax.types import Response + + lorax_client: LoraxClient = self._client.deployments.client( + deployment_ref=self.model_name ) - return CompletionResponse(text=result.response) + response: Response + if self.adapter_id: + """ + Attempt to retrieve the fine-tuned adapter from a Predibase repository. + If absent, then load the fine-tuned adapter from a HuggingFace repository. + """ + if self.adapter_version: + # Since the adapter version is provided, query the Predibase repository. + pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}" + try: + response = lorax_client.generate( + prompt=prompt, + adapter_id=pb_adapter_id, + **options, + ) + except GenerationError as ge: + raise ValueError( + f'An adapter with the ID "{pb_adapter_id}" cannot be found in the Predibase repository of fine-tuned adapters.' + ) from ge + else: + # The adapter version is omitted, hence look for the adapter ID in the HuggingFace repository. + try: + response = lorax_client.generate( + prompt=prompt, + adapter_id=self.adapter_id, + adapter_source="hub", + **options, + ) + except GenerationError as ge: + raise ValueError( + f"""Either an adapter with the ID "{self.adapter_id}" cannot be found in a HuggingFace repository, \ +or it is incompatible with the base model (please make sure that the adapter configuration is consistent). +""" + ) from ge + else: + try: + response = lorax_client.generate( + prompt=prompt, + **options, + ) + except requests.JSONDecodeError as jde: + raise ValueError( + f"""An LLM with the deployment ID "{self.model_name}" cannot be found at Predibase \ +(please refer to "https://docs.predibase.com/user-guide/inference/models" for the list of supported models). +""" + ) from jde + response_text = response.generated_text + + return CompletionResponse(text=response_text) @llm_completion_callback() def stream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> "CompletionResponseGen": raise NotImplementedError + + def _is_deprecated_sdk_version(self) -> bool: + try: + import semantic_version + from semantic_version.base import Version + + from predibase.version import __version__ as current_version + + sdk_semver_deprecated: Version = semantic_version.Version( + version_string="2024.4.8" + ) + actual_current_version: str = self.predibase_sdk_version or current_version + sdk_semver_current: Version = semantic_version.Version( + version_string=actual_current_version + ) + return not ( + (sdk_semver_current > sdk_semver_deprecated) + or ("+dev" in actual_current_version) + ) + except ImportError as e: + raise ImportError( + "Could not import Predibase Python package. " + "Please install it with `pip install semantic_version predibase`." + ) from e diff --git a/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml index 3cda9f5e11e11a273578ccb54c0811f8b23548ad..7855014ada3a0c32aa890994a443a95bd65a4b3d 100644 --- a/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-predibase" readme = "README.md" -version = "0.1.4" +version = "0.1.5" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"