Skip to content
Snippets Groups Projects
Unverified Commit 6a0c06ce authored by Denny Lee's avatar Denny Lee Committed by GitHub
Browse files

Update DataBricks to Databricks and renamed environment variables (#12595)

parent c30af7f6
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
## Overview ## Overview
Integrate with DataBricks LLMs APIs. Integrate with Databricks LLMs APIs.
## Installation ## Installation
...@@ -15,15 +15,15 @@ pip install llama-index-llms-databricks ...@@ -15,15 +15,15 @@ pip install llama-index-llms-databricks
With environmental variables. With environmental variables.
```.env ```.env
DATABRICKS_API_KEY=your_api_key DATABRICKS_TOKEN=your_api_key
DATABRICKS_API_BASE=https://[your-work-space].cloud.databricks.com/serving-endpoints/[your-serving-endpoint] DATABRICKS_SERVING_ENDPOINT=https://[your-work-space].cloud.databricks.com/serving-endpoints
``` ```
```python ```python
from llama_index.llms.databricks import DataBricks from llama_index.llms.databricks import Databricks
# Initialize DataBricks LLM without explicitly passing the API key and base # Initialize Databricks LLM without explicitly passing the API key and base
llm = DataBricks(model="databricks-dbrx-instruct") llm = Databricks(model="databricks-dbrx-instruct")
# Make a query to the LLM # Make a query to the LLM
response = llm.complete("Explain the importance of open source LLMs") response = llm.complete("Explain the importance of open source LLMs")
...@@ -34,13 +34,13 @@ print(response) ...@@ -34,13 +34,13 @@ print(response)
Without environmental variables Without environmental variables
```python ```python
from llama_index.llms.databricks import DataBricks from llama_index.llms.databricks import Databricks
# Set up the DataBricks class with the required model, API key and serving endpoint # Set up the Databricks class with the required model, API key and serving endpoint
llm = DataBricks( llm = Databricks(
model="databricks-dbrx-instruct", model="databricks-dbrx-instruct",
api_key="your_api_key", api_key="your_api_key",
api_base="https://[your-work-space].cloud.databricks.com/serving-endpoints/[your-serving-endpoint]", api_base="https://[your-work-space].cloud.databricks.com/serving-endpoints",
) )
# Call the complete method with a query # Call the complete method with a query
......
from llama_index.llms.databricks.base import DataBricks from llama_index.llms.databricks.base import Databricks
__all__ = ["DataBricks"] __all__ = ["Databricks"]
...@@ -4,17 +4,17 @@ from typing import Any, Optional ...@@ -4,17 +4,17 @@ from typing import Any, Optional
from llama_index.llms.openai_like import OpenAILike from llama_index.llms.openai_like import OpenAILike
class DataBricks(OpenAILike): class Databricks(OpenAILike):
"""DataBricks LLM. """Databricks LLM.
Examples: Examples:
`pip install llama-index-llms-databricks` `pip install llama-index-llms-databricks`
```python ```python
from llama_index.llms.databricks import DataBricks from llama_index.llms.databricks import Databricks
# Set up the DataBricks class with the required model, API key and serving endpoint # Set up the Databricks class with the required model, API key and serving endpoint
llm = DataBricks(model="databricks-dbrx-instruct", api_key="your_api_key", api_base="https://[your-work-space].cloud.databricks.com/serving-endpoints/[your-serving-endpoint]") llm = Databricks(model="databricks-dbrx-instruct", api_key="your_api_key", api_base="https://[your-work-space].cloud.databricks.com/serving-endpoints")
# Call the complete method with a query # Call the complete method with a query
response = llm.complete("Explain the importance of open source LLMs") response = llm.complete("Explain the importance of open source LLMs")
...@@ -31,8 +31,8 @@ class DataBricks(OpenAILike): ...@@ -31,8 +31,8 @@ class DataBricks(OpenAILike):
is_chat_model: bool = True, is_chat_model: bool = True,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
api_key = api_key or os.environ.get("DATABRICKS_API_KEY", None) api_key = api_key or os.environ.get("DATABRICKS_TOKEN", None)
api_base = api_base or os.environ.get("DATABRICKS_API_BASE", None) api_base = api_base or os.environ.get("DATABRICKS_SERVING_ENDPOINT", None)
super().__init__( super().__init__(
model=model, model=model,
api_key=api_key, api_key=api_key,
...@@ -44,4 +44,4 @@ class DataBricks(OpenAILike): ...@@ -44,4 +44,4 @@ class DataBricks(OpenAILike):
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
"""Get class name.""" """Get class name."""
return "DataBricks" return "Databricks"
...@@ -30,7 +30,7 @@ license = "MIT" ...@@ -30,7 +30,7 @@ license = "MIT"
name = "llama-index-llms-databricks" name = "llama-index-llms-databricks"
packages = [{include = "llama_index/"}] packages = [{include = "llama_index/"}]
readme = "README.md" readme = "README.md"
version = "0.1.0" version = "0.1.1"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
......
...@@ -2,15 +2,16 @@ import os ...@@ -2,15 +2,16 @@ import os
import pytest import pytest
from llama_index.llms.databricks import DataBricks from llama_index.llms.databricks import Databricks
@pytest.mark.skipif( @pytest.mark.skipif(
"DATABRICKS_API_KEY" not in os.environ or "DATABRICKS_API_BASE" not in os.environ, "DATABRICKS_TOKEN" not in os.environ
reason="DATABRICKS_API_KEY or DATABRICKS_API_BASE not set in environment", or "DATABRICKS_SERVING_ENDPOINT" not in os.environ,
reason="DATABRICKS_TOKEN or DATABRICKS_SERVING_ENDPOINT not set in environment",
) )
def test_completion(): def test_completion():
databricks = DataBricks( databricks = Databricks(
model="databricks-dbrx-instruct", temperature=0, max_tokens=2 model="databricks-dbrx-instruct", temperature=0, max_tokens=2
) )
resp = databricks.complete("hello") resp = databricks.complete("hello")
...@@ -18,11 +19,12 @@ def test_completion(): ...@@ -18,11 +19,12 @@ def test_completion():
@pytest.mark.skipif( @pytest.mark.skipif(
"DATABRICKS_API_KEY" not in os.environ or "DATABRICKS_API_BASE" not in os.environ, "DATABRICKS_TOKEN" not in os.environ
reason="DATABRICKS_API_KEY or DATABRICKS_API_BASE not set in environment", or "DATABRICKS_SERVING_ENDPOINT" not in os.environ,
reason="DATABRICKS_TOKEN or DATABRICKS_SERVING_ENDPOINT not set in environment",
) )
def test_stream_completion(): def test_stream_completion():
databricks = DataBricks( databricks = Databricks(
model="databricks-dbrx-instruct", temperature=0, max_tokens=2 model="databricks-dbrx-instruct", temperature=0, max_tokens=2
) )
stream = databricks.stream_complete("hello") stream = databricks.stream_complete("hello")
......
from llama_index.core.base.llms.base import BaseLLM from llama_index.core.base.llms.base import BaseLLM
from llama_index.llms.databricks import DataBricks from llama_index.llms.databricks import Databricks
def test_embedding_class(): def test_embedding_class():
names_of_base_classes = [b.__name__ for b in DataBricks.__mro__] names_of_base_classes = [b.__name__ for b in Databricks.__mro__]
assert BaseLLM.__name__ in names_of_base_classes assert BaseLLM.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment