From 7ef27694ce26321987e0ed1922b8997ad8c51df5 Mon Sep 17 00:00:00 2001 From: Alex Sherstinsky <alexsherstinsky@users.noreply.github.com> Date: Thu, 11 Apr 2024 09:20:19 -0700 Subject: [PATCH] [BUGFIX] Update LlamaIndex-Predibase Integration (#12736) --- docs/docs/examples/llm/predibase.ipynb | 12 ++- .../llama_index/llms/predibase/base.py | 78 ++++++++++++++++--- .../llama-index-llms-predibase/pyproject.toml | 2 +- 3 files changed, 77 insertions(+), 15 deletions(-) diff --git a/docs/docs/examples/llm/predibase.ipynb b/docs/docs/examples/llm/predibase.ipynb index dffa20a060..25902a1559 100644 --- a/docs/docs/examples/llm/predibase.ipynb +++ b/docs/docs/examples/llm/predibase.ipynb @@ -77,9 +77,14 @@ "outputs": [], "source": [ "llm = PredibaseLLM(\n", - " model_name=\"llama-2-13b\", temperature=0.3, max_new_tokens=512\n", + " model_name=\"mistral-7b\",\n", + " adapter_id=\"predibase/e2e_nlg\", # adapter_id is optional\n", + " temperature=0.3,\n", + " max_new_tokens=512,\n", ")\n", - "# You can query any HuggingFace or fine-tuned LLM that's hosted on Predibase" + "# The `model_name` parameter is the Predibase \"serverless\" base_model ID\n", + "# (see https://docs.predibase.com/user-guide/inference/models for the catalog).\n", + "# You can also optionally specify a fine-tuned adapter that's hosted on HuggingFace" ] }, { @@ -167,7 +172,8 @@ "outputs": [], "source": [ "llm = PredibaseLLM(\n", - " model_name=\"llama-2-13b\",\n", + " model_name=\"mistral-7b\",\n", + " adapter_id=\"predibase/e2e_nlg\", # adapter_id is optional\n", " temperature=0.3,\n", " max_new_tokens=400,\n", " context_window=1024,\n", diff --git a/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py b/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py index d16576dbc5..ff8d7b8f0d 100644 --- a/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py +++ b/llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence, Union +import copy from llama_index.core.base.llms.types import ( ChatMessage, @@ -22,6 +23,16 @@ from llama_index.core.types import BaseOutputParser, PydanticProgramMode class PredibaseLLM(CustomLLM): """Predibase LLM. + To use, you should have the ``predibase`` python package installed, + and have your Predibase API key. + + The `model_name` parameter is the Predibase "serverless" base_model ID + (see https://docs.predibase.com/user-guide/inference/models for the catalog). + + An optional `adapter_id` parameter is the HuggingFace ID of a fine-tuned LLM + adapter, whose base model is the `model` parameter; the fine-tuned adapter + must be compatible with its base model; otherwise, an error is raised. + Examples: `pip install llama-index-llms-predibase` @@ -33,15 +44,22 @@ class PredibaseLLM(CustomLLM): from llama_index.llms.predibase import PredibaseLLM llm = PredibaseLLM( - model_name="llama-2-13b", temperature=0.3, max_new_tokens=512 + model_name="mistral-7b", + adapter_id="my-repo/my-adapter", # optional parameter + temperature=0.3, + max_new_tokens=512, ) response = llm.complete("Hello World!") print(str(response)) ``` """ - model_name: str = Field(description="The Predibase model to use.") + model_name: str = Field(description="The Predibase base model to use.") predibase_api_key: str = Field(description="The Predibase API key to use.") + adapter_id: str = Field( + default=None, + description="The optional HuggingFace ID of a fine-tuned adapter to use.", + ) max_new_tokens: int = Field( default=DEFAULT_NUM_OUTPUTS, description="The number of tokens to generate.", @@ -65,6 +83,7 @@ class PredibaseLLM(CustomLLM): self, model_name: str, predibase_api_key: Optional[str] = None, + adapter_id: Optional[str] = None, max_new_tokens: int = DEFAULT_NUM_OUTPUTS, temperature: float = DEFAULT_TEMPERATURE, context_window: int = DEFAULT_CONTEXT_WINDOW, @@ -82,10 +101,9 @@ class PredibaseLLM(CustomLLM): ) assert predibase_api_key is not None - self._client = self.initialize_client(predibase_api_key) - super().__init__( model_name=model_name, + adapter_id=adapter_id, predibase_api_key=predibase_api_key, max_new_tokens=max_new_tokens, temperature=temperature, @@ -98,12 +116,21 @@ class PredibaseLLM(CustomLLM): output_parser=output_parser, ) + self._client = self.initialize_client(predibase_api_key) + @staticmethod def initialize_client(predibase_api_key: str) -> Any: try: from predibase import PredibaseClient - - return PredibaseClient(token=predibase_api_key) + from predibase.pql import get_session + from predibase.pql.api import Session + + session: Session = get_session( + token=predibase_api_key, + gateway="https://api.app.predibase.com/v1", + serving_endpoint="serving.app.predibase.com", + ) + return PredibaseClient(session=session) except ImportError as e: raise ImportError( "Could not import Predibase Python package. " @@ -129,11 +156,40 @@ class PredibaseLLM(CustomLLM): def complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> "CompletionResponse": - llm = self._client.LLM(f"pb://deployments/{self.model_name}") - results = llm.prompt( - prompt, max_new_tokens=self.max_new_tokens, temperature=self.temperature + from predibase.resource.llm.interface import ( + HuggingFaceLLM, + LLMDeployment, + ) + from predibase.resource.llm.response import GeneratedResponse + + base_llm_deployment: LLMDeployment = self._client.LLM( + uri=f"pb://deployments/{self.model_name}" ) - return CompletionResponse(text=results.response) + + options: Dict[str, Union[str, float]] = copy.deepcopy(kwargs) + options.update( + { + "max_new_tokens": self.max_new_tokens, + "temperature": self.temperature, + } + ) + + result: GeneratedResponse + if self.adapter_id: + adapter_model: HuggingFaceLLM = self._client.LLM( + uri=f"hf://{self.adapter_id}" + ) + result = base_llm_deployment.with_adapter(model=adapter_model).generate( + prompt=prompt, + options=options, + ) + else: + result = base_llm_deployment.generate( + prompt=prompt, + options=options, + ) + + return CompletionResponse(text=result.response) @llm_completion_callback() def stream_complete( diff --git a/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml index 43015e6938..422215ee45 100644 --- a/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-predibase/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-predibase" readme = "README.md" -version = "0.1.2" +version = "0.1.3" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -- GitLab