Skip to content
Snippets Groups Projects
Unverified Commit bf358fa8 authored by Rashmi Pawar's avatar Rashmi Pawar Committed by GitHub
Browse files

Remove base url validation (#18031)

parent f18200d1
No related branches found
No related tags found
No related merge requests found
Showing
with 11 additions and 125 deletions
...@@ -13,7 +13,6 @@ from llama_index.core.callbacks.base import CallbackManager ...@@ -13,7 +13,6 @@ from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.base.llms.generic_utils import get_from_param_or_env from llama_index.core.base.llms.generic_utils import get_from_param_or_env
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from urllib.parse import urlparse, urlunparse
from .utils import ( from .utils import (
EMBEDDING_MODEL_TABLE, EMBEDDING_MODEL_TABLE,
BASE_URL, BASE_URL,
...@@ -120,14 +119,10 @@ class NVIDIAEmbedding(BaseEmbedding): ...@@ -120,14 +119,10 @@ class NVIDIAEmbedding(BaseEmbedding):
) )
self._is_hosted = self.base_url in KNOWN_URLS self._is_hosted = self.base_url in KNOWN_URLS
if not self._is_hosted:
self.base_url = self._validate_url(self.base_url)
if self._is_hosted: # hosted on API Catalog (build.nvidia.com) if self._is_hosted: # hosted on API Catalog (build.nvidia.com)
if api_key == "NO_API_KEY_PROVIDED": if api_key == "NO_API_KEY_PROVIDED":
raise ValueError("An API key is required for hosted NIM.") raise ValueError("An API key is required for hosted NIM.")
else: # not hosted
self.base_url = self._validate_url(self.base_url)
self._client = OpenAI( self._client = OpenAI(
api_key=api_key, api_key=api_key,
...@@ -176,38 +171,6 @@ class NVIDIAEmbedding(BaseEmbedding): ...@@ -176,38 +171,6 @@ class NVIDIAEmbedding(BaseEmbedding):
else: else:
self.model = self.model or DEFAULT_MODEL self.model = self.model or DEFAULT_MODEL
def _validate_url(self, base_url):
"""
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /embeddings
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
"""
if base_url is not None:
parsed = urlparse(base_url)
# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
expected_format = "Expected format is: http://host:port"
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)
normalized_path = parsed.path.rstrip("/")
if not normalized_path.endswith("/v1"):
warnings.warn(
f"{base_url} does not end in /v1, you may "
"have inference and listing issues"
)
normalized_path += "/v1"
base_url = urlunparse(
(parsed.scheme, parsed.netloc, normalized_path, None, None, None)
)
return base_url
def _validate_model(self, model_name: str) -> None: def _validate_model(self, model_name: str) -> None:
""" """
Validates compatibility of the hosted model with the client. Validates compatibility of the hosted model with the client.
......
...@@ -27,7 +27,7 @@ exclude = ["**/BUILD"] ...@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT" license = "MIT"
name = "llama-index-embeddings-nvidia" name = "llama-index-embeddings-nvidia"
readme = "README.md" readme = "README.md"
version = "0.3.2" version = "0.3.3"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
......
...@@ -56,6 +56,8 @@ def test_base_url_priority(public_class: type, monkeypatch) -> None: ...@@ -56,6 +56,8 @@ def test_base_url_priority(public_class: type, monkeypatch) -> None:
assert get_base_url(base_url=PARAM_URL) == PARAM_URL assert get_base_url(base_url=PARAM_URL) == PARAM_URL
# marking as skip because base_url validation is removed
@pytest.mark.skip(reason="base_url validation is removed")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"base_url", "base_url",
[ [
...@@ -75,6 +77,8 @@ def test_param_base_url_negative( ...@@ -75,6 +77,8 @@ def test_param_base_url_negative(
assert "Invalid base_url" in str(e.value) assert "Invalid base_url" in str(e.value)
# marking as skip because base_url validation is removed
@pytest.mark.skip(reason="base_url validation is removed")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"base_url", "base_url",
[ [
......
...@@ -11,7 +11,6 @@ from llama_index.core.base.llms.generic_utils import ( ...@@ -11,7 +11,6 @@ from llama_index.core.base.llms.generic_utils import (
from llama_index.llms.openai_like import OpenAILike from llama_index.llms.openai_like import OpenAILike
from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.llms.function_calling import FunctionCallingLLM
from urllib.parse import urlparse
from llama_index.core.base.llms.types import ( from llama_index.core.base.llms.types import (
ChatMessage, ChatMessage,
ChatResponse, ChatResponse,
...@@ -128,27 +127,6 @@ class NVIDIA(OpenAILike, FunctionCallingLLM): ...@@ -128,27 +127,6 @@ class NVIDIA(OpenAILike, FunctionCallingLLM):
else: else:
self.model = DEFAULT_MODEL self.model = DEFAULT_MODEL
def _validate_url(self, base_url):
"""
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /completions, /chat/completions,
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
"""
if base_url is not None:
base_url = base_url.rstrip("/")
parsed = urlparse(base_url)
# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
expected_format = "Expected format is: http://host:port"
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)
return base_url
def _validate_model(self, model_name: str) -> None: def _validate_model(self, model_name: str) -> None:
""" """
Validates compatibility of the hosted model with the client. Validates compatibility of the hosted model with the client.
......
...@@ -30,7 +30,7 @@ license = "MIT" ...@@ -30,7 +30,7 @@ license = "MIT"
name = "llama-index-llms-nvidia" name = "llama-index-llms-nvidia"
packages = [{include = "llama_index/"}] packages = [{include = "llama_index/"}]
readme = "README.md" readme = "README.md"
version = "0.3.2" version = "0.3.3"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
......
from typing import Any, List, Optional, Generator, Literal from typing import Any, List, Optional, Generator, Literal
import os import os
from urllib.parse import urlparse, urlunparse
import httpx import httpx
from llama_index.core.bridge.pydantic import Field, PrivateAttr, ConfigDict from llama_index.core.bridge.pydantic import Field, PrivateAttr, ConfigDict
...@@ -110,8 +109,6 @@ class NVIDIARerank(BaseNodePostprocessor): ...@@ -110,8 +109,6 @@ class NVIDIARerank(BaseNodePostprocessor):
if self._is_hosted: # hosted on API Catalog (build.nvidia.com) if self._is_hosted: # hosted on API Catalog (build.nvidia.com)
if (not self._api_key) or (self._api_key == "NO_API_KEY_PROVIDED"): if (not self._api_key) or (self._api_key == "NO_API_KEY_PROVIDED"):
raise ValueError("An API key is required for hosted NIM.") raise ValueError("An API key is required for hosted NIM.")
else: # not hosted
self.base_url = self._validate_url(self.base_url)
self.model = model self.model = model
if not self.model: if not self.model:
...@@ -210,65 +207,6 @@ class NVIDIARerank(BaseNodePostprocessor): ...@@ -210,65 +207,6 @@ class NVIDIARerank(BaseNodePostprocessor):
else: else:
return RANKING_MODEL_TABLE return RANKING_MODEL_TABLE
def _validate_url(self, base_url):
"""
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /embeddings
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /embeddings
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
"""
if base_url is not None:
parsed = urlparse(base_url)
# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
expected_format = "Expected format is: http://host:port"
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)
normalized_path = parsed.path.rstrip("/")
if not normalized_path.endswith("/v1"):
warnings.warn(
f"{base_url} does not end in /v1, you may "
"have inference and listing issues"
)
normalized_path += "/v1"
base_url = urlunparse(
(parsed.scheme, parsed.netloc, normalized_path, None, None, None)
)
if base_url is not None:
parsed = urlparse(base_url)
# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
expected_format = "Expected format is: http://host:port"
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)
normalized_path = parsed.path.rstrip("/")
if not normalized_path.endswith("/v1"):
warnings.warn(
f"{base_url} does not end in /v1, you may "
"have inference and listing issues"
)
normalized_path += "/v1"
base_url = urlunparse(
(parsed.scheme, parsed.netloc, normalized_path, None, None, None)
)
return base_url
def _validate_model(self, model_name: str) -> None: def _validate_model(self, model_name: str) -> None:
""" """
Validates compatibility of the hosted model with the client. Validates compatibility of the hosted model with the client.
......
...@@ -30,7 +30,7 @@ license = "MIT" ...@@ -30,7 +30,7 @@ license = "MIT"
name = "llama-index-postprocessor-nvidia-rerank" name = "llama-index-postprocessor-nvidia-rerank"
packages = [{include = "llama_index/"}] packages = [{include = "llama_index/"}]
readme = "README.md" readme = "README.md"
version = "0.4.2" version = "0.4.3"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
......
...@@ -31,7 +31,8 @@ def mock_v1_local_models2(respx_mock: respx.MockRouter, base_url: str) -> None: ...@@ -31,7 +31,8 @@ def mock_v1_local_models2(respx_mock: respx.MockRouter, base_url: str) -> None:
) )
# Updated test for non-hosted URLs that may need normalization. # marking as skip because base_url validation is removed
@pytest.mark.skip(reason="base_url validation is removed")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"base_url", "base_url",
[ [
...@@ -98,6 +99,8 @@ def test_proxy_base_url(base_url: str, mock_v1_local_models2: None) -> None: ...@@ -98,6 +99,8 @@ def test_proxy_base_url(base_url: str, mock_v1_local_models2: None) -> None:
assert client.base_url == base_url assert client.base_url == base_url
# marking as skip because base_url validation is removed
@pytest.mark.skip(reason="base_url validation is removed")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"base_url", "base_url",
[ [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment