Skip to content
Snippets Groups Projects
Unverified Commit a10af2c3 authored by Matthew Farrellee's avatar Matthew Farrellee Committed by GitHub
Browse files

change available_models to return List[Model], previously List[str] (#16968)

parent 018eaca8
No related branches found
No related tags found
No related merge requests found
......@@ -39,6 +39,12 @@ from llama_index.multi_modal_llms.nvidia.utils import (
import aiohttp
import json
from llama_index.core.bridge.pydantic import BaseModel
class Model(BaseModel):
id: str
class NVIDIAClient:
def __init__(
......@@ -58,14 +64,14 @@ class NVIDIAClient:
headers["accept"] = "text/event-stream" if stream else "application/json"
return headers
def get_model_details(self) -> List[str]:
def get_model_details(self) -> List[Model]:
"""
Get model details.
Returns:
List of models
"""
return list(NVIDIA_MULTI_MODAL_MODELS.keys())
return [Model(id=model) for model in NVIDIA_MULTI_MODAL_MODELS]
def request(
self,
......@@ -198,7 +204,7 @@ class NVIDIAMultiModal(MultiModalLLM):
)
@property
def available_models(self):
def available_models(self) -> List[Model]:
return self._client.get_model_details()
def _get_credential_kwargs(self) -> Dict[str, Any]:
......
......@@ -27,7 +27,7 @@ license = "MIT"
name = "llama-index-multi-modal-llms-nvidia"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.1.0"
version = "0.2.0"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
......
......@@ -8,4 +8,4 @@ def test_available_models() -> None:
models = NVIDIAMultiModal().available_models
assert models
assert isinstance(models, list)
assert all(isinstance(model, str) for model in models)
assert all(isinstance(model.id, str) for model in models)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment