Skip to content
Snippets Groups Projects
Unverified Commit bf358fa8 authored by Rashmi Pawar's avatar Rashmi Pawar Committed by GitHub
Browse files

Remove base url validation (#18031)

parent f18200d1
No related branches found
No related tags found
No related merge requests found
Showing
with 11 additions and 125 deletions
......@@ -13,7 +13,6 @@ from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.base.llms.generic_utils import get_from_param_or_env
from openai import OpenAI, AsyncOpenAI
from urllib.parse import urlparse, urlunparse
from .utils import (
EMBEDDING_MODEL_TABLE,
BASE_URL,
......@@ -120,14 +119,10 @@ class NVIDIAEmbedding(BaseEmbedding):
)
self._is_hosted = self.base_url in KNOWN_URLS
if not self._is_hosted:
self.base_url = self._validate_url(self.base_url)
if self._is_hosted: # hosted on API Catalog (build.nvidia.com)
if api_key == "NO_API_KEY_PROVIDED":
raise ValueError("An API key is required for hosted NIM.")
else: # not hosted
self.base_url = self._validate_url(self.base_url)
self._client = OpenAI(
api_key=api_key,
......@@ -176,38 +171,6 @@ class NVIDIAEmbedding(BaseEmbedding):
else:
self.model = self.model or DEFAULT_MODEL
def _validate_url(self, base_url):
"""
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /embeddings
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
"""
if base_url is not None:
parsed = urlparse(base_url)
# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
expected_format = "Expected format is: http://host:port"
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)
normalized_path = parsed.path.rstrip("/")
if not normalized_path.endswith("/v1"):
warnings.warn(
f"{base_url} does not end in /v1, you may "
"have inference and listing issues"
)
normalized_path += "/v1"
base_url = urlunparse(
(parsed.scheme, parsed.netloc, normalized_path, None, None, None)
)
return base_url
def _validate_model(self, model_name: str) -> None:
"""
Validates compatibility of the hosted model with the client.
......
......@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-nvidia"
readme = "README.md"
version = "0.3.2"
version = "0.3.3"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
......
......@@ -56,6 +56,8 @@ def test_base_url_priority(public_class: type, monkeypatch) -> None:
assert get_base_url(base_url=PARAM_URL) == PARAM_URL
# marking as skip because base_url validation is removed
@pytest.mark.skip(reason="base_url validation is removed")
@pytest.mark.parametrize(
"base_url",
[
......@@ -75,6 +77,8 @@ def test_param_base_url_negative(
assert "Invalid base_url" in str(e.value)
# marking as skip because base_url validation is removed
@pytest.mark.skip(reason="base_url validation is removed")
@pytest.mark.parametrize(
"base_url",
[
......
......@@ -11,7 +11,6 @@ from llama_index.core.base.llms.generic_utils import (
from llama_index.llms.openai_like import OpenAILike
from llama_index.core.llms.function_calling import FunctionCallingLLM
from urllib.parse import urlparse
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
......@@ -128,27 +127,6 @@ class NVIDIA(OpenAILike, FunctionCallingLLM):
else:
self.model = DEFAULT_MODEL
def _validate_url(self, base_url):
"""
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /completions, /chat/completions,
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
"""
if base_url is not None:
base_url = base_url.rstrip("/")
parsed = urlparse(base_url)
# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
expected_format = "Expected format is: http://host:port"
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)
return base_url
def _validate_model(self, model_name: str) -> None:
"""
Validates compatibility of the hosted model with the client.
......
......@@ -30,7 +30,7 @@ license = "MIT"
name = "llama-index-llms-nvidia"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.3.2"
version = "0.3.3"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
......
from typing import Any, List, Optional, Generator, Literal
import os
from urllib.parse import urlparse, urlunparse
import httpx
from llama_index.core.bridge.pydantic import Field, PrivateAttr, ConfigDict
......@@ -110,8 +109,6 @@ class NVIDIARerank(BaseNodePostprocessor):
if self._is_hosted: # hosted on API Catalog (build.nvidia.com)
if (not self._api_key) or (self._api_key == "NO_API_KEY_PROVIDED"):
raise ValueError("An API key is required for hosted NIM.")
else: # not hosted
self.base_url = self._validate_url(self.base_url)
self.model = model
if not self.model:
......@@ -210,65 +207,6 @@ class NVIDIARerank(BaseNodePostprocessor):
else:
return RANKING_MODEL_TABLE
def _validate_url(self, base_url):
"""
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /embeddings
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /embeddings
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
"""
if base_url is not None:
parsed = urlparse(base_url)
# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
expected_format = "Expected format is: http://host:port"
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)
normalized_path = parsed.path.rstrip("/")
if not normalized_path.endswith("/v1"):
warnings.warn(
f"{base_url} does not end in /v1, you may "
"have inference and listing issues"
)
normalized_path += "/v1"
base_url = urlunparse(
(parsed.scheme, parsed.netloc, normalized_path, None, None, None)
)
if base_url is not None:
parsed = urlparse(base_url)
# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
expected_format = "Expected format is: http://host:port"
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)
normalized_path = parsed.path.rstrip("/")
if not normalized_path.endswith("/v1"):
warnings.warn(
f"{base_url} does not end in /v1, you may "
"have inference and listing issues"
)
normalized_path += "/v1"
base_url = urlunparse(
(parsed.scheme, parsed.netloc, normalized_path, None, None, None)
)
return base_url
def _validate_model(self, model_name: str) -> None:
"""
Validates compatibility of the hosted model with the client.
......
......@@ -30,7 +30,7 @@ license = "MIT"
name = "llama-index-postprocessor-nvidia-rerank"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.4.2"
version = "0.4.3"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
......
......@@ -31,7 +31,8 @@ def mock_v1_local_models2(respx_mock: respx.MockRouter, base_url: str) -> None:
)
# Updated test for non-hosted URLs that may need normalization.
# marking as skip because base_url validation is removed
@pytest.mark.skip(reason="base_url validation is removed")
@pytest.mark.parametrize(
"base_url",
[
......@@ -98,6 +99,8 @@ def test_proxy_base_url(base_url: str, mock_v1_local_models2: None) -> None:
assert client.base_url == base_url
# marking as skip because base_url validation is removed
@pytest.mark.skip(reason="base_url validation is removed")
@pytest.mark.parametrize(
"base_url",
[
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment