diff --git a/.gitignore b/.gitignore
index 2313f2f35e95acdfcda7fa166e82e621c415780f..990c18de229088f55c6c514fd0f2d49981d1b0e7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -142,6 +142,9 @@ dmypy.json
 modules/
 *.swp
 
+# VsCode
+.vscode
+
 # pipenv
 Pipfile
 Pipfile.lock
diff --git a/docs/examples/llm/nvidia_triton.ipynb b/docs/examples/llm/nvidia_triton.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..f0a8202da0426ed77ac7568eb877b63b56aa2f82
--- /dev/null
+++ b/docs/examples/llm/nvidia_triton.ipynb
@@ -0,0 +1,136 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<a href=\"https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/examples/llm/nvidia_triton.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Nvidia Triton"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "\n",
+    "Nvidia's Triton is an inference server that provides API access to hosted LLM models. This connector allows for llama_index to remotely interact with a Triton inference server over GRPC to accelerate inference operations.\n",
+    "\n",
+    "[Triton Inference Server Github](https://github.com/triton-inference-server/server)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Install tritonclient\n",
+    "Since we are interacting with the Triton inference server we will need to install the `tritonclient` package. The `tritonclient` package.\n",
+    "\n",
+    "`tritonclient` can be easily installed using `pip3 install tritonclient`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!pip3 install tritonclient"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Basic Usage"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Call `complete` with a prompt"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from llama_index.llms import NvidiaTriton\n",
+    "\n",
+    "# A Triton server instance must be running. Use the correct URL for your desired Triton server instance.\n",
+    "triton_url = \"localhost:8001\"\n",
+    "resp = NvidiaTriton().complete(\"The tallest mountain in North America is \")\n",
+    "print(resp)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Call `chat` with a list of messages"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from llama_index.llms import ChatMessage, NvidiaTriton\n",
+    "\n",
+    "messages = [\n",
+    "    ChatMessage(\n",
+    "        role=\"system\",\n",
+    "        content=\"You are a clown named bozo that has had a rough day at the circus\",\n",
+    "    ),\n",
+    "    ChatMessage(role=\"user\", content=\"What has you down bozo?\"),\n",
+    "]\n",
+    "resp = NvidiaTriton().chat(messages)\n",
+    "print(resp)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Further Examples\n",
+    "Remember that a Triton instance represents a running server instance therefore you should ensure you have a valid server configuration running and change the `localhost:8001` to the correct IP/hostname:port combination for your server.\n",
+    "\n",
+    "An example of setting up this environment can be found at Nvidia's (GenerativeAIExamples Github Repo)[https://github.com/NVIDIA/GenerativeAIExamples/tree/main/RetrievalAugmentedGeneration]"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3"
+  },
+  "vscode": {
+   "interpreter": {
+    "hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs/module_guides/models/llms/modules.md b/docs/module_guides/models/llms/modules.md
index 504ece86ba7d65bd48770700e70b6c1c19064345..05b3aaa92d24b6203b5dea9b9bf1c68f0410f134 100644
--- a/docs/module_guides/models/llms/modules.md
+++ b/docs/module_guides/models/llms/modules.md
@@ -149,6 +149,15 @@ maxdepth: 1
 /examples/llm/monsterapi.ipynb
 ```
 
+## Nivida Triton
+
+```{toctree}
+---
+maxdepth: 1
+---
+/examples/llm/nvidia_triton.ipynb
+```
+
 ## Ollama
 
 ```{toctree}
diff --git a/llama_index/llms/__init__.py b/llama_index/llms/__init__.py
index 95c9ea25a7c534e9aad19369aee225adf338559e..bf4d304e502fc10ace28c9aae22ebe9d6ad042da 100644
--- a/llama_index/llms/__init__.py
+++ b/llama_index/llms/__init__.py
@@ -19,6 +19,7 @@ from llama_index.llms.localai import LOCALAI_DEFAULTS, LocalAI
 from llama_index.llms.mistral import MistralAI
 from llama_index.llms.mock import MockLLM
 from llama_index.llms.monsterapi import MonsterLLM
+from llama_index.llms.nvidia_triton import NvidiaTriton
 from llama_index.llms.ollama import Ollama
 from llama_index.llms.openai import OpenAI
 from llama_index.llms.openai_like import OpenAILike
@@ -78,6 +79,7 @@ __all__ = [
     "MessageRole",
     "MockLLM",
     "MonsterLLM",
+    "NvidiaTriton",
     "MistralAI",
     "Ollama",
     "OpenAI",
diff --git a/llama_index/llms/nvidia_triton.py b/llama_index/llms/nvidia_triton.py
new file mode 100644
index 0000000000000000000000000000000000000000..468892d520115639dc2112aac448fecd7384046a
--- /dev/null
+++ b/llama_index/llms/nvidia_triton.py
@@ -0,0 +1,242 @@
+import random
+from typing import (
+    Any,
+    Dict,
+    Optional,
+    Sequence,
+)
+
+from llama_index.bridge.pydantic import Field, PrivateAttr
+from llama_index.callbacks import CallbackManager
+from llama_index.llms.base import (
+    ChatMessage,
+    ChatResponse,
+    ChatResponseAsyncGen,
+    ChatResponseGen,
+    CompletionResponse,
+    CompletionResponseAsyncGen,
+    CompletionResponseGen,
+    LLMMetadata,
+    llm_chat_callback,
+)
+from llama_index.llms.generic_utils import (
+    completion_to_chat_decorator,
+)
+from llama_index.llms.llm import LLM
+from llama_index.llms.nvidia_triton_utils import GrpcTritonClient
+
+DEFAULT_SERVER_URL = "localhost:8001"
+DEFAULT_MAX_RETRIES = 3
+DEFAULT_TIMEOUT = 60.0
+DEFAULT_MODEL = "ensemble"
+DEFAULT_TEMPERATURE = 1.0
+DEFAULT_TOP_P = 0
+DEFAULT_TOP_K = 1.0
+DEFAULT_MAX_TOKENS = 100
+DEFAULT_BEAM_WIDTH = 1
+DEFAULT_REPTITION_PENALTY = 1.0
+DEFAULT_LENGTH_PENALTY = 1.0
+DEFAULT_REUSE_CLIENT = True
+DEFAULT_TRITON_LOAD_MODEL = True
+
+
+class NvidiaTriton(LLM):
+    server_url: str = Field(
+        default=DEFAULT_SERVER_URL,
+        description="The URL of the Triton inference server to use.",
+    )
+    model_name: str = Field(
+        default=DEFAULT_MODEL,
+        description="The name of the Triton hosted model this client should use",
+    )
+    temperature: Optional[float] = Field(
+        default=DEFAULT_TEMPERATURE, description="Temperature to use for sampling"
+    )
+    top_p: Optional[float] = Field(
+        default=DEFAULT_TOP_P, description="The top-p value to use for sampling"
+    )
+    top_k: Optional[float] = Field(
+        default=DEFAULT_TOP_K, description="The top k value to use for sampling"
+    )
+    tokens: Optional[int] = Field(
+        default=DEFAULT_MAX_TOKENS,
+        description="The maximum number of tokens to generate.",
+    )
+    beam_width: Optional[int] = Field(
+        default=DEFAULT_BEAM_WIDTH, description="Last n number of tokens to penalize"
+    )
+    repetition_penalty: Optional[float] = Field(
+        default=DEFAULT_REPTITION_PENALTY,
+        description="Last n number of tokens to penalize",
+    )
+    length_penalty: Optional[float] = Field(
+        default=DEFAULT_LENGTH_PENALTY,
+        description="The penalty to apply repeated tokens",
+    )
+    max_retries: Optional[int] = Field(
+        default=DEFAULT_MAX_RETRIES,
+        description="Maximum number of attempts to retry Triton client invocation before erroring",
+    )
+    timeout: Optional[float] = Field(
+        default=DEFAULT_TIMEOUT,
+        description="Maximum time (seconds) allowed for a Triton client call before erroring",
+    )
+    reuse_client: Optional[bool] = Field(
+        default=DEFAULT_REUSE_CLIENT,
+        description="True for reusing the same client instance between invocations",
+    )
+    triton_load_model_call: Optional[bool] = Field(
+        default=DEFAULT_TRITON_LOAD_MODEL,
+        description="True if a Triton load model API call should be made before using the client",
+    )
+
+    _client: Optional[GrpcTritonClient] = PrivateAttr()
+
+    def __init__(
+        self,
+        server_url: str = DEFAULT_SERVER_URL,
+        model: str = DEFAULT_MODEL,
+        temperature: float = DEFAULT_TEMPERATURE,
+        top_p: float = DEFAULT_TOP_P,
+        top_k: float = DEFAULT_TOP_K,
+        tokens: Optional[int] = DEFAULT_MAX_TOKENS,
+        beam_width: int = DEFAULT_BEAM_WIDTH,
+        repetition_penalty: float = DEFAULT_REPTITION_PENALTY,
+        length_penalty: float = DEFAULT_LENGTH_PENALTY,
+        max_retries: int = DEFAULT_MAX_RETRIES,
+        timeout: float = DEFAULT_TIMEOUT,
+        reuse_client: bool = DEFAULT_REUSE_CLIENT,
+        triton_load_model_call: bool = DEFAULT_TRITON_LOAD_MODEL,
+        callback_manager: Optional[CallbackManager] = None,
+        additional_kwargs: Optional[Dict[str, Any]] = None,
+        **kwargs: Any,
+    ) -> None:
+        additional_kwargs = additional_kwargs or {}
+
+        super().__init__(
+            server_url=server_url,
+            model=model,
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            tokens=tokens,
+            beam_width=beam_width,
+            repetition_penalty=repetition_penalty,
+            length_penalty=length_penalty,
+            max_retries=max_retries,
+            timeout=timeout,
+            reuse_client=reuse_client,
+            triton_load_model_call=triton_load_model_call,
+            callback_manager=callback_manager,
+            additional_kwargs=additional_kwargs,
+            **kwargs,
+        )
+
+        try:
+            self._client = GrpcTritonClient(server_url)
+        except ImportError as err:
+            raise ImportError(
+                "Could not import triton client python package. "
+                "Please install it with `pip install tritonclient`."
+            ) from err
+
+    @property
+    def _get_model_default_parameters(self) -> Dict[str, Any]:
+        return {
+            "tokens": self.tokens,
+            "top_k": self.top_k,
+            "top_p": self.top_p,
+            "temperature": self.temperature,
+            "repetition_penalty": self.repetition_penalty,
+            "length_penalty": self.length_penalty,
+            "beam_width": self.beam_width,
+        }
+
+    @property
+    def _invocation_params(self, **kwargs: Any) -> Dict[str, Any]:
+        return {**self._get_model_default_parameters, **kwargs}
+
+    @property
+    def _identifying_params(self) -> Dict[str, Any]:
+        """Get all the identifying parameters."""
+        return {
+            "server_url": self.server_url,
+            "model_name": self.model_name,
+        }
+
+    def _get_client(self) -> Any:
+        """Create or reuse a Triton client connection."""
+        if not self.reuse_client:
+            return GrpcTritonClient(self.server_url)
+
+        if self._client is None:
+            self._client = GrpcTritonClient(self.server_url)
+        return self._client
+
+    @property
+    def metadata(self) -> LLMMetadata:
+        """Gather and return metadata about the user Triton configured LLM model."""
+        return LLMMetadata(
+            model_name=self.model_name,
+        )
+
+    @llm_chat_callback()
+    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
+        chat_fn = completion_to_chat_decorator(self.complete)
+        return chat_fn(messages, **kwargs)
+
+    def stream_chat(
+        self, messages: Sequence[ChatMessage], **kwargs: Any
+    ) -> ChatResponseGen:
+        raise NotImplementedError
+
+    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
+        from tritonclient.utils import InferenceServerException
+
+        client = self._get_client()
+
+        invocation_params = self._get_model_default_parameters
+        invocation_params.update(kwargs)
+        invocation_params["prompt"] = [[prompt]]
+        model_params = self._identifying_params
+        model_params.update(kwargs)
+        request_id = str(random.randint(1, 9999999))  # nosec
+
+        if self.triton_load_model_call:
+            client.load_model(model_params["model_name"])
+
+        result_queue = client.request_streaming(
+            model_params["model_name"], request_id, **invocation_params
+        )
+
+        response = ""
+        for token in result_queue:
+            if isinstance(token, InferenceServerException):
+                client.stop_stream(model_params["model_name"], request_id)
+                raise token
+            response = response + token
+
+        return CompletionResponse(
+            text=response,
+        )
+
+    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
+        raise NotImplementedError
+
+    async def achat(
+        self, messages: Sequence[ChatMessage], **kwargs: Any
+    ) -> ChatResponse:
+        raise NotImplementedError
+
+    async def acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
+        raise NotImplementedError
+
+    async def astream_chat(
+        self, messages: Sequence[ChatMessage], **kwargs: Any
+    ) -> ChatResponseAsyncGen:
+        raise NotImplementedError
+
+    async def astream_complete(
+        self, prompt: str, **kwargs: Any
+    ) -> CompletionResponseAsyncGen:
+        raise NotImplementedError
diff --git a/llama_index/llms/nvidia_triton_utils.py b/llama_index/llms/nvidia_triton_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1452ff5fc667eefa149f84e17607d6f99d76f98b
--- /dev/null
+++ b/llama_index/llms/nvidia_triton_utils.py
@@ -0,0 +1,343 @@
+import abc
+import json
+import random
+import time
+from functools import partial
+from queue import Queue
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    List,
+    Optional,
+    Type,
+    Union,
+)
+
+import numpy as np
+
+if TYPE_CHECKING:
+    import tritonclient.grpc as grpcclient
+    import tritonclient.http as httpclient
+
+STOP_WORDS = ["</s>"]
+RANDOM_SEED = 0
+
+
+class StreamingResponseGenerator(Queue):
+    """A Generator that provides the inference results from an LLM."""
+
+    def __init__(
+        self, client: "GrpcTritonClient", request_id: str, force_batch: bool
+    ) -> None:
+        """Instantiate the generator class."""
+        super().__init__()
+        self._client = client
+        self.request_id = request_id
+        self._batch = force_batch
+
+    def __iter__(self) -> "StreamingResponseGenerator":
+        """Return self as a generator."""
+        return self
+
+    def __next__(self) -> str:
+        """Return the next retrieved token."""
+        val = self.get()
+        if val is None or val in STOP_WORDS:
+            self._stop_stream()
+            raise StopIteration
+        return val
+
+    def _stop_stream(self) -> None:
+        """Drain and shutdown the Triton stream."""
+        self._client.stop_stream(
+            "tensorrt_llm", self.request_id, signal=not self._batch
+        )
+
+
+class _BaseTritonClient(abc.ABC):
+    """An abstraction of the connection to a triton inference server."""
+
+    def __init__(self, server_url: str) -> None:
+        """Initialize the client."""
+        self._server_url = server_url
+        self._client = self._inference_server_client(server_url)
+
+    @property
+    @abc.abstractmethod
+    def _inference_server_client(
+        self,
+    ) -> Union[
+        Type["grpcclient.InferenceServerClient"],
+        Type["httpclient.InferenceServerClient"],
+    ]:
+        """Return the preferred InferenceServerClient class."""
+
+    @property
+    @abc.abstractmethod
+    def _infer_input(
+        self,
+    ) -> Union[Type["grpcclient.InferInput"], Type["httpclient.InferInput"]]:
+        """Return the preferred InferInput."""
+
+    @property
+    @abc.abstractmethod
+    def _infer_output(
+        self,
+    ) -> Union[
+        Type["grpcclient.InferRequestedOutput"], Type["httpclient.InferRequestedOutput"]
+    ]:
+        """Return the preferred InferRequestedOutput."""
+
+    def load_model(self, model_name: str, timeout: int = 1000) -> None:
+        """Load a model into the server."""
+        if self._client.is_model_ready(model_name):
+            return
+
+        self._client.load_model(model_name)
+        t0 = time.perf_counter()
+        t1 = t0
+        while not self._client.is_model_ready(model_name) and t1 - t0 < timeout:
+            t1 = time.perf_counter()
+
+        if not self._client.is_model_ready(model_name):
+            raise RuntimeError(f"Failed to load {model_name} on Triton in {timeout}s")
+
+    def get_model_list(self) -> List[str]:
+        """Get a list of models loaded in the triton server."""
+        res = self._client.get_model_repository_index(as_json=True)
+        return [model["name"] for model in res["models"]]
+
+    def get_model_concurrency(self, model_name: str, timeout: int = 1000) -> int:
+        """Get the model concurrency."""
+        self.load_model(model_name, timeout)
+        instances = self._client.get_model_config(model_name, as_json=True)["config"][
+            "instance_group"
+        ]
+        return sum(instance["count"] * len(instance["gpus"]) for instance in instances)
+
+    def _generate_stop_signals(
+        self,
+    ) -> List[Union["grpcclient.InferInput", "httpclient.InferInput"]]:
+        """Generate the signal to stop the stream."""
+        inputs = [
+            self._infer_input("input_ids", [1, 1], "INT32"),
+            self._infer_input("input_lengths", [1, 1], "INT32"),
+            self._infer_input("request_output_len", [1, 1], "UINT32"),
+            self._infer_input("stop", [1, 1], "BOOL"),
+        ]
+        inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32))
+        inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32))
+        inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32))
+        inputs[3].set_data_from_numpy(np.array([[True]], dtype="bool"))
+        return inputs
+
+    def _generate_outputs(
+        self,
+    ) -> List[
+        Union["grpcclient.InferRequestedOutput", "httpclient.InferRequestedOutput"]
+    ]:
+        """Generate the expected output structure."""
+        return [self._infer_output("text_output")]
+
+    def _prepare_tensor(
+        self, name: str, input_data: Any
+    ) -> Union["grpcclient.InferInput", "httpclient.InferInput"]:
+        """Prepare an input data structure."""
+        from tritonclient.utils import np_to_triton_dtype
+
+        t = self._infer_input(
+            name, input_data.shape, np_to_triton_dtype(input_data.dtype)
+        )
+        t.set_data_from_numpy(input_data)
+        return t
+
+    def _generate_inputs(  # pylint: disable=too-many-arguments,too-many-locals
+        self,
+        prompt: str,
+        tokens: int = 300,
+        temperature: float = 1.0,
+        top_k: float = 1,
+        top_p: float = 0,
+        beam_width: int = 1,
+        repetition_penalty: float = 1,
+        length_penalty: float = 1.0,
+        stream: bool = True,
+    ) -> List[Union["grpcclient.InferInput", "httpclient.InferInput"]]:
+        """Create the input for the triton inference server."""
+        query = np.array(prompt).astype(object)
+        request_output_len = np.array([tokens]).astype(np.uint32).reshape((1, -1))
+        runtime_top_k = np.array([top_k]).astype(np.uint32).reshape((1, -1))
+        runtime_top_p = np.array([top_p]).astype(np.float32).reshape((1, -1))
+        temperature_array = np.array([temperature]).astype(np.float32).reshape((1, -1))
+        len_penalty = np.array([length_penalty]).astype(np.float32).reshape((1, -1))
+        repetition_penalty_array = (
+            np.array([repetition_penalty]).astype(np.float32).reshape((1, -1))
+        )
+        random_seed = np.array([RANDOM_SEED]).astype(np.uint64).reshape((1, -1))
+        beam_width_array = np.array([beam_width]).astype(np.uint32).reshape((1, -1))
+        streaming_data = np.array([[stream]], dtype=bool)
+
+        return [
+            self._prepare_tensor("text_input", query),
+            self._prepare_tensor("max_tokens", request_output_len),
+            self._prepare_tensor("top_k", runtime_top_k),
+            self._prepare_tensor("top_p", runtime_top_p),
+            self._prepare_tensor("temperature", temperature_array),
+            self._prepare_tensor("length_penalty", len_penalty),
+            self._prepare_tensor("repetition_penalty", repetition_penalty_array),
+            self._prepare_tensor("random_seed", random_seed),
+            self._prepare_tensor("beam_width", beam_width_array),
+            self._prepare_tensor("stream", streaming_data),
+        ]
+
+    def _trim_batch_response(self, result_str: str) -> str:
+        """Trim the resulting response from a batch request by removing provided prompt and extra generated text."""
+        # extract the generated part of the prompt
+        split = result_str.split("[/INST]", 1)
+        generated = split[-1]
+        end_token = generated.find("</s>")
+        if end_token == -1:
+            return generated
+        return generated[:end_token].strip()
+
+
+class GrpcTritonClient(_BaseTritonClient):
+    """GRPC connection to a triton inference server."""
+
+    @property
+    def _inference_server_client(
+        self,
+    ) -> Type["grpcclient.InferenceServerClient"]:
+        """Return the preferred InferenceServerClient class."""
+        import tritonclient.grpc as grpcclient
+
+        return grpcclient.InferenceServerClient  # type: ignore
+
+    @property
+    def _infer_input(self) -> Type["grpcclient.InferInput"]:
+        """Return the preferred InferInput."""
+        import tritonclient.grpc as grpcclient
+
+        return grpcclient.InferInput  # type: ignore
+
+    @property
+    def _infer_output(
+        self,
+    ) -> Type["grpcclient.InferRequestedOutput"]:
+        """Return the preferred InferRequestedOutput."""
+        import tritonclient.grpc as grpcclient
+
+        return grpcclient.InferRequestedOutput  # type: ignore
+
+    def _send_stop_signals(self, model_name: str, request_id: str) -> None:
+        """Send the stop signal to the Triton Inference server."""
+        stop_inputs = self._generate_stop_signals()
+        self._client.async_stream_infer(
+            model_name,
+            stop_inputs,
+            request_id=request_id,
+            parameters={"Streaming": True},
+        )
+
+    @staticmethod
+    def _process_result(result: Dict[str, str]) -> str:
+        """Post-process the result from the server."""
+        import google.protobuf.json_format
+        import tritonclient.grpc as grpcclient
+        from tritonclient.grpc.service_pb2 import ModelInferResponse
+
+        message = ModelInferResponse()
+        generated_text: str = ""
+        google.protobuf.json_format.Parse(json.dumps(result), message)
+        infer_result = grpcclient.InferResult(message)
+        np_res = infer_result.as_numpy("text_output")
+
+        generated_text = ""
+        if np_res is not None:
+            generated_text = "".join([token.decode() for token in np_res])
+
+        return generated_text
+
+    def _stream_callback(
+        self,
+        result_queue: Queue,
+        force_batch: bool,
+        result: Any,
+        error: str,
+    ) -> None:
+        """Add streamed result to queue."""
+        if error:
+            result_queue.put(error)
+        else:
+            response_raw = result.get_response(as_json=True)
+            if "outputs" in response_raw:
+                # the very last response might have no output, just the final flag
+                response = self._process_result(response_raw)
+                if force_batch:
+                    response = self._trim_batch_response(response)
+
+                if response in STOP_WORDS:
+                    result_queue.put(None)
+                else:
+                    result_queue.put(response)
+
+            if response_raw["parameters"]["triton_final_response"]["bool_param"]:
+                # end of the generation
+                result_queue.put(None)
+
+    # pylint: disable-next=too-many-arguments
+    def _send_prompt_streaming(
+        self,
+        model_name: str,
+        request_inputs: Any,
+        request_outputs: Optional[Any],
+        request_id: str,
+        result_queue: StreamingResponseGenerator,
+        force_batch: bool = False,
+    ) -> None:
+        """Send the prompt and start streaming the result."""
+        self._client.start_stream(
+            callback=partial(self._stream_callback, result_queue, force_batch)
+        )
+        self._client.async_stream_infer(
+            model_name=model_name,
+            inputs=request_inputs,
+            outputs=request_outputs,
+            request_id=request_id,
+        )
+
+    def request_streaming(
+        self,
+        model_name: str,
+        request_id: Optional[str] = None,
+        force_batch: bool = False,
+        **params: Any,
+    ) -> StreamingResponseGenerator:
+        """Request a streaming connection."""
+        if not self._client.is_model_ready(model_name):
+            raise RuntimeError("Cannot request streaming, model is not loaded")
+
+        if not request_id:
+            request_id = str(random.randint(1, 9999999))  # nosec
+
+        result_queue = StreamingResponseGenerator(self, request_id, force_batch)
+        inputs = self._generate_inputs(stream=not force_batch, **params)
+        outputs = self._generate_outputs()
+        self._send_prompt_streaming(
+            model_name,
+            inputs,
+            outputs,
+            request_id,
+            result_queue,
+            force_batch,
+        )
+        return result_queue
+
+    def stop_stream(
+        self, model_name: str, request_id: str, signal: bool = True
+    ) -> None:
+        """Close the streaming connection."""
+        if signal:
+            self._send_stop_signals(model_name, request_id)
+        self._client.stop_stream()
diff --git a/poetry.lock b/poetry.lock
index 54b979c328c63c7f03cce94e35efc631082c2339..d148b69822800e43e02dd8aecfd2c7f16d21ea02 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
 
 [[package]]
 name = "accelerate"
