Skip to content
Snippets Groups Projects
Unverified Commit cee23c9d authored by Alex Sherstinsky's avatar Alex Sherstinsky Committed by GitHub
Browse files

[FEATURE] Support both Predibase SDK-v1 and SDK-v2 (#13066)

parent cf384839
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id:4ec7cd6e tags:
<a href="https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/docs/examples/llm/predibase.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
%% Cell type:markdown id:bf9f19f3 tags:
# Predibase
This notebook shows how you can use Predibase-hosted LLM's within Llamaindex. You can add [Predibase](https://predibase.com) to your existing Llamaindex worklow to:
1. Deploy and query pre-trained or custom open source LLM’s without the hassle
2. Operationalize an end-to-end Retrieval Augmented Generation (RAG) system
3. Fine-tune your own LLM in just a few lines of code
## Getting Started
1. Sign up for a free Predibase account [here](https://predibase.com/free-trial)
2. Create an Account
3. Go to Settings > My profile and Generate a new API Token.
%% Cell type:code id:72d6eb5b tags:
``` python
%pip install llama-index-llms-predibase
```
%% Cell type:code id:79a726c5 tags:
``` python
!pip install llama-index --quiet
!pip install predibase --quiet
!pip install sentence-transformers --quiet
```
%% Cell type:code id:1c2b0d5d tags:
``` python
import os
os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
from llama_index.llms.predibase import PredibaseLLM
```
%% Cell type:markdown id:9a602a2a tags:
## Flow 1: Query Predibase LLM directly
%% Cell type:code id:4baffaa2 tags:
``` python
# Predibase-hosted fine-tuned adapter example
llm = PredibaseLLM(
model_name="mistral-7b",
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
adapter_id="e2e_nlg", # adapter_id is optional
adapter_version=1, # optional parameter (applies to Predibase only)
temperature=0.3,
max_new_tokens=512,
)
# The `model_name` parameter is the Predibase "serverless" base_model ID
# (see https://docs.predibase.com/user-guide/inference/models for the catalog).
# You can also optionally specify a fine-tuned adapter that's hosted on Predibase or HuggingFace
# In the case of Predibase-hosted adapters, you can also specify the adapter_version (assumed latest if omitted)
# In the case of Predibase-hosted adapters, you must also specify the adapter_version
```
%% Cell type:code id:69713553 tags:
``` python
# HuggingFace-hosted fine-tuned adapter example
llm = PredibaseLLM(
model_name="mistral-7b",
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
adapter_id="predibase/e2e_nlg", # adapter_id is optional
temperature=0.3,
max_new_tokens=512,
)
# The `model_name` parameter is the Predibase "serverless" base_model ID
# (see https://docs.predibase.com/user-guide/inference/models for the catalog).
# You can also optionally specify a fine-tuned adapter that's hosted on Predibase or HuggingFace
# In the case of Predibase-hosted adapters, you can also specify the adapter_version (assumed latest if omitted)
```
%% Cell type:code id:e7039a65 tags:
``` python
result = llm.complete("Can you recommend me a nice dry white wine?")
print(result)
```
%% Cell type:markdown id:1112e828 tags:
## Flow 2: Retrieval Augmented Generation (RAG) with Predibase LLM
%% Cell type:code id:cacff36a tags:
``` python
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.core.embeddings import resolve_embed_model
from llama_index.core.node_parser import SentenceSplitter
```
%% Cell type:markdown id:c8f6fef1 tags:
#### Download Data
%% Cell type:code id:65930e7e tags:
``` python
!mkdir -p 'data/paul_graham/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'
```
%% Cell type:markdown id:1edd41d1 tags:
### Load Documents
%% Cell type:code id:c5941151 tags:
``` python
documents = SimpleDirectoryReader("./data/paul_graham/").load_data()
```
%% Cell type:markdown id:7df4407f tags:
### Configure Predibase LLM
%% Cell type:code id:3f67e975-3cb5-4ddc-98e8-eae7892315ca tags:
``` python
# Predibase-hosted fine-tuned adapter
llm = PredibaseLLM(
model_name="mistral-7b",
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
adapter_id="e2e_nlg", # adapter_id is optional
temperature=0.3,
context_window=1024,
)
```
%% Cell type:code id:4a44defc tags:
``` python
# HuggingFace-hosted fine-tuned adapter
llm = PredibaseLLM(
model_name="mistral-7b",
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
adapter_id="predibase/e2e_nlg", # adapter_id is optional
temperature=0.3,
context_window=1024,
)
```
%% Cell type:code id:b3d527b7-5110-4d9c-97df-3926d3db0772 tags:
``` python
embed_model = resolve_embed_model("local:BAAI/bge-small-en-v1.5")
splitter = SentenceSplitter(chunk_size=1024)
```
%% Cell type:markdown id:7a131a8e tags:
### Setup and Query Index
%% Cell type:code id:c9b10269 tags:
``` python
index = VectorStoreIndex.from_documents(
documents, transformations=[splitter], embed_model=embed_model
)
query_engine = index.as_query_engine(llm=llm)
response = query_engine.query("What did the author do growing up?")
```
%% Cell type:code id:ac73eb65 tags:
``` python
print(response)
```
......
......@@ -33,7 +33,7 @@ class PredibaseLLM(CustomLLM):
of a fine-tuned LLM adapter, whose base model is the `model` parameter; the
fine-tuned adapter must be compatible with its base model; otherwise, an
error is raised. If the fine-tuned adapter is hosted at Predibase,
`adapter_version` can be specified (omitting it gives the latest version).
`adapter_version` must be specified.
Examples:
`pip install llama-index-llms-predibase`
......@@ -47,6 +47,7 @@ class PredibaseLLM(CustomLLM):
llm = PredibaseLLM(
model_name="mistral-7b",
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
adapter_id="my-adapter-id", # optional parameter
adapter_version=3, # optional parameter (applies to Predibase only)
temperature=0.3,
......@@ -59,6 +60,10 @@ class PredibaseLLM(CustomLLM):
model_name: str = Field(description="The Predibase base model to use.")
predibase_api_key: str = Field(description="The Predibase API key to use.")
predibase_sdk_version: str = Field(
default=None,
description="The optional version (string) of the Predibase SDK (defaults to the latest if not specified).",
)
adapter_id: str = Field(
default=None,
description="The optional Predibase ID or HuggingFace ID of a fine-tuned adapter to use.",
......@@ -90,6 +95,7 @@ class PredibaseLLM(CustomLLM):
self,
model_name: str,
predibase_api_key: Optional[str] = None,
predibase_sdk_version: Optional[str] = None,
adapter_id: Optional[str] = None,
adapter_version: Optional[int] = None,
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
......@@ -111,9 +117,10 @@ class PredibaseLLM(CustomLLM):
super().__init__(
model_name=model_name,
predibase_api_key=predibase_api_key,
predibase_sdk_version=predibase_sdk_version,
adapter_id=adapter_id,
adapter_version=adapter_version,
predibase_api_key=predibase_api_key,
max_new_tokens=max_new_tokens,
temperature=temperature,
context_window=context_window,
......@@ -125,26 +132,28 @@ class PredibaseLLM(CustomLLM):
output_parser=output_parser,
)
self._client = self.initialize_client(predibase_api_key)
self._client: Union["PredibaseClient", "Predibase"] = self.initialize_client()
@staticmethod
def initialize_client(predibase_api_key: str) -> Any:
def initialize_client(
self,
) -> Union["PredibaseClient", "Predibase"]:
try:
from predibase import PredibaseClient
from predibase.pql import get_session
from predibase.pql.api import Session
session: Session = get_session(
token=predibase_api_key,
gateway="https://api.app.predibase.com/v1",
serving_endpoint="serving.app.predibase.com",
)
return PredibaseClient(session=session)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install predibase`."
) from e
if self._is_deprecated_sdk_version():
from predibase import PredibaseClient
from predibase.pql import get_session
from predibase.pql.api import Session
session: Session = get_session(
token=self.predibase_api_key,
gateway="https://api.app.predibase.com/v1",
serving_endpoint="serving.app.predibase.com",
)
return PredibaseClient(session=session)
from predibase import Predibase
os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com"
return Predibase(api_token=self.predibase_api_key)
except ValueError as e:
raise ValueError("Your API key is not correct. Please try again") from e
......@@ -165,18 +174,6 @@ class PredibaseLLM(CustomLLM):
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> "CompletionResponse":
from predibase.pql.api import ServerResponseError
from predibase.resource.llm.interface import (
HuggingFaceLLM,
LLMDeployment,
)
from predibase.resource.llm.response import GeneratedResponse
from predibase.resource.model import Model
base_llm_deployment: LLMDeployment = self._client.LLM(
uri=f"pb://deployments/{self.model_name}"
)
options: Dict[str, Union[str, float]] = copy.deepcopy(kwargs)
options.update(
{
......@@ -185,36 +182,133 @@ class PredibaseLLM(CustomLLM):
}
)
result: GeneratedResponse
if self.adapter_id:
"""
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
If absent, then load the fine-tuned adapter from a HuggingFace repository.
"""
adapter_model: Union[Model, HuggingFaceLLM]
try:
adapter_model = self._client.get_model(
name=self.adapter_id,
version=self.adapter_version,
model_id=None,
)
except ServerResponseError:
# Predibase does not recognize the adapter ID (query HuggingFace).
adapter_model = self._client.LLM(uri=f"hf://{self.adapter_id}")
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
prompt=prompt,
options=options,
response_text: str
if self._is_deprecated_sdk_version():
from predibase.pql.api import ServerResponseError
from predibase.resource.llm.interface import (
HuggingFaceLLM,
LLMDeployment,
)
from predibase.resource.llm.response import GeneratedResponse
from predibase.resource.model import Model
base_llm_deployment: LLMDeployment = self._client.LLM(
uri=f"pb://deployments/{self.model_name}"
)
result: GeneratedResponse
if self.adapter_id:
"""
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
If absent, then load the fine-tuned adapter from a HuggingFace repository.
"""
adapter_model: Union[Model, HuggingFaceLLM]
try:
adapter_model = self._client.get_model(
name=self.adapter_id,
version=self.adapter_version,
model_id=None,
)
except ServerResponseError:
# Predibase does not recognize the adapter ID (query HuggingFace).
adapter_model = self._client.LLM(uri=f"hf://{self.adapter_id}")
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
prompt=prompt,
options=options,
)
else:
result = base_llm_deployment.generate(
prompt=prompt,
options=options,
)
response_text = result.response
else:
result = base_llm_deployment.generate(
prompt=prompt,
options=options,
import requests
from lorax.client import Client as LoraxClient
from lorax.errors import GenerationError
from lorax.types import Response
lorax_client: LoraxClient = self._client.deployments.client(
deployment_ref=self.model_name
)
return CompletionResponse(text=result.response)
response: Response
if self.adapter_id:
"""
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
If absent, then load the fine-tuned adapter from a HuggingFace repository.
"""
if self.adapter_version:
# Since the adapter version is provided, query the Predibase repository.
pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}"
try:
response = lorax_client.generate(
prompt=prompt,
adapter_id=pb_adapter_id,
**options,
)
except GenerationError as ge:
raise ValueError(
f'An adapter with the ID "{pb_adapter_id}" cannot be found in the Predibase repository of fine-tuned adapters.'
) from ge
else:
# The adapter version is omitted, hence look for the adapter ID in the HuggingFace repository.
try:
response = lorax_client.generate(
prompt=prompt,
adapter_id=self.adapter_id,
adapter_source="hub",
**options,
)
except GenerationError as ge:
raise ValueError(
f"""Either an adapter with the ID "{self.adapter_id}" cannot be found in a HuggingFace repository, \
or it is incompatible with the base model (please make sure that the adapter configuration is consistent).
"""
) from ge
else:
try:
response = lorax_client.generate(
prompt=prompt,
**options,
)
except requests.JSONDecodeError as jde:
raise ValueError(
f"""An LLM with the deployment ID "{self.model_name}" cannot be found at Predibase \
(please refer to "https://docs.predibase.com/user-guide/inference/models" for the list of supported models).
"""
) from jde
response_text = response.generated_text
return CompletionResponse(text=response_text)
@llm_completion_callback()
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> "CompletionResponseGen":
raise NotImplementedError
def _is_deprecated_sdk_version(self) -> bool:
try:
import semantic_version
from semantic_version.base import Version
from predibase.version import __version__ as current_version
sdk_semver_deprecated: Version = semantic_version.Version(
version_string="2024.4.8"
)
actual_current_version: str = self.predibase_sdk_version or current_version
sdk_semver_current: Version = semantic_version.Version(
version_string=actual_current_version
)
return not (
(sdk_semver_current > sdk_semver_deprecated)
or ("+dev" in actual_current_version)
)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install semantic_version predibase`."
) from e
......@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-predibase"
readme = "README.md"
version = "0.1.4"
version = "0.1.5"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment