diff --git a/.gitignore b/.gitignore index 2313f2f35e95acdfcda7fa166e82e621c415780f..990c18de229088f55c6c514fd0f2d49981d1b0e7 100644 --- a/.gitignore +++ b/.gitignore @@ -142,6 +142,9 @@ dmypy.json modules/ *.swp +# VsCode +.vscode + # pipenv Pipfile Pipfile.lock diff --git a/docs/examples/llm/nvidia_triton.ipynb b/docs/examples/llm/nvidia_triton.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f0a8202da0426ed77ac7568eb877b63b56aa2f82 --- /dev/null +++ b/docs/examples/llm/nvidia_triton.ipynb @@ -0,0 +1,136 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<a href=\"https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/examples/llm/nvidia_triton.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Nvidia Triton" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Nvidia's Triton is an inference server that provides API access to hosted LLM models. This connector allows for llama_index to remotely interact with a Triton inference server over GRPC to accelerate inference operations.\n", + "\n", + "[Triton Inference Server Github](https://github.com/triton-inference-server/server)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install tritonclient\n", + "Since we are interacting with the Triton inference server we will need to install the `tritonclient` package. The `tritonclient` package.\n", + "\n", + "`tritonclient` can be easily installed using `pip3 install tritonclient`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip3 install tritonclient" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Call `complete` with a prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms import NvidiaTriton\n", + "\n", + "# A Triton server instance must be running. Use the correct URL for your desired Triton server instance.\n", + "triton_url = \"localhost:8001\"\n", + "resp = NvidiaTriton().complete(\"The tallest mountain in North America is \")\n", + "print(resp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Call `chat` with a list of messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms import ChatMessage, NvidiaTriton\n", + "\n", + "messages = [\n", + " ChatMessage(\n", + " role=\"system\",\n", + " content=\"You are a clown named bozo that has had a rough day at the circus\",\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"What has you down bozo?\"),\n", + "]\n", + "resp = NvidiaTriton().chat(messages)\n", + "print(resp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Further Examples\n", + "Remember that a Triton instance represents a running server instance therefore you should ensure you have a valid server configuration running and change the `localhost:8001` to the correct IP/hostname:port combination for your server.\n", + "\n", + "An example of setting up this environment can be found at Nvidia's (GenerativeAIExamples Github Repo)[https://github.com/NVIDIA/GenerativeAIExamples/tree/main/RetrievalAugmentedGeneration]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + }, + "vscode": { + "interpreter": { + "hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/module_guides/models/llms/modules.md b/docs/module_guides/models/llms/modules.md index 504ece86ba7d65bd48770700e70b6c1c19064345..05b3aaa92d24b6203b5dea9b9bf1c68f0410f134 100644 --- a/docs/module_guides/models/llms/modules.md +++ b/docs/module_guides/models/llms/modules.md @@ -149,6 +149,15 @@ maxdepth: 1 /examples/llm/monsterapi.ipynb ``` +## Nivida Triton + +```{toctree} +--- +maxdepth: 1 +--- +/examples/llm/nvidia_triton.ipynb +``` + ## Ollama ```{toctree} diff --git a/llama_index/llms/__init__.py b/llama_index/llms/__init__.py index 95c9ea25a7c534e9aad19369aee225adf338559e..bf4d304e502fc10ace28c9aae22ebe9d6ad042da 100644 --- a/llama_index/llms/__init__.py +++ b/llama_index/llms/__init__.py @@ -19,6 +19,7 @@ from llama_index.llms.localai import LOCALAI_DEFAULTS, LocalAI from llama_index.llms.mistral import MistralAI from llama_index.llms.mock import MockLLM from llama_index.llms.monsterapi import MonsterLLM +from llama_index.llms.nvidia_triton import NvidiaTriton from llama_index.llms.ollama import Ollama from llama_index.llms.openai import OpenAI from llama_index.llms.openai_like import OpenAILike @@ -78,6 +79,7 @@ __all__ = [ "MessageRole", "MockLLM", "MonsterLLM", + "NvidiaTriton", "MistralAI", "Ollama", "OpenAI", diff --git a/llama_index/llms/nvidia_triton.py b/llama_index/llms/nvidia_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..468892d520115639dc2112aac448fecd7384046a --- /dev/null +++ b/llama_index/llms/nvidia_triton.py @@ -0,0 +1,242 @@ +import random +from typing import ( + Any, + Dict, + Optional, + Sequence, +) + +from llama_index.bridge.pydantic import Field, PrivateAttr +from llama_index.callbacks import CallbackManager +from llama_index.llms.base import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + llm_chat_callback, +) +from llama_index.llms.generic_utils import ( + completion_to_chat_decorator, +) +from llama_index.llms.llm import LLM +from llama_index.llms.nvidia_triton_utils import GrpcTritonClient + +DEFAULT_SERVER_URL = "localhost:8001" +DEFAULT_MAX_RETRIES = 3 +DEFAULT_TIMEOUT = 60.0 +DEFAULT_MODEL = "ensemble" +DEFAULT_TEMPERATURE = 1.0 +DEFAULT_TOP_P = 0 +DEFAULT_TOP_K = 1.0 +DEFAULT_MAX_TOKENS = 100 +DEFAULT_BEAM_WIDTH = 1 +DEFAULT_REPTITION_PENALTY = 1.0 +DEFAULT_LENGTH_PENALTY = 1.0 +DEFAULT_REUSE_CLIENT = True +DEFAULT_TRITON_LOAD_MODEL = True + + +class NvidiaTriton(LLM): + server_url: str = Field( + default=DEFAULT_SERVER_URL, + description="The URL of the Triton inference server to use.", + ) + model_name: str = Field( + default=DEFAULT_MODEL, + description="The name of the Triton hosted model this client should use", + ) + temperature: Optional[float] = Field( + default=DEFAULT_TEMPERATURE, description="Temperature to use for sampling" + ) + top_p: Optional[float] = Field( + default=DEFAULT_TOP_P, description="The top-p value to use for sampling" + ) + top_k: Optional[float] = Field( + default=DEFAULT_TOP_K, description="The top k value to use for sampling" + ) + tokens: Optional[int] = Field( + default=DEFAULT_MAX_TOKENS, + description="The maximum number of tokens to generate.", + ) + beam_width: Optional[int] = Field( + default=DEFAULT_BEAM_WIDTH, description="Last n number of tokens to penalize" + ) + repetition_penalty: Optional[float] = Field( + default=DEFAULT_REPTITION_PENALTY, + description="Last n number of tokens to penalize", + ) + length_penalty: Optional[float] = Field( + default=DEFAULT_LENGTH_PENALTY, + description="The penalty to apply repeated tokens", + ) + max_retries: Optional[int] = Field( + default=DEFAULT_MAX_RETRIES, + description="Maximum number of attempts to retry Triton client invocation before erroring", + ) + timeout: Optional[float] = Field( + default=DEFAULT_TIMEOUT, + description="Maximum time (seconds) allowed for a Triton client call before erroring", + ) + reuse_client: Optional[bool] = Field( + default=DEFAULT_REUSE_CLIENT, + description="True for reusing the same client instance between invocations", + ) + triton_load_model_call: Optional[bool] = Field( + default=DEFAULT_TRITON_LOAD_MODEL, + description="True if a Triton load model API call should be made before using the client", + ) + + _client: Optional[GrpcTritonClient] = PrivateAttr() + + def __init__( + self, + server_url: str = DEFAULT_SERVER_URL, + model: str = DEFAULT_MODEL, + temperature: float = DEFAULT_TEMPERATURE, + top_p: float = DEFAULT_TOP_P, + top_k: float = DEFAULT_TOP_K, + tokens: Optional[int] = DEFAULT_MAX_TOKENS, + beam_width: int = DEFAULT_BEAM_WIDTH, + repetition_penalty: float = DEFAULT_REPTITION_PENALTY, + length_penalty: float = DEFAULT_LENGTH_PENALTY, + max_retries: int = DEFAULT_MAX_RETRIES, + timeout: float = DEFAULT_TIMEOUT, + reuse_client: bool = DEFAULT_REUSE_CLIENT, + triton_load_model_call: bool = DEFAULT_TRITON_LOAD_MODEL, + callback_manager: Optional[CallbackManager] = None, + additional_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + additional_kwargs = additional_kwargs or {} + + super().__init__( + server_url=server_url, + model=model, + temperature=temperature, + top_p=top_p, + top_k=top_k, + tokens=tokens, + beam_width=beam_width, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + max_retries=max_retries, + timeout=timeout, + reuse_client=reuse_client, + triton_load_model_call=triton_load_model_call, + callback_manager=callback_manager, + additional_kwargs=additional_kwargs, + **kwargs, + ) + + try: + self._client = GrpcTritonClient(server_url) + except ImportError as err: + raise ImportError( + "Could not import triton client python package. " + "Please install it with `pip install tritonclient`." + ) from err + + @property + def _get_model_default_parameters(self) -> Dict[str, Any]: + return { + "tokens": self.tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "repetition_penalty": self.repetition_penalty, + "length_penalty": self.length_penalty, + "beam_width": self.beam_width, + } + + @property + def _invocation_params(self, **kwargs: Any) -> Dict[str, Any]: + return {**self._get_model_default_parameters, **kwargs} + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get all the identifying parameters.""" + return { + "server_url": self.server_url, + "model_name": self.model_name, + } + + def _get_client(self) -> Any: + """Create or reuse a Triton client connection.""" + if not self.reuse_client: + return GrpcTritonClient(self.server_url) + + if self._client is None: + self._client = GrpcTritonClient(self.server_url) + return self._client + + @property + def metadata(self) -> LLMMetadata: + """Gather and return metadata about the user Triton configured LLM model.""" + return LLMMetadata( + model_name=self.model_name, + ) + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + chat_fn = completion_to_chat_decorator(self.complete) + return chat_fn(messages, **kwargs) + + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + raise NotImplementedError + + def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: + from tritonclient.utils import InferenceServerException + + client = self._get_client() + + invocation_params = self._get_model_default_parameters + invocation_params.update(kwargs) + invocation_params["prompt"] = [[prompt]] + model_params = self._identifying_params + model_params.update(kwargs) + request_id = str(random.randint(1, 9999999)) # nosec + + if self.triton_load_model_call: + client.load_model(model_params["model_name"]) + + result_queue = client.request_streaming( + model_params["model_name"], request_id, **invocation_params + ) + + response = "" + for token in result_queue: + if isinstance(token, InferenceServerException): + client.stop_stream(model_params["model_name"], request_id) + raise token + response = response + token + + return CompletionResponse( + text=response, + ) + + def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: + raise NotImplementedError + + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + raise NotImplementedError + + async def acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse: + raise NotImplementedError + + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + raise NotImplementedError + + async def astream_complete( + self, prompt: str, **kwargs: Any + ) -> CompletionResponseAsyncGen: + raise NotImplementedError diff --git a/llama_index/llms/nvidia_triton_utils.py b/llama_index/llms/nvidia_triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1452ff5fc667eefa149f84e17607d6f99d76f98b --- /dev/null +++ b/llama_index/llms/nvidia_triton_utils.py @@ -0,0 +1,343 @@ +import abc +import json +import random +import time +from functools import partial +from queue import Queue +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Type, + Union, +) + +import numpy as np + +if TYPE_CHECKING: + import tritonclient.grpc as grpcclient + import tritonclient.http as httpclient + +STOP_WORDS = ["</s>"] +RANDOM_SEED = 0 + + +class StreamingResponseGenerator(Queue): + """A Generator that provides the inference results from an LLM.""" + + def __init__( + self, client: "GrpcTritonClient", request_id: str, force_batch: bool + ) -> None: + """Instantiate the generator class.""" + super().__init__() + self._client = client + self.request_id = request_id + self._batch = force_batch + + def __iter__(self) -> "StreamingResponseGenerator": + """Return self as a generator.""" + return self + + def __next__(self) -> str: + """Return the next retrieved token.""" + val = self.get() + if val is None or val in STOP_WORDS: + self._stop_stream() + raise StopIteration + return val + + def _stop_stream(self) -> None: + """Drain and shutdown the Triton stream.""" + self._client.stop_stream( + "tensorrt_llm", self.request_id, signal=not self._batch + ) + + +class _BaseTritonClient(abc.ABC): + """An abstraction of the connection to a triton inference server.""" + + def __init__(self, server_url: str) -> None: + """Initialize the client.""" + self._server_url = server_url + self._client = self._inference_server_client(server_url) + + @property + @abc.abstractmethod + def _inference_server_client( + self, + ) -> Union[ + Type["grpcclient.InferenceServerClient"], + Type["httpclient.InferenceServerClient"], + ]: + """Return the preferred InferenceServerClient class.""" + + @property + @abc.abstractmethod + def _infer_input( + self, + ) -> Union[Type["grpcclient.InferInput"], Type["httpclient.InferInput"]]: + """Return the preferred InferInput.""" + + @property + @abc.abstractmethod + def _infer_output( + self, + ) -> Union[ + Type["grpcclient.InferRequestedOutput"], Type["httpclient.InferRequestedOutput"] + ]: + """Return the preferred InferRequestedOutput.""" + + def load_model(self, model_name: str, timeout: int = 1000) -> None: + """Load a model into the server.""" + if self._client.is_model_ready(model_name): + return + + self._client.load_model(model_name) + t0 = time.perf_counter() + t1 = t0 + while not self._client.is_model_ready(model_name) and t1 - t0 < timeout: + t1 = time.perf_counter() + + if not self._client.is_model_ready(model_name): + raise RuntimeError(f"Failed to load {model_name} on Triton in {timeout}s") + + def get_model_list(self) -> List[str]: + """Get a list of models loaded in the triton server.""" + res = self._client.get_model_repository_index(as_json=True) + return [model["name"] for model in res["models"]] + + def get_model_concurrency(self, model_name: str, timeout: int = 1000) -> int: + """Get the model concurrency.""" + self.load_model(model_name, timeout) + instances = self._client.get_model_config(model_name, as_json=True)["config"][ + "instance_group" + ] + return sum(instance["count"] * len(instance["gpus"]) for instance in instances) + + def _generate_stop_signals( + self, + ) -> List[Union["grpcclient.InferInput", "httpclient.InferInput"]]: + """Generate the signal to stop the stream.""" + inputs = [ + self._infer_input("input_ids", [1, 1], "INT32"), + self._infer_input("input_lengths", [1, 1], "INT32"), + self._infer_input("request_output_len", [1, 1], "UINT32"), + self._infer_input("stop", [1, 1], "BOOL"), + ] + inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32)) + inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32)) + inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32)) + inputs[3].set_data_from_numpy(np.array([[True]], dtype="bool")) + return inputs + + def _generate_outputs( + self, + ) -> List[ + Union["grpcclient.InferRequestedOutput", "httpclient.InferRequestedOutput"] + ]: + """Generate the expected output structure.""" + return [self._infer_output("text_output")] + + def _prepare_tensor( + self, name: str, input_data: Any + ) -> Union["grpcclient.InferInput", "httpclient.InferInput"]: + """Prepare an input data structure.""" + from tritonclient.utils import np_to_triton_dtype + + t = self._infer_input( + name, input_data.shape, np_to_triton_dtype(input_data.dtype) + ) + t.set_data_from_numpy(input_data) + return t + + def _generate_inputs( # pylint: disable=too-many-arguments,too-many-locals + self, + prompt: str, + tokens: int = 300, + temperature: float = 1.0, + top_k: float = 1, + top_p: float = 0, + beam_width: int = 1, + repetition_penalty: float = 1, + length_penalty: float = 1.0, + stream: bool = True, + ) -> List[Union["grpcclient.InferInput", "httpclient.InferInput"]]: + """Create the input for the triton inference server.""" + query = np.array(prompt).astype(object) + request_output_len = np.array([tokens]).astype(np.uint32).reshape((1, -1)) + runtime_top_k = np.array([top_k]).astype(np.uint32).reshape((1, -1)) + runtime_top_p = np.array([top_p]).astype(np.float32).reshape((1, -1)) + temperature_array = np.array([temperature]).astype(np.float32).reshape((1, -1)) + len_penalty = np.array([length_penalty]).astype(np.float32).reshape((1, -1)) + repetition_penalty_array = ( + np.array([repetition_penalty]).astype(np.float32).reshape((1, -1)) + ) + random_seed = np.array([RANDOM_SEED]).astype(np.uint64).reshape((1, -1)) + beam_width_array = np.array([beam_width]).astype(np.uint32).reshape((1, -1)) + streaming_data = np.array([[stream]], dtype=bool) + + return [ + self._prepare_tensor("text_input", query), + self._prepare_tensor("max_tokens", request_output_len), + self._prepare_tensor("top_k", runtime_top_k), + self._prepare_tensor("top_p", runtime_top_p), + self._prepare_tensor("temperature", temperature_array), + self._prepare_tensor("length_penalty", len_penalty), + self._prepare_tensor("repetition_penalty", repetition_penalty_array), + self._prepare_tensor("random_seed", random_seed), + self._prepare_tensor("beam_width", beam_width_array), + self._prepare_tensor("stream", streaming_data), + ] + + def _trim_batch_response(self, result_str: str) -> str: + """Trim the resulting response from a batch request by removing provided prompt and extra generated text.""" + # extract the generated part of the prompt + split = result_str.split("[/INST]", 1) + generated = split[-1] + end_token = generated.find("</s>") + if end_token == -1: + return generated + return generated[:end_token].strip() + + +class GrpcTritonClient(_BaseTritonClient): + """GRPC connection to a triton inference server.""" + + @property + def _inference_server_client( + self, + ) -> Type["grpcclient.InferenceServerClient"]: + """Return the preferred InferenceServerClient class.""" + import tritonclient.grpc as grpcclient + + return grpcclient.InferenceServerClient # type: ignore + + @property + def _infer_input(self) -> Type["grpcclient.InferInput"]: + """Return the preferred InferInput.""" + import tritonclient.grpc as grpcclient + + return grpcclient.InferInput # type: ignore + + @property + def _infer_output( + self, + ) -> Type["grpcclient.InferRequestedOutput"]: + """Return the preferred InferRequestedOutput.""" + import tritonclient.grpc as grpcclient + + return grpcclient.InferRequestedOutput # type: ignore + + def _send_stop_signals(self, model_name: str, request_id: str) -> None: + """Send the stop signal to the Triton Inference server.""" + stop_inputs = self._generate_stop_signals() + self._client.async_stream_infer( + model_name, + stop_inputs, + request_id=request_id, + parameters={"Streaming": True}, + ) + + @staticmethod + def _process_result(result: Dict[str, str]) -> str: + """Post-process the result from the server.""" + import google.protobuf.json_format + import tritonclient.grpc as grpcclient + from tritonclient.grpc.service_pb2 import ModelInferResponse + + message = ModelInferResponse() + generated_text: str = "" + google.protobuf.json_format.Parse(json.dumps(result), message) + infer_result = grpcclient.InferResult(message) + np_res = infer_result.as_numpy("text_output") + + generated_text = "" + if np_res is not None: + generated_text = "".join([token.decode() for token in np_res]) + + return generated_text + + def _stream_callback( + self, + result_queue: Queue, + force_batch: bool, + result: Any, + error: str, + ) -> None: + """Add streamed result to queue.""" + if error: + result_queue.put(error) + else: + response_raw = result.get_response(as_json=True) + if "outputs" in response_raw: + # the very last response might have no output, just the final flag + response = self._process_result(response_raw) + if force_batch: + response = self._trim_batch_response(response) + + if response in STOP_WORDS: + result_queue.put(None) + else: + result_queue.put(response) + + if response_raw["parameters"]["triton_final_response"]["bool_param"]: + # end of the generation + result_queue.put(None) + + # pylint: disable-next=too-many-arguments + def _send_prompt_streaming( + self, + model_name: str, + request_inputs: Any, + request_outputs: Optional[Any], + request_id: str, + result_queue: StreamingResponseGenerator, + force_batch: bool = False, + ) -> None: + """Send the prompt and start streaming the result.""" + self._client.start_stream( + callback=partial(self._stream_callback, result_queue, force_batch) + ) + self._client.async_stream_infer( + model_name=model_name, + inputs=request_inputs, + outputs=request_outputs, + request_id=request_id, + ) + + def request_streaming( + self, + model_name: str, + request_id: Optional[str] = None, + force_batch: bool = False, + **params: Any, + ) -> StreamingResponseGenerator: + """Request a streaming connection.""" + if not self._client.is_model_ready(model_name): + raise RuntimeError("Cannot request streaming, model is not loaded") + + if not request_id: + request_id = str(random.randint(1, 9999999)) # nosec + + result_queue = StreamingResponseGenerator(self, request_id, force_batch) + inputs = self._generate_inputs(stream=not force_batch, **params) + outputs = self._generate_outputs() + self._send_prompt_streaming( + model_name, + inputs, + outputs, + request_id, + result_queue, + force_batch, + ) + return result_queue + + def stop_stream( + self, model_name: str, request_id: str, signal: bool = True + ) -> None: + """Close the streaming connection.""" + if signal: + self._send_stop_signals(model_name, request_id) + self._client.stop_stream() diff --git a/poetry.lock b/poetry.lock index 54b979c328c63c7f03cce94e35efc631082c2339..d148b69822800e43e02dd8aecfd2c7f16d21ea02 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -2572,13 +2572,13 @@ files = [ [[package]] name = "langchain" -version = "0.0.351" +version = "0.0.352" description = "Building applications with LLMs through composability" optional = true python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain-0.0.351-py3-none-any.whl", hash = "sha256:90cdaee27db2b2aeeb7b0709a79cbfe3e858fc9536b6bc3ea262135a6affc70f"}, - {file = "langchain-0.0.351.tar.gz", hash = "sha256:6bf2a8665a7a3ca2bbd4eea9889ecfd3d39ab23a505549a03860272474399b38"}, + {file = "langchain-0.0.352-py3-none-any.whl", hash = "sha256:43ab580e1223e5d7c3495b3c0cb79e2f3a0ecb52caf8126271fb10d42cede2d0"}, + {file = "langchain-0.0.352.tar.gz", hash = "sha256:8928d7b63d73af9681fe1b2a2b99b84238efef61ed537de666160fd001f41efd"}, ] [package.dependencies] @@ -2612,13 +2612,13 @@ text-helpers = ["chardet (>=5.1.0,<6.0.0)"] [[package]] name = "langchain-community" -version = "0.0.5" +version = "0.0.6" description = "Community contributed LangChain integrations." optional = true python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain_community-0.0.5-py3-none-any.whl", hash = "sha256:7579ff28b3bbaa73dd17ee5e88b84d09c785691d2af2f00f0bb98e7478072af6"}, - {file = "langchain_community-0.0.5.tar.gz", hash = "sha256:425953df8035b6d278fa724a5d1a33b95ced3787ffff5b9128b3c16f0474335e"}, + {file = "langchain_community-0.0.6-py3-none-any.whl", hash = "sha256:13b16da0f89c328df456911ff03069e4d919f647c7dd3bfc5062525cf956ed82"}, + {file = "langchain_community-0.0.6.tar.gz", hash = "sha256:b7deb63fd8205d54b51cf8b1702de15d1da77987f8465c356b158a65adff378c"}, ] [package.dependencies] @@ -2634,17 +2634,17 @@ tenacity = ">=8.1.0,<9.0.0" [package.extras] cli = ["typer (>=0.9.0,<0.10.0)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] [[package]] name = "langchain-core" -version = "0.1.2" +version = "0.1.3" description = "Building applications with LLMs through composability" optional = true python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain_core-0.1.2-py3-none-any.whl", hash = "sha256:823de99910081be46b127ae4fd7066acea82b8ac48742fb34e6d3c7f5d1a03ce"}, - {file = "langchain_core-0.1.2.tar.gz", hash = "sha256:6fd641ca776974d0adeb4aa390ebb173138d75e93a1c6928bb05b21f5e81cb1f"}, + {file = "langchain_core-0.1.3-py3-none-any.whl", hash = "sha256:bfbbc5dfeb06cfe3fd078e7a12db3a4cfb9d28b715b200a64f7abb7ae1976b17"}, + {file = "langchain_core-0.1.3.tar.gz", hash = "sha256:d8898254dfea1c4ab614f470db40909969604775f7524175f6d9167ea58050c9"}, ] [package.dependencies] @@ -2676,13 +2676,13 @@ data = ["language-data (>=1.1,<2.0)"] [[package]] name = "langsmith" -version = "0.0.72" +version = "0.0.73" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = true python-versions = ">=3.8.1,<4.0" files = [ - {file = "langsmith-0.0.72-py3-none-any.whl", hash = "sha256:2cddd49cd7d1477409c8785746acf42dbd6709a7d36e751247a3cab5e3eee20e"}, - {file = "langsmith-0.0.72.tar.gz", hash = "sha256:505f517e2e67836a4e831917d8223a18cc59d51bdae1e4295fc6dff2266bab5d"}, + {file = "langsmith-0.0.73-py3-none-any.whl", hash = "sha256:cb99cc10a70d882a72cc2884d75b95744bd931ee24b0466c28b522a91354c566"}, + {file = "langsmith-0.0.73.tar.gz", hash = "sha256:f21984aa8bb0a7749f355d3fc95e3f8bd663fb90f465b1692c18420d89241a0f"}, ] [package.dependencies] @@ -2802,6 +2802,7 @@ files = [ {file = "lxml-4.9.4-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e8f9f93a23634cfafbad6e46ad7d09e0f4a25a2400e4a64b1b7b7c0fbaa06d9d"}, {file = "lxml-4.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3f3f00a9061605725df1816f5713d10cd94636347ed651abdbc75828df302b20"}, {file = "lxml-4.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:953dd5481bd6252bd480d6ec431f61d7d87fdcbbb71b0d2bdcfc6ae00bb6fb10"}, + {file = "lxml-4.9.4-cp312-cp312-win32.whl", hash = "sha256:266f655d1baff9c47b52f529b5f6bec33f66042f65f7c56adde3fcf2ed62ae8b"}, {file = "lxml-4.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:f1faee2a831fe249e1bae9cbc68d3cd8a30f7e37851deee4d7962b17c410dd56"}, {file = "lxml-4.9.4-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:23d891e5bdc12e2e506e7d225d6aa929e0a0368c9916c1fddefab88166e98b20"}, {file = "lxml-4.9.4-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e96a1788f24d03e8d61679f9881a883ecdf9c445a38f9ae3f3f193ab6c591c66"}, @@ -4612,7 +4613,6 @@ files = [ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"}, {file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"}, {file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"}, - {file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},