From 6a0c06ce44e126c5ec1ff18123ced27f81e4e98c Mon Sep 17 00:00:00 2001
From: Denny Lee <denny.g.lee@gmail.com>
Date: Fri, 5 Apr 2024 11:58:33 -0700
Subject: [PATCH] Update DataBricks to Databricks and renamed environment
 variables  (#12595)

---
 .../llama-index-llms-databricks/README.md     | 20 +++++++++----------
 .../llama_index/llms/databricks/__init__.py   |  4 ++--
 .../llama_index/llms/databricks/base.py       | 16 +++++++--------
 .../pyproject.toml                            |  2 +-
 .../tests/test_integration_databricks.py      | 16 ++++++++-------
 .../tests/test_llms_databricks.py             |  4 ++--
 6 files changed, 32 insertions(+), 30 deletions(-)

diff --git a/llama-index-integrations/llms/llama-index-llms-databricks/README.md b/llama-index-integrations/llms/llama-index-llms-databricks/README.md
index 4bf0b5a038..cd5e60ba67 100644
--- a/llama-index-integrations/llms/llama-index-llms-databricks/README.md
+++ b/llama-index-integrations/llms/llama-index-llms-databricks/README.md
@@ -2,7 +2,7 @@
 
 ## Overview
 
-Integrate with DataBricks LLMs APIs.
+Integrate with Databricks LLMs APIs.
 
 ## Installation
 
@@ -15,15 +15,15 @@ pip install llama-index-llms-databricks
 With environmental variables.
 
 ```.env
-DATABRICKS_API_KEY=your_api_key
-DATABRICKS_API_BASE=https://[your-work-space].cloud.databricks.com/serving-endpoints/[your-serving-endpoint]
+DATABRICKS_TOKEN=your_api_key
+DATABRICKS_SERVING_ENDPOINT=https://[your-work-space].cloud.databricks.com/serving-endpoints
 ```
 
 ```python
-from llama_index.llms.databricks import DataBricks
+from llama_index.llms.databricks import Databricks
 
-# Initialize DataBricks LLM without explicitly passing the API key and base
-llm = DataBricks(model="databricks-dbrx-instruct")
+# Initialize Databricks LLM without explicitly passing the API key and base
+llm = Databricks(model="databricks-dbrx-instruct")
 
 # Make a query to the LLM
 response = llm.complete("Explain the importance of open source LLMs")
@@ -34,13 +34,13 @@ print(response)
 Without environmental variables
 
 ```python
-from llama_index.llms.databricks import DataBricks
+from llama_index.llms.databricks import Databricks
 
-# Set up the DataBricks class with the required model, API key and serving endpoint
-llm = DataBricks(
+# Set up the Databricks class with the required model, API key and serving endpoint
+llm = Databricks(
     model="databricks-dbrx-instruct",
     api_key="your_api_key",
-    api_base="https://[your-work-space].cloud.databricks.com/serving-endpoints/[your-serving-endpoint]",
+    api_base="https://[your-work-space].cloud.databricks.com/serving-endpoints",
 )
 
 # Call the complete method with a query
diff --git a/llama-index-integrations/llms/llama-index-llms-databricks/llama_index/llms/databricks/__init__.py b/llama-index-integrations/llms/llama-index-llms-databricks/llama_index/llms/databricks/__init__.py
index a46ccbff70..5872eab54b 100644
--- a/llama-index-integrations/llms/llama-index-llms-databricks/llama_index/llms/databricks/__init__.py
+++ b/llama-index-integrations/llms/llama-index-llms-databricks/llama_index/llms/databricks/__init__.py
@@ -1,3 +1,3 @@
-from llama_index.llms.databricks.base import DataBricks
+from llama_index.llms.databricks.base import Databricks
 
-__all__ = ["DataBricks"]
+__all__ = ["Databricks"]
diff --git a/llama-index-integrations/llms/llama-index-llms-databricks/llama_index/llms/databricks/base.py b/llama-index-integrations/llms/llama-index-llms-databricks/llama_index/llms/databricks/base.py
index dc99f7feb1..2d0e35a27e 100644
--- a/llama-index-integrations/llms/llama-index-llms-databricks/llama_index/llms/databricks/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-databricks/llama_index/llms/databricks/base.py
@@ -4,17 +4,17 @@ from typing import Any, Optional
 from llama_index.llms.openai_like import OpenAILike
 
 
-class DataBricks(OpenAILike):
-    """DataBricks LLM.
+class Databricks(OpenAILike):
+    """Databricks LLM.
 
     Examples:
         `pip install llama-index-llms-databricks`
 
         ```python
-        from llama_index.llms.databricks import DataBricks
+        from llama_index.llms.databricks import Databricks
 
-        # Set up the DataBricks class with the required model, API key and serving endpoint
-        llm = DataBricks(model="databricks-dbrx-instruct", api_key="your_api_key", api_base="https://[your-work-space].cloud.databricks.com/serving-endpoints/[your-serving-endpoint]")
+        # Set up the Databricks class with the required model, API key and serving endpoint
+        llm = Databricks(model="databricks-dbrx-instruct", api_key="your_api_key", api_base="https://[your-work-space].cloud.databricks.com/serving-endpoints")
 
         # Call the complete method with a query
         response = llm.complete("Explain the importance of open source LLMs")
@@ -31,8 +31,8 @@ class DataBricks(OpenAILike):
         is_chat_model: bool = True,
         **kwargs: Any,
     ) -> None:
-        api_key = api_key or os.environ.get("DATABRICKS_API_KEY", None)
-        api_base = api_base or os.environ.get("DATABRICKS_API_BASE", None)
+        api_key = api_key or os.environ.get("DATABRICKS_TOKEN", None)
+        api_base = api_base or os.environ.get("DATABRICKS_SERVING_ENDPOINT", None)
         super().__init__(
             model=model,
             api_key=api_key,
@@ -44,4 +44,4 @@ class DataBricks(OpenAILike):
     @classmethod
     def class_name(cls) -> str:
         """Get class name."""
-        return "DataBricks"
+        return "Databricks"
diff --git a/llama-index-integrations/llms/llama-index-llms-databricks/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-databricks/pyproject.toml
index 631af5be36..c5d980fcc1 100644
--- a/llama-index-integrations/llms/llama-index-llms-databricks/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-databricks/pyproject.toml
@@ -30,7 +30,7 @@ license = "MIT"
 name = "llama-index-llms-databricks"
 packages = [{include = "llama_index/"}]
 readme = "README.md"
-version = "0.1.0"
+version = "0.1.1"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
diff --git a/llama-index-integrations/llms/llama-index-llms-databricks/tests/test_integration_databricks.py b/llama-index-integrations/llms/llama-index-llms-databricks/tests/test_integration_databricks.py
index fbf6af8af5..8f0a1ca161 100644
--- a/llama-index-integrations/llms/llama-index-llms-databricks/tests/test_integration_databricks.py
+++ b/llama-index-integrations/llms/llama-index-llms-databricks/tests/test_integration_databricks.py
@@ -2,15 +2,16 @@ import os
 
 import pytest
 
-from llama_index.llms.databricks import DataBricks
+from llama_index.llms.databricks import Databricks
 
 
 @pytest.mark.skipif(
-    "DATABRICKS_API_KEY" not in os.environ or "DATABRICKS_API_BASE" not in os.environ,
-    reason="DATABRICKS_API_KEY or DATABRICKS_API_BASE not set in environment",
+    "DATABRICKS_TOKEN" not in os.environ
+    or "DATABRICKS_SERVING_ENDPOINT" not in os.environ,
+    reason="DATABRICKS_TOKEN or DATABRICKS_SERVING_ENDPOINT not set in environment",
 )
 def test_completion():
-    databricks = DataBricks(
+    databricks = Databricks(
         model="databricks-dbrx-instruct", temperature=0, max_tokens=2
     )
     resp = databricks.complete("hello")
@@ -18,11 +19,12 @@ def test_completion():
 
 
 @pytest.mark.skipif(
-    "DATABRICKS_API_KEY" not in os.environ or "DATABRICKS_API_BASE" not in os.environ,
-    reason="DATABRICKS_API_KEY or DATABRICKS_API_BASE not set in environment",
+    "DATABRICKS_TOKEN" not in os.environ
+    or "DATABRICKS_SERVING_ENDPOINT" not in os.environ,
+    reason="DATABRICKS_TOKEN or DATABRICKS_SERVING_ENDPOINT not set in environment",
 )
 def test_stream_completion():
-    databricks = DataBricks(
+    databricks = Databricks(
         model="databricks-dbrx-instruct", temperature=0, max_tokens=2
     )
     stream = databricks.stream_complete("hello")
diff --git a/llama-index-integrations/llms/llama-index-llms-databricks/tests/test_llms_databricks.py b/llama-index-integrations/llms/llama-index-llms-databricks/tests/test_llms_databricks.py
index 1bfe9b0022..e724776328 100644
--- a/llama-index-integrations/llms/llama-index-llms-databricks/tests/test_llms_databricks.py
+++ b/llama-index-integrations/llms/llama-index-llms-databricks/tests/test_llms_databricks.py
@@ -1,7 +1,7 @@
 from llama_index.core.base.llms.base import BaseLLM
-from llama_index.llms.databricks import DataBricks
+from llama_index.llms.databricks import Databricks
 
 
 def test_embedding_class():
-    names_of_base_classes = [b.__name__ for b in DataBricks.__mro__]
+    names_of_base_classes = [b.__name__ for b in Databricks.__mro__]
     assert BaseLLM.__name__ in names_of_base_classes
-- 
GitLab