@@ -2572,13 +2572,13 @@ files = [
 
 [[package]]
 name = "langchain"
-version = "0.0.351"
+version = "0.0.352"
 description = "Building applications with LLMs through composability"
 optional = true
 python-versions = ">=3.8.1,<4.0"
 files = [
-    {file = "langchain-0.0.351-py3-none-any.whl", hash = "sha256:90cdaee27db2b2aeeb7b0709a79cbfe3e858fc9536b6bc3ea262135a6affc70f"},
-    {file = "langchain-0.0.351.tar.gz", hash = "sha256:6bf2a8665a7a3ca2bbd4eea9889ecfd3d39ab23a505549a03860272474399b38"},
+    {file = "langchain-0.0.352-py3-none-any.whl", hash = "sha256:43ab580e1223e5d7c3495b3c0cb79e2f3a0ecb52caf8126271fb10d42cede2d0"},
+    {file = "langchain-0.0.352.tar.gz", hash = "sha256:8928d7b63d73af9681fe1b2a2b99b84238efef61ed537de666160fd001f41efd"},
 ]
 
 [package.dependencies]
@@ -2612,13 +2612,13 @@ text-helpers = ["chardet (>=5.1.0,<6.0.0)"]
 
 [[package]]
 name = "langchain-community"
-version = "0.0.5"
+version = "0.0.6"
 description = "Community contributed LangChain integrations."
 optional = true
 python-versions = ">=3.8.1,<4.0"
 files = [
-    {file = "langchain_community-0.0.5-py3-none-any.whl", hash = "sha256:7579ff28b3bbaa73dd17ee5e88b84d09c785691d2af2f00f0bb98e7478072af6"},
-    {file = "langchain_community-0.0.5.tar.gz", hash = "sha256:425953df8035b6d278fa724a5d1a33b95ced3787ffff5b9128b3c16f0474335e"},
+    {file = "langchain_community-0.0.6-py3-none-any.whl", hash = "sha256:13b16da0f89c328df456911ff03069e4d919f647c7dd3bfc5062525cf956ed82"},
+    {file = "langchain_community-0.0.6.tar.gz", hash = "sha256:b7deb63fd8205d54b51cf8b1702de15d1da77987f8465c356b158a65adff378c"},
 ]
 
 [package.dependencies]
@@ -2634,17 +2634,17 @@ tenacity = ">=8.1.0,<9.0.0"
 
 [package.extras]
 cli = ["typer (>=0.9.0,<0.10.0)"]
-extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"]
+extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"]
 
 [[package]]
 name = "langchain-core"
-version = "0.1.2"
+version = "0.1.3"
 description = "Building applications with LLMs through composability"
 optional = true
 python-versions = ">=3.8.1,<4.0"
 files = [
-    {file = "langchain_core-0.1.2-py3-none-any.whl", hash = "sha256:823de99910081be46b127ae4fd7066acea82b8ac48742fb34e6d3c7f5d1a03ce"},
-    {file = "langchain_core-0.1.2.tar.gz", hash = "sha256:6fd641ca776974d0adeb4aa390ebb173138d75e93a1c6928bb05b21f5e81cb1f"},
+    {file = "langchain_core-0.1.3-py3-none-any.whl", hash = "sha256:bfbbc5dfeb06cfe3fd078e7a12db3a4cfb9d28b715b200a64f7abb7ae1976b17"},
+    {file = "langchain_core-0.1.3.tar.gz", hash = "sha256:d8898254dfea1c4ab614f470db40909969604775f7524175f6d9167ea58050c9"},
 ]
 
 [package.dependencies]
@@ -2676,13 +2676,13 @@ data = ["language-data (>=1.1,<2.0)"]
 
 [[package]]
 name = "langsmith"
-version = "0.0.72"
+version = "0.0.73"
 description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
 optional = true
 python-versions = ">=3.8.1,<4.0"
 files = [
-    {file = "langsmith-0.0.72-py3-none-any.whl", hash = "sha256:2cddd49cd7d1477409c8785746acf42dbd6709a7d36e751247a3cab5e3eee20e"},
-    {file = "langsmith-0.0.72.tar.gz", hash = "sha256:505f517e2e67836a4e831917d8223a18cc59d51bdae1e4295fc6dff2266bab5d"},
+    {file = "langsmith-0.0.73-py3-none-any.whl", hash = "sha256:cb99cc10a70d882a72cc2884d75b95744bd931ee24b0466c28b522a91354c566"},
+    {file = "langsmith-0.0.73.tar.gz", hash = "sha256:f21984aa8bb0a7749f355d3fc95e3f8bd663fb90f465b1692c18420d89241a0f"},
 ]
 
 [package.dependencies]
@@ -2802,6 +2802,7 @@ files = [
     {file = "lxml-4.9.4-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e8f9f93a23634cfafbad6e46ad7d09e0f4a25a2400e4a64b1b7b7c0fbaa06d9d"},
     {file = "lxml-4.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3f3f00a9061605725df1816f5713d10cd94636347ed651abdbc75828df302b20"},
     {file = "lxml-4.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:953dd5481bd6252bd480d6ec431f61d7d87fdcbbb71b0d2bdcfc6ae00bb6fb10"},
+    {file = "lxml-4.9.4-cp312-cp312-win32.whl", hash = "sha256:266f655d1baff9c47b52f529b5f6bec33f66042f65f7c56adde3fcf2ed62ae8b"},
     {file = "lxml-4.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:f1faee2a831fe249e1bae9cbc68d3cd8a30f7e37851deee4d7962b17c410dd56"},
     {file = "lxml-4.9.4-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:23d891e5bdc12e2e506e7d225d6aa929e0a0368c9916c1fddefab88166e98b20"},
     {file = "lxml-4.9.4-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e96a1788f24d03e8d61679f9881a883ecdf9c445a38f9ae3f3f193ab6c591c66"},
@@ -4612,7 +4613,6 @@ files = [
     {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"},
     {file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"},
     {file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"},
-    {file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"},
     {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"},
     {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"},
     {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},