Skip to content
Snippets Groups Projects
Unverified Commit a10af2c3 authored by Matthew Farrellee's avatar Matthew Farrellee Committed by GitHub
Browse files

change available_models to return List[Model], previously List[str] (#16968)

parent 018eaca8
No related branches found
No related tags found
No related merge requests found
...@@ -39,6 +39,12 @@ from llama_index.multi_modal_llms.nvidia.utils import ( ...@@ -39,6 +39,12 @@ from llama_index.multi_modal_llms.nvidia.utils import (
import aiohttp import aiohttp
import json import json
from llama_index.core.bridge.pydantic import BaseModel
class Model(BaseModel):
id: str
class NVIDIAClient: class NVIDIAClient:
def __init__( def __init__(
...@@ -58,14 +64,14 @@ class NVIDIAClient: ...@@ -58,14 +64,14 @@ class NVIDIAClient:
headers["accept"] = "text/event-stream" if stream else "application/json" headers["accept"] = "text/event-stream" if stream else "application/json"
return headers return headers
def get_model_details(self) -> List[str]: def get_model_details(self) -> List[Model]:
""" """
Get model details. Get model details.
Returns: Returns:
List of models List of models
""" """
return list(NVIDIA_MULTI_MODAL_MODELS.keys()) return [Model(id=model) for model in NVIDIA_MULTI_MODAL_MODELS]
def request( def request(
self, self,
...@@ -198,7 +204,7 @@ class NVIDIAMultiModal(MultiModalLLM): ...@@ -198,7 +204,7 @@ class NVIDIAMultiModal(MultiModalLLM):
) )
@property @property
def available_models(self): def available_models(self) -> List[Model]:
return self._client.get_model_details() return self._client.get_model_details()
def _get_credential_kwargs(self) -> Dict[str, Any]: def _get_credential_kwargs(self) -> Dict[str, Any]:
......
...@@ -27,7 +27,7 @@ license = "MIT" ...@@ -27,7 +27,7 @@ license = "MIT"
name = "llama-index-multi-modal-llms-nvidia" name = "llama-index-multi-modal-llms-nvidia"
packages = [{include = "llama_index/"}] packages = [{include = "llama_index/"}]
readme = "README.md" readme = "README.md"
version = "0.1.0" version = "0.2.0"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
......
...@@ -8,4 +8,4 @@ def test_available_models() -> None: ...@@ -8,4 +8,4 @@ def test_available_models() -> None:
models = NVIDIAMultiModal().available_models models = NVIDIAMultiModal().available_models
assert models assert models
assert isinstance(models, list) assert isinstance(models, list)
assert all(isinstance(model, str) for model in models) assert all(isinstance(model.id, str) for model in models)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment