diff --git a/docs/docs/examples/llm/nvidia_triton.ipynb b/docs/docs/examples/llm/nvidia_triton.ipynb index ab33d0f7dc668614dd27b2936eb0868ed4aa24fd..66be41c51934362b6c13fb83ffb88586a011175e 100644 --- a/docs/docs/examples/llm/nvidia_triton.ipynb +++ b/docs/docs/examples/llm/nvidia_triton.ipynb @@ -144,6 +144,34 @@ "```\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Call `stream_complete` with a prompt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```python\n", + "resp = NvidiaTriton(server_url=triton_url, model_name=model_name, tokens=32).stream_complete(\"The tallest mountain in North America is \")\n", + "for delta in resp:\n", + " print(delta.delta, end=\" \")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should expect the following response as a stream\n", + "```\n", + "the Great Pyramid of Giza, which is about 1,000 feet high. The Great Pyramid of Giza is the tallest mountain in North America.\n", + "```\n" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia-triton/llama_index/llms/nvidia_triton/base.py b/llama-index-integrations/llms/llama-index-llms-nvidia-triton/llama_index/llms/nvidia_triton/base.py index b15793c7810c2588a00bc6db87aa01a5257c339c..e970ceaaa2385c1ba793c3f27808fa4132831951 100644 --- a/llama-index-integrations/llms/llama-index-llms-nvidia-triton/llama_index/llms/nvidia_triton/base.py +++ b/llama-index-integrations/llms/llama-index-llms-nvidia-triton/llama_index/llms/nvidia_triton/base.py @@ -47,6 +47,7 @@ from llama_index.core.base.llms.types import ( from llama_index.core.llms.callbacks import llm_chat_callback from llama_index.core.base.llms.generic_utils import ( completion_to_chat_decorator, + stream_completion_to_chat_decorator, ) from llama_index.core.llms.llm import LLM from llama_index.llms.nvidia_triton.utils import GrpcTritonClient @@ -236,10 +237,12 @@ class NvidiaTriton(LLM): chat_fn = completion_to_chat_decorator(self.complete) return chat_fn(messages, **kwargs) + @llm_chat_callback() def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: - raise NotImplementedError + chat_stream_fn = stream_completion_to_chat_decorator(self.stream_complete) + return chat_stream_fn(messages, **kwargs) def complete( self, prompt: str, formatted: bool = False, **kwargs: Any @@ -266,7 +269,7 @@ class NvidiaTriton(LLM): if isinstance(token, InferenceServerException): client.stop_stream(model_params["model_name"], request_id) raise token - response = response + token + response += token return CompletionResponse( text=response, @@ -275,7 +278,34 @@ class NvidiaTriton(LLM): def stream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseGen: - raise NotImplementedError + from tritonclient.utils import InferenceServerException + + client = self._get_client() + + invocation_params = self._get_model_default_parameters + invocation_params.update(kwargs) + invocation_params["prompt"] = [[prompt]] + model_params = self._identifying_params + model_params.update(kwargs) + request_id = str(random.randint(1, 9999999)) # nosec + + if self.triton_load_model_call: + client.load_model(model_params["model_name"]) + + result_queue = client.request_streaming( + model_params["model_name"], request_id, **invocation_params + ) + + def gen() -> CompletionResponseGen: + text = "" + for token in result_queue: + if isinstance(token, InferenceServerException): + client.stop_stream(model_params["model_name"], request_id) + raise token + text += token + yield CompletionResponse(text=text, delta=token) + + return gen() async def achat( self, messages: Sequence[ChatMessage], **kwargs: Any diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia-triton/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-nvidia-triton/pyproject.toml index d5f681d3496abf9d68a4a9d8e5bc1a0ce0377f52..67cfd96f7b9bf21ce0c842bb58a80b596096c41c 100644 --- a/llama-index-integrations/llms/llama-index-llms-nvidia-triton/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-nvidia-triton/pyproject.toml @@ -21,13 +21,13 @@ ignore_missing_imports = true python_version = "3.8" [tool.poetry] -authors = ["Your Name <you@example.com>"] +authors = ["Rohith Ramakrishnan <rrohith2001@gmail.com>"] description = "llama-index llms nvidia triton integration" exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-nvidia-triton" readme = "README.md" -version = "0.1.4" +version = "0.1.5" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia-triton/tests/BUILD b/llama-index-integrations/llms/llama-index-llms-nvidia-triton/tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dabf212d7e7162849c24a733909ac4f645d75a31 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-nvidia-triton/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia-triton/tests/__init__.py b/llama-index-integrations/llms/llama-index-llms-nvidia-triton/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia-triton/tests/test_llms_nvidia_triton.py b/llama-index-integrations/llms/llama-index-llms-nvidia-triton/tests/test_llms_nvidia_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..85b014c85a39e21a3d7bac2b0a7857f1845ba28e --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-nvidia-triton/tests/test_llms_nvidia_triton.py @@ -0,0 +1,7 @@ +from llama_index.core.base.llms.base import BaseLLM +from llama_index.llms.nvidia_triton import NvidiaTriton + + +def test_text_inference_embedding_class(): + names_of_base_classes = [b.__name__ for b in NvidiaTriton.__mro__] + assert BaseLLM.__name__ in names_of_base_classes