diff --git a/coverage.xml b/coverage.xml index dc2d164687e8d31f642f1823936a6ebe9ae33247..5850def2dfb8bf3ee91ce0e20a28f5312ec8c0b2 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,12 +1,12 @@ <?xml version="1.0" ?> -<coverage version="7.4.0" timestamp="1704877042238" lines-valid="824" lines-covered="712" line-rate="0.8641" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> +<coverage version="7.4.0" timestamp="1704881517474" lines-valid="893" lines-covered="778" line-rate="0.8712" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> <!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.4.0 --> <!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd --> <sources> <source>/Users/jakit/customers/aurelio/semantic-router/semantic_router</source> </sources> <packages> - <package name="." line-rate="0.9281" branch-rate="0" complexity="0"> + <package name="." line-rate="0.9263" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> @@ -17,7 +17,7 @@ <line number="5" hits="1"/> </lines> </class> - <class name="hybrid_layer.py" filename="hybrid_layer.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="hybrid_layer.py" filename="hybrid_layer.py" complexity="0" line-rate="0.9796" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -32,92 +32,95 @@ <line number="16" hits="1"/> <line number="17" hits="1"/> <line number="19" hits="1"/> - <line number="22" hits="1"/> - <line number="23" hits="1"/> - <line number="24" hits="1"/> - <line number="25" hits="1"/> + <line number="26" hits="1"/> <line number="27" hits="1"/> - <line number="31" hits="1"/> + <line number="29" hits="1"/> + <line number="30" hits="0"/> + <line number="31" hits="0"/> <line number="33" hits="1"/> - <line number="34" hits="1"/> <line number="35" hits="1"/> - <line number="36" hits="1"/> <line number="37" hits="1"/> - <line number="38" hits="1"/> - <line number="40" hits="1"/> - <line number="42" hits="1"/> + <line number="41" hits="1"/> <line number="43" hits="1"/> + <line number="44" hits="1"/> <line number="45" hits="1"/> + <line number="46" hits="1"/> <line number="47" hits="1"/> <line number="48" hits="1"/> + <line number="50" hits="1"/> + <line number="52" hits="1"/> <line number="53" hits="1"/> - <line number="54" hits="1"/> <line number="55" hits="1"/> <line number="57" hits="1"/> <line number="58" hits="1"/> - <line number="59" hits="1"/> <line number="63" hits="1"/> <line number="64" hits="1"/> - <line number="66" hits="1"/> + <line number="65" hits="1"/> + <line number="67" hits="1"/> <line number="68" hits="1"/> <line number="69" hits="1"/> - <line number="71" hits="1"/> <line number="73" hits="1"/> - <line number="75" hits="1"/> + <line number="74" hits="1"/> <line number="76" hits="1"/> + <line number="78" hits="1"/> <line number="79" hits="1"/> - <line number="80" hits="1"/> + <line number="81" hits="1"/> <line number="83" hits="1"/> - <line number="84" hits="1"/> <line number="85" hits="1"/> - <line number="92" hits="1"/> - <line number="99" hits="1"/> - <line number="105" hits="1"/> - <line number="110" hits="1"/> - <line number="111" hits="1"/> - <line number="113" hits="1"/> - <line number="114" hits="1"/> - <line number="116" hits="1"/> - <line number="118" hits="1"/> + <line number="86" hits="1"/> + <line number="89" hits="1"/> + <line number="90" hits="1"/> + <line number="93" hits="1"/> + <line number="94" hits="1"/> + <line number="95" hits="1"/> + <line number="102" hits="1"/> + <line number="109" hits="1"/> + <line number="115" hits="1"/> <line number="120" hits="1"/> <line number="121" hits="1"/> - <line number="122" hits="1"/> + <line number="123" hits="1"/> <line number="124" hits="1"/> - <line number="125" hits="1"/> <line number="126" hits="1"/> - <line number="127" hits="1"/> - <line number="129" hits="1"/> + <line number="128" hits="1"/> <line number="130" hits="1"/> <line number="131" hits="1"/> - <line number="133" hits="1"/> + <line number="132" hits="1"/> <line number="134" hits="1"/> + <line number="135" hits="1"/> <line number="136" hits="1"/> <line number="137" hits="1"/> <line number="139" hits="1"/> + <line number="140" hits="1"/> <line number="141" hits="1"/> - <line number="142" hits="1"/> <line number="143" hits="1"/> - <line number="145" hits="1"/> + <line number="144" hits="1"/> <line number="146" hits="1"/> <line number="147" hits="1"/> - <line number="148" hits="1"/> <line number="149" hits="1"/> - <line number="150" hits="1"/> <line number="151" hits="1"/> + <line number="152" hits="1"/> <line number="153" hits="1"/> + <line number="155" hits="1"/> <line number="156" hits="1"/> <line number="157" hits="1"/> + <line number="158" hits="1"/> + <line number="159" hits="1"/> <line number="160" hits="1"/> <line number="161" hits="1"/> <line number="163" hits="1"/> - <line number="164" hits="1"/> <line number="166" hits="1"/> <line number="167" hits="1"/> - <line number="168" hits="1"/> <line number="170" hits="1"/> + <line number="171" hits="1"/> + <line number="173" hits="1"/> + <line number="174" hits="1"/> + <line number="176" hits="1"/> + <line number="177" hits="1"/> + <line number="178" hits="1"/> + <line number="180" hits="1"/> </lines> </class> - <class name="layer.py" filename="layer.py" complexity="0" line-rate="0.8776" branch-rate="0"> + <class name="layer.py" filename="layer.py" complexity="0" line-rate="0.8827" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -182,7 +185,7 @@ <line number="93" hits="1"/> <line number="94" hits="1"/> <line number="95" hits="1"/> - <line number="99" hits="0"/> + <line number="99" hits="1"/> <line number="101" hits="1"/> <line number="102" hits="1"/> <line number="108" hits="1"/> @@ -471,7 +474,7 @@ </class> </classes> </package> - <package name="encoders" line-rate="0.9231" branch-rate="0" complexity="0"> + <package name="encoders" line-rate="0.9369" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="encoders/__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> @@ -481,7 +484,8 @@ <line number="3" hits="1"/> <line number="4" hits="1"/> <line number="5" hits="1"/> - <line number="7" hits="1"/> + <line number="6" hits="1"/> + <line number="8" hits="1"/> </lines> </class> <class name="base.py" filename="encoders/base.py" complexity="0" line-rate="1" branch-rate="0"> @@ -498,7 +502,7 @@ <line number="13" hits="1"/> </lines> </class> - <class name="bm25.py" filename="encoders/bm25.py" complexity="0" line-rate="0.9524" branch-rate="0"> + <class name="bm25.py" filename="encoders/bm25.py" complexity="0" line-rate="0.9574" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -509,28 +513,23 @@ <line number="9" hits="1"/> <line number="10" hits="1"/> <line number="12" hits="1"/> - <line number="13" hits="1"/> - <line number="14" hits="1"/> - <line number="15" hits="1"/> - <line number="16" hits="0"/> - <line number="17" hits="0"/> - <line number="21" hits="1"/> - <line number="22" hits="1"/> - <line number="24" hits="1"/> - <line number="25" hits="1"/> - <line number="26" hits="1"/> + <line number="18" hits="1"/> + <line number="19" hits="1"/> + <line number="20" hits="1"/> + <line number="21" hits="0"/> + <line number="22" hits="0"/> <line number="27" hits="1"/> - <line number="28" hits="1"/> + <line number="29" hits="1"/> <line number="30" hits="1"/> + <line number="31" hits="1"/> <line number="32" hits="1"/> - <line number="33" hits="1"/> <line number="34" hits="1"/> <line number="35" hits="1"/> <line number="36" hits="1"/> <line number="37" hits="1"/> <line number="38" hits="1"/> - <line number="40" hits="1"/> - <line number="42" hits="1"/> + <line number="39" hits="1"/> + <line number="41" hits="1"/> <line number="43" hits="1"/> <line number="44" hits="1"/> <line number="45" hits="1"/> @@ -538,11 +537,21 @@ <line number="47" hits="1"/> <line number="48" hits="1"/> <line number="49" hits="1"/> - <line number="50" hits="1"/> - <line number="52" hits="1"/> + <line number="51" hits="1"/> <line number="53" hits="1"/> <line number="54" hits="1"/> <line number="55" hits="1"/> + <line number="56" hits="1"/> + <line number="57" hits="1"/> + <line number="58" hits="1"/> + <line number="59" hits="1"/> + <line number="60" hits="1"/> + <line number="61" hits="1"/> + <line number="63" hits="1"/> + <line number="64" hits="1"/> + <line number="65" hits="1"/> + <line number="66" hits="1"/> + <line number="67" hits="1"/> </lines> </class> <class name="cohere.py" filename="encoders/cohere.py" complexity="0" line-rate="1" branch-rate="0"> @@ -611,6 +620,71 @@ <line number="51" hits="0"/> </lines> </class> + <class name="huggingface.py" filename="encoders/huggingface.py" complexity="0" line-rate="0.9667" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="2" hits="1"/> + <line number="3" hits="1"/> + <line number="6" hits="1"/> + <line number="7" hits="1"/> + <line number="8" hits="1"/> + <line number="9" hits="1"/> + <line number="10" hits="1"/> + <line number="11" hits="1"/> + <line number="12" hits="1"/> + <line number="13" hits="1"/> + <line number="14" hits="1"/> + <line number="15" hits="1"/> + <line number="17" hits="1"/> + <line number="18" hits="1"/> + <line number="19" hits="1"/> + <line number="21" hits="1"/> + <line number="22" hits="1"/> + <line number="23" hits="1"/> + <line number="24" hits="1"/> + <line number="25" hits="1"/> + <line number="31" hits="1"/> + <line number="32" hits="1"/> + <line number="33" hits="1"/> + <line number="34" hits="1"/> + <line number="40" hits="1"/> + <line number="42" hits="1"/> + <line number="47" hits="1"/> + <line number="49" hits="1"/> + <line number="50" hits="0"/> + <line number="53" hits="1"/> + <line number="54" hits="1"/> + <line number="55" hits="1"/> + <line number="57" hits="1"/> + <line number="59" hits="1"/> + <line number="66" hits="1"/> + <line number="67" hits="1"/> + <line number="68" hits="1"/> + <line number="70" hits="1"/> + <line number="74" hits="1"/> + <line number="75" hits="1"/> + <line number="77" hits="1"/> + <line number="78" hits="1"/> + <line number="81" hits="1"/> + <line number="82" hits="1"/> + <line number="86" hits="0"/> + <line number="90" hits="1"/> + <line number="91" hits="1"/> + <line number="93" hits="1"/> + <line number="94" hits="1"/> + <line number="95" hits="1"/> + <line number="97" hits="1"/> + <line number="98" hits="1"/> + <line number="99" hits="1"/> + <line number="102" hits="1"/> + <line number="106" hits="1"/> + <line number="107" hits="1"/> + <line number="108" hits="1"/> + <line number="111" hits="1"/> + <line number="112" hits="1"/> + </lines> + </class> <class name="openai.py" filename="encoders/openai.py" complexity="0" line-rate="0.9767" branch-rate="0"> <methods/> <lines> diff --git a/docs/encoders/fastembed.ipynb b/docs/encoders/fastembed.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..68b1f69cd7228a25ec2ae97e6ffa4f13ca39a116 --- /dev/null +++ b/docs/encoders/fastembed.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/encoders/fastembed.ipynb) [](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/encoders/fastembed.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using FastEmbedEncoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "FastEmbed is a _lightweight and fast_ embedding library built for generating embeddings. It can be run locally and supports many open source encoders.\n", + "\n", + "Beyond being a local, open source library, there are two key reasons we might want to run this library over other open source alternatives:\n", + "\n", + "* **Lightweight and Fast**: The library uses an ONNX runtime so there is no heavy PyTorch dependency, supports quantized model weights (smaller memory footprint), is developed for running on CPU, and uses data-parallelism for encoding large datasets.\n", + "\n", + "* **Open-weight models**: FastEmbed supports many open source and open-weight models, included some that outperform popular encoders like OpenAI's Ada-002." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing semantic-router with the `[fastembed]` flag to include all necessary dependencies for `FastEmbedEncoder`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU \"semantic-router[fastembed]==0.0.15\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by defining a dictionary mapping routes to example phrases that should trigger those routes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\",\n", + " \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define another for good measure:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we initialize our embedding model, you can find a list of [all available embedding models here](https://qdrant.github.io/fastembed/examples/Supported_Models/):" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router.encoders import FastEmbedEncoder\n", + "\n", + "encoder = FastEmbedEncoder(name=\"BAAI/bge-small-en-v1.5\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_**âš ï¸ If you see an ImportError, you must install the FastEmbed library. You can do so by installing Semantic Router using `pip install -qU \"semantic-router[fastembed]\"`.**_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-06 16:53:16 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + ] + } + ], + "source": [ + "from semantic_router.layer import RouteLayer\n", + "\n", + "rl = RouteLayer(encoder=encoder, routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can test it:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='politics', function_call=None)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects?" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name=None, function_call=None)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"I'm interested in learning about llama 2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we return `None` because no matches were identified." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/encoders/huggingface.ipynb b/docs/encoders/huggingface.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4e9c28cdd9d1cc4c65b84a30de2d896b77ef2049 --- /dev/null +++ b/docs/encoders/huggingface.ipynb @@ -0,0 +1,271 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/encoders/huggingface.ipynb) [](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/encoders/huggingface.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using HuggingFaceEncoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "HuggingFace is a huge ecosystem of open source models. It can be run locally and supports the largest library of encoders." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing semantic-router with the `[local]` flag to include all necessary dependencies for `HuggingFaceEncoder`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU \"semantic-router[local]==0.0.16\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by defining a dictionary mapping routes to example phrases that should trigger those routes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\",\n", + " \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_**âš ï¸ If you see an ImportError, you must install local dependencies. You can do so by installing Semantic Router using `pip install -qU \"semantic-router[local]\"`.**_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define another for good measure:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we initialize our embedding model." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jamesbriggs/opt/anaconda3/envs/decision-layer/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "tokenizer_config.json: 100%|██████████| 350/350 [00:00<00:00, 1.06MB/s]\n", + "vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 1.05MB/s]\n", + "tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 1.43MB/s]\n", + "special_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 386kB/s]\n", + "config.json: 100%|██████████| 612/612 [00:00<00:00, 2.90MB/s]\n", + "pytorch_model.bin: 100%|██████████| 90.9M/90.9M [00:01<00:00, 63.2MB/s]\n" + ] + } + ], + "source": [ + "from semantic_router.encoders import HuggingFaceEncoder\n", + "\n", + "encoder = HuggingFaceEncoder()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-09 00:22:35 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + ] + } + ], + "source": [ + "from semantic_router.layer import RouteLayer\n", + "\n", + "rl = RouteLayer(encoder=encoder, routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can test it:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='politics', function_call=None)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects?" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name=None, function_call=None)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"I'm interested in learning about llama 2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we return `None` because no matches were identified." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/poetry.lock b/poetry.lock index 3131c0d8060074de0d248f94aae8747ce493f7f4..26e392884612b9b043d9234a858b83554444b278 100644 --- a/poetry.lock +++ b/poetry.lock @@ -250,6 +250,17 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "cachetools" +version = "5.3.2" +description = "Extensible memoizing collections and decorators" +optional = false +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"}, + {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"}, +] + [[package]] name = "certifi" version = "2023.11.17" @@ -325,6 +336,17 @@ files = [ [package.dependencies] pycparser = "*" +[[package]] +name = "chardet" +version = "5.2.0" +description = "Universal encoding detector for Python 3" +optional = false +python-versions = ">=3.7" +files = [ + {file = "chardet-5.2.0-py3-none-any.whl", hash = "sha256:e1cf59446890a00105fe7b7912492ea04b6e6f06d4b742b2c788469e34c82970"}, + {file = "chardet-5.2.0.tar.gz", hash = "sha256:1b3b6ff479a8c414bc3fa2c0852995695c4a026dcd6d0633b2dd092ca39c1cf7"}, +] + [[package]] name = "charset-normalizer" version = "3.3.2" @@ -624,6 +646,17 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "distlib" +version = "0.3.8" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, + {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, +] + [[package]] name = "distro" version = "1.9.0" @@ -746,7 +779,7 @@ tqdm = ">=4.65,<5.0" name = "filelock" version = "3.13.1" description = "A platform independent file lock." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, @@ -1123,6 +1156,23 @@ docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alab qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] +[[package]] +name = "jinja2" +version = "3.1.2" +description = "A very fast and expressive template engine." +optional = true +python-versions = ">=3.7" +files = [ + {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, + {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, +] + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + [[package]] name = "joblib" version = "1.3.2" @@ -1177,6 +1227,65 @@ traitlets = ">=5.3" docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] +[[package]] +name = "markupsafe" +version = "2.1.3" +description = "Safely add untrusted strings to HTML/XML markup." +optional = true +python-versions = ">=3.7" +files = [ + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win32.whl", hash = "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win32.whl", hash = "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl", hash = "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win32.whl", hash = "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win32.whl", hash = "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"}, + {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, +] + [[package]] name = "matplotlib-inline" version = "0.1.6" @@ -1404,6 +1513,24 @@ files = [ {file = "nest_asyncio-1.5.8.tar.gz", hash = "sha256:25aa2ca0d2a5b5531956b9e273b45cf664cae2b145101d73b86b199978d48fdb"}, ] +[[package]] +name = "networkx" +version = "3.2.1" +description = "Python package for creating and manipulating graphs and networks" +optional = true +python-versions = ">=3.9" +files = [ + {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, + {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, +] + +[package.extras] +default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] +test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] + [[package]] name = "nltk" version = "3.8.1" @@ -1463,6 +1590,147 @@ files = [ {file = "numpy-1.25.2.tar.gz", hash = "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760"}, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.1.3.1" +description = "CUBLAS native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" +description = "CUDA profiling tools runtime libs." +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" +description = "NVRTC native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" +description = "CUDA Runtime native Libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" +description = "cuDNN runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.0.2.54" +description = "CUFFT native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.2.106" +description = "CURAND native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" +description = "CUDA solver native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" +description = "CUSPARSE native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.18.1" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.3.101" +description = "Nvidia JIT LTO Library" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, +] + [[package]] name = "onnx" version = "1.15.0" @@ -1837,6 +2105,25 @@ files = [ plugins = ["importlib-metadata"] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyproject-api" +version = "1.6.1" +description = "API to interact with the python pyproject.toml based projects" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyproject_api-1.6.1-py3-none-any.whl", hash = "sha256:4c0116d60476b0786c88692cf4e325a9814965e2469c5998b830bba16b183675"}, + {file = "pyproject_api-1.6.1.tar.gz", hash = "sha256:1817dc018adc0d1ff9ca1ed8c60e1623d5aaca40814b953af14a9cf9a5cae538"}, +] + +[package.dependencies] +packaging = ">=23.1" +tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +docs = ["furo (>=2023.8.19)", "sphinx (<7.2)", "sphinx-autodoc-typehints (>=1.24)"] +testing = ["covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "setuptools (>=68.1.2)", "wheel (>=0.41.2)"] + [[package]] name = "pyreadline3" version = "3.4.1" @@ -2265,6 +2552,125 @@ files = [ {file = "ruff-0.1.11.tar.gz", hash = "sha256:f9d4d88cb6eeb4dfe20f9f0519bd2eaba8119bde87c3d5065c541dbae2b5a2cb"}, ] +[[package]] +name = "safetensors" +version = "0.4.1" +description = "" +optional = true +python-versions = ">=3.7" +files = [ + {file = "safetensors-0.4.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:cba01c6b76e01ec453933b3b3c0157c59b52881c83eaa0f7666244e71aa75fd1"}, + {file = "safetensors-0.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7a8f6f679d97ea0135c7935c202feefbd042c149aa70ee759855e890c01c7814"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbc2ce1f5ae5143a7fb72b71fa71db6a42b4f6cf912aa3acdc6b914084778e68"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2d87d993eaefe6611a9c241a8bd364a5f1ffed5771c74840363a6c4ed8d868f6"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:097e9af2efa8778cd2f0cba451784253e62fa7cc9fc73c0744d27212f7294e25"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d10a9f7bae608ccfdc009351f01dc3d8535ff57f9488a58a4c38e45bf954fe93"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:270b99885ec14abfd56c1d7f28ada81740a9220b4bae960c3de1c6fe84af9e4d"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:285b52a481e7ba93e29ad4ec5841ef2c4479ef0a6c633c4e2629e0508453577b"}, + {file = "safetensors-0.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c3c9f0ca510e0de95abd6424789dcbc879942a3a4e29b0dfa99d9427bf1da75c"}, + {file = "safetensors-0.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:88b4653059c903015284a9722f9a46838c654257173b279c8f6f46dbe80b612d"}, + {file = "safetensors-0.4.1-cp310-none-win32.whl", hash = "sha256:2fe6926110e3d425c4b684a4379b7796fdc26ad7d16922ea1696c8e6ea7e920f"}, + {file = "safetensors-0.4.1-cp310-none-win_amd64.whl", hash = "sha256:a79e16222106b2f5edbca1b8185661477d8971b659a3c814cc6f15181a9b34c8"}, + {file = "safetensors-0.4.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:d93321eea0dd7e81b283e47a1d20dee6069165cc158286316d0d06d340de8fe8"}, + {file = "safetensors-0.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8ff8e41c8037db17de0ea2a23bc684f43eaf623be7d34906fe1ac10985b8365e"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39d36f1d88468a87c437a1bc27c502e71b6ca44c385a9117a9f9ba03a75cc9c6"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7ef010e9afcb4057fb6be3d0a0cfa07aac04fe97ef73fe4a23138d8522ba7c17"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b287304f2b2220d51ccb51fd857761e78bcffbeabe7b0238f8dc36f2edfd9542"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e09000b2599e1836314430f81a3884c66a5cbabdff5d9f175b5d560d4de38d78"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9c80ce0001efa16066358d2dd77993adc25f5a6c61850e4ad096a2232930bce"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:413e1f6ac248f7d1b755199a06635e70c3515493d3b41ba46063dec33aa2ebb7"}, + {file = "safetensors-0.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3ac139377cfe71ba04573f1cda66e663b7c3e95be850e9e6c2dd4b5984bd513"}, + {file = "safetensors-0.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:04157d008385bea66d12fe90844a80d4a76dc25ec5230b5bd9a630496d1b7c03"}, + {file = "safetensors-0.4.1-cp311-none-win32.whl", hash = "sha256:5f25297148ec665f0deb8bd67e9564634d8d6841041ab5393ccfe203379ea88b"}, + {file = "safetensors-0.4.1-cp311-none-win_amd64.whl", hash = "sha256:b2f8877990a72ff595507b80f4b69036a9a1986a641f8681adf3425d97d3d2a5"}, + {file = "safetensors-0.4.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:eb2c1da1cc39509d1a55620a5f4d14f8911c47a89c926a96e6f4876e864375a3"}, + {file = "safetensors-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:303d2c0415cf15a28f8d7f17379ea3c34c2b466119118a34edd9965983a1a8a6"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb4cb3e37a9b961ddd68e873b29fe9ab4a081e3703412e34aedd2b7a8e9cafd9"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ae5497adc68669db2fed7cb2dad81e6a6106e79c9a132da3efdb6af1db1014fa"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b30abd0cddfe959d1daedf92edcd1b445521ebf7ddefc20860ed01486b33c90"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d784a98c492c751f228a4a894c3b8a092ff08b24e73b5568938c28b8c0e8f8df"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e57a5ab08b0ec7a7caf30d2ac79bb30c89168431aca4f8854464bb9461686925"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:edcf3121890b5f0616aa5a54683b1a5d2332037b970e507d6bb7841a3a596556"}, + {file = "safetensors-0.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fdb58dee173ef33634c3016c459d671ca12d11e6acf9db008261cbe58107e579"}, + {file = "safetensors-0.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:780dc21eb3fd32ddd0e8c904bdb0290f2454f4ac21ae71e94f9ce72db1900a5a"}, + {file = "safetensors-0.4.1-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:48901bd540f8a3c1791314bc5c8a170927bf7f6acddb75bf0a263d081a3637d4"}, + {file = "safetensors-0.4.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:3b0b7b2d5976fbed8a05e2bbdce5816a59e6902e9e7c7e07dc723637ed539787"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f69903ff49cb30b9227fb5d029bea276ea20d04b06803877a420c5b1b74c689"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0ddd050e01f3e843aa8c1c27bf68675b8a08e385d0045487af4d70418c3cb356"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a82bc2bd7a9a0e08239bdd6d7774d64121f136add93dfa344a2f1a6d7ef35fa"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6ace9e66a40f98a216ad661245782483cf79cf56eb2b112650bb904b0baa9db5"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82cbb8f4d022f2e94498cbefca900698b8ded3d4f85212f47da614001ff06652"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:791edc10a3c359a2f5f52d5cddab0df8a45107d91027d86c3d44e57162e5d934"}, + {file = "safetensors-0.4.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:83c2cfbe8c6304f0891e7bb378d56f66d2148972eeb5f747cd8a2246886f0d8c"}, + {file = "safetensors-0.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:04dd14f53f5500eb4c4149674216ba1000670efbcf4b1b5c2643eb244e7882ea"}, + {file = "safetensors-0.4.1-cp37-none-win32.whl", hash = "sha256:d5b3defa74f3723a388bfde2f5d488742bc4879682bd93267c09a3bcdf8f869b"}, + {file = "safetensors-0.4.1-cp37-none-win_amd64.whl", hash = "sha256:25a043cbb59d4f75e9dd87fdf5c009dd8830105a2c57ace49b72167dd9808111"}, + {file = "safetensors-0.4.1-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:3f6a520af7f2717c5ecba112041f2c8af1ca6480b97bf957aba81ed9642e654c"}, + {file = "safetensors-0.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c3807ac3b16288dffebb3474b555b56fe466baa677dfc16290dcd02dca1ab228"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b58ba13a9e82b4bc3fc221914f6ef237fe6c2adb13cede3ace64d1aacf49610"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dac4bb42f8679aadc59bd91a4c5a1784a758ad49d0912995945cd674089f628e"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:911b48dc09e321a194def3a7431662ff4f03646832f3a8915bbf0f449b8a5fcb"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82571d20288c975c1b30b08deb9b1c3550f36b31191e1e81fae87669a92217d0"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da52ee0dc8ba03348ffceab767bd8230842fdf78f8a996e2a16445747143a778"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2536b11ce665834201072e9397404170f93f3be10cca9995b909f023a04501ee"}, + {file = "safetensors-0.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:998fbac99ca956c3a09fe07cc0b35fac26a521fa8865a690686d889f0ff4e4a6"}, + {file = "safetensors-0.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:845be0aafabf2a60c2d482d4e93023fecffe5e5443d801d7a7741bae9de41233"}, + {file = "safetensors-0.4.1-cp38-none-win32.whl", hash = "sha256:ce7a28bc8af685a69d7e869d09d3e180a275e3281e29cf5f1c7319e231932cc7"}, + {file = "safetensors-0.4.1-cp38-none-win_amd64.whl", hash = "sha256:e056fb9e22d118cc546107f97dc28b449d88274207dd28872bd668c86216e4f6"}, + {file = "safetensors-0.4.1-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:bdc0d039e44a727824639824090bd8869535f729878fa248addd3dc01db30eae"}, + {file = "safetensors-0.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3c1b1d510c7aba71504ece87bf393ea82638df56303e371e5e2cf09d18977dd7"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bd0afd95c1e497f520e680ea01e0397c0868a3a3030e128438cf6e9e3fcd671"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f603bdd8deac6726d39f41688ed353c532dd53935234405d79e9eb53f152fbfb"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8a85e3e47e0d4eebfaf9a58b40aa94f977a56050cb5598ad5396a9ee7c087c6"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0ccb5aa0f3be2727117e5631200fbb3a5b3a2b3757545a92647d6dd8be6658f"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d784938534e255473155e4d9f276ee69eb85455b6af1292172c731409bf9adee"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a257de175c254d39ccd6a21341cd62eb7373b05c1e618a78096a56a857e0c316"}, + {file = "safetensors-0.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6fd80f7794554091836d4d613d33a7d006e2b8d6ba014d06f97cebdfda744f64"}, + {file = "safetensors-0.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:35803201d980efcf964b75a0a2aee97fe5e9ecc5f3ad676b38fafdfe98e0620d"}, + {file = "safetensors-0.4.1-cp39-none-win32.whl", hash = "sha256:7ff8a36e0396776d3ed9a106fc9a9d7c55d4439ca9a056a24bf66d343041d3e6"}, + {file = "safetensors-0.4.1-cp39-none-win_amd64.whl", hash = "sha256:bfa2e20342b81921b98edba52f8deb68843fa9c95250739a56b52ceda5ea5c61"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ae2d5a31cfb8a973a318f7c4d2cffe0bd1fe753cdf7bb41a1939d45a0a06f964"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a45dbf03e8334d3a5dc93687d98b6dc422f5d04c7d519dac09b84a3c87dd7c6"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2297b359d91126c0f9d4fd17bae3cfa2fe3a048a6971b8db07db746ad92f850c"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bda3d98e2bcece388232cfc551ebf063b55bdb98f65ab54df397da30efc7dcc5"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f8934bdfd202ebd0697040a3dff40dd77bc4c5bbf3527ede0532f5e7fb4d970f"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:42c3710cec7e5c764c7999697516370bee39067de0aa089b7e2cfb97ac8c6b20"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:53134226053e56bd56e73f7db42596e7908ed79f3c9a1016e4c1dade593ac8e5"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:257d59e40a1b367cb544122e7451243d65b33c3f34d822a347f4eea6fdf97fdf"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d54c2f1826e790d1eb2d2512bfd0ee443f0206b423d6f27095057c7f18a0687"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:645b3f1138fce6e818e79d4128afa28f0657430764cc045419c1d069ff93f732"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e9a7ffb1e551c6df51d267f5a751f042b183df22690f6feceac8d27364fd51d7"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:44e230fbbe120de564b64f63ef3a8e6ff02840fa02849d9c443d56252a1646d4"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:9d16b3b2fcc6fca012c74bd01b5619c655194d3e3c13e4d4d0e446eefa39a463"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:5d95ea4d8b32233910734a904123bdd3979c137c461b905a5ed32511defc075f"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:dab431699b5d45e0ca043bc580651ce9583dda594e62e245b7497adb32e99809"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16d8bbb7344e39cb9d4762e85c21df94ebeb03edac923dd94bb9ed8c10eac070"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1faf5111c66a6ba91f85dff2e36edaaf36e6966172703159daeef330de4ddc7b"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:660ca1d8bff6c7bc7c6b30b9b32df74ef3ab668f5df42cefd7588f0d40feadcb"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ae2f67f04ed0bb2e56fd380a8bd3eef03f609df53f88b6f5c7e89c08e52aae00"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:c8ed5d2c04cdc1afc6b3c28d59580448ac07732c50d94c15e14670f9c473a2ce"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:2b6a2814278b6660261aa9a9aae524616de9f1ec364e3716d219b6ed8f91801f"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:3cfd1ca35eacc635f0eaa894e5c5ed83ffebd0f95cac298fd430014fa7323631"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4177b456c6b0c722d82429127b5beebdaf07149d265748e97e0a34ff0b3694c8"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:313e8472197bde54e3ec54a62df184c414582979da8f3916981b6a7954910a1b"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fdb4adb76e21bad318210310590de61c9f4adcef77ee49b4a234f9dc48867869"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:1d568628e9c43ca15eb96c217da73737c9ccb07520fafd8a1eba3f2750614105"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:573b6023a55a2f28085fc0a84e196c779b6cbef4d9e73acea14c8094fee7686f"}, + {file = "safetensors-0.4.1.tar.gz", hash = "sha256:2304658e6ada81a5223225b4efe84748e760c46079bffedf7e321763cafb36c9"}, +] + +[package.extras] +all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"] +dev = ["safetensors[all]"] +jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"] +numpy = ["numpy (>=1.21.6)"] +paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"] +pinned-tf = ["safetensors[numpy]", "tensorflow (==2.11.0)"] +quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] +tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] +testing = ["h5py (>=3.7.0)", "huggingface_hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools_rust (>=1.5.2)"] +torch = ["safetensors[numpy]", "torch (>=1.10)"] + [[package]] name = "six" version = "1.16.0" @@ -2457,6 +2863,59 @@ files = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +[[package]] +name = "torch" +version = "2.1.2" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3a871edd6c02dae77ad810335c0833391c1a4ce49af21ea8cf0f6a5d2096eea8"}, + {file = "torch-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076"}, + {file = "torch-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e13034fd5fb323cbbc29e56d0637a3791e50dd589616f40c79adfa36a5a35a1"}, + {file = "torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:d9b535cad0df3d13997dbe8bd68ac33e0e3ae5377639c9881948e40794a61403"}, + {file = "torch-2.1.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:f9a55d55af02826ebfbadf4e9b682f0f27766bc33df8236b48d28d705587868f"}, + {file = "torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:a6ebbe517097ef289cc7952783588c72de071d4b15ce0f8b285093f0916b1162"}, + {file = "torch-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8f32ce591616a30304f37a7d5ea80b69ca9e1b94bba7f308184bf616fdaea155"}, + {file = "torch-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e0ee6cf90c8970e05760f898d58f9ac65821c37ffe8b04269ec787aa70962b69"}, + {file = "torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:76d37967c31c99548ad2c4d3f2cf191db48476f2e69b35a0937137116da356a1"}, + {file = "torch-2.1.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e2d83f07b4aac983453ea5bf8f9aa9dacf2278a8d31247f5d9037f37befc60e4"}, + {file = "torch-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f41fe0c7ecbf903a568c73486139a75cfab287a0f6c17ed0698fdea7a1e8641d"}, + {file = "torch-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e3225f47d50bb66f756fe9196a768055d1c26b02154eb1f770ce47a2578d3aa7"}, + {file = "torch-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33d59cd03cb60106857f6c26b36457793637512998666ee3ce17311f217afe2b"}, + {file = "torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:8e221deccd0def6c2badff6be403e0c53491805ed9915e2c029adbcdb87ab6b5"}, + {file = "torch-2.1.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05b18594f60a911a0c4f023f38a8bda77131fba5fd741bda626e97dcf5a3dd0a"}, + {file = "torch-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ca96253b761e9aaf8e06fb30a66ee301aecbf15bb5a303097de1969077620b6"}, + {file = "torch-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d93ba70f67b08c2ae5598ee711cbc546a1bc8102cef938904b8c85c2089a51a0"}, + {file = "torch-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:255b50bc0608db177e6a3cc118961d77de7e5105f07816585fa6f191f33a9ff3"}, + {file = "torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6984cd5057c0c977b3c9757254e989d3f1124f4ce9d07caa6cb637783c71d42a"}, + {file = "torch-2.1.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:bc195d7927feabc0eb7c110e457c955ed2ab616f3c7c28439dd4188cf589699f"}, +] + +[package.dependencies] +filelock = "*" +fsspec = "*" +jinja2 = "*" +networkx = "*" +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +sympy = "*" +triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +typing-extensions = "*" + +[package.extras] +dynamo = ["jinja2"] +opt-einsum = ["opt-einsum (>=3.3)"] + [[package]] name = "tornado" version = "6.4" @@ -2477,6 +2936,33 @@ files = [ {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"}, ] +[[package]] +name = "tox" +version = "4.11.4" +description = "tox is a generic virtualenv management and test command line tool" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tox-4.11.4-py3-none-any.whl", hash = "sha256:2adb83d68f27116812b69aa36676a8d6a52249cb0d173649de0e7d0c2e3e7229"}, + {file = "tox-4.11.4.tar.gz", hash = "sha256:73a7240778fabf305aeb05ab8ea26e575e042ab5a18d71d0ed13e343a51d6ce1"}, +] + +[package.dependencies] +cachetools = ">=5.3.1" +chardet = ">=5.2" +colorama = ">=0.4.6" +filelock = ">=3.12.3" +packaging = ">=23.1" +platformdirs = ">=3.10" +pluggy = ">=1.3" +pyproject-api = ">=1.6.1" +tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} +virtualenv = ">=20.24.3" + +[package.extras] +docs = ["furo (>=2023.8.19)", "sphinx (>=7.2.4)", "sphinx-argparse-cli (>=1.11.1)", "sphinx-autodoc-typehints (>=1.24)", "sphinx-copybutton (>=0.5.2)", "sphinx-inline-tabs (>=2023.4.21)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +testing = ["build[virtualenv] (>=0.10)", "covdefaults (>=2.3)", "detect-test-pollution (>=1.1.1)", "devpi-process (>=1)", "diff-cover (>=7.7)", "distlib (>=0.3.7)", "flaky (>=3.7)", "hatch-vcs (>=0.3)", "hatchling (>=1.18)", "psutil (>=5.9.5)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-xdist (>=3.3.1)", "re-assert (>=1.1)", "time-machine (>=2.12)", "wheel (>=0.41.2)"] + [[package]] name = "tqdm" version = "4.66.1" @@ -2512,6 +2998,99 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "transformers" +version = "4.36.2" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "transformers-4.36.2-py3-none-any.whl", hash = "sha256:462066c4f74ee52516f12890dcc9ec71d1a5e97998db621668455117a54330f6"}, + {file = "transformers-4.36.2.tar.gz", hash = "sha256:d8068e897e47793281501e547d2bbdfc5b8556409c2cb6c3d9e2ca77d4c0b4ec"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.19.3,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.3.1" +tokenizers = ">=0.14,<0.19" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.21.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +codecarbon = ["codecarbon (==1.2.0)"] +deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +docs-specific = ["hf-doc-builder"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +timm = ["timm"] +tokenizers = ["tokenizers (>=0.14,<0.19)"] +torch = ["accelerate (>=0.21.0)", "torch (>=1.10,!=1.12.0)"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] +video = ["av (==9.2.0)", "decord (==0.6.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] + +[[package]] +name = "triton" +version = "2.1.0" +description = "A language and compiler for custom Deep Learning operations" +optional = true +python-versions = "*" +files = [ + {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, + {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, + {file = "triton-2.1.0-0-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ae4bb8a91de790e1866405211c4d618379781188f40d5c4c399766914e84cd94"}, + {file = "triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39f6fb6bdccb3e98f3152e3fbea724f1aeae7d749412bbb1fa9c441d474eba26"}, + {file = "triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21544e522c02005a626c8ad63d39bdff2f31d41069592919ef281e964ed26446"}, + {file = "triton-2.1.0-0-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:143582ca31dd89cd982bd3bf53666bab1c7527d41e185f9e3d8a3051ce1b663b"}, + {file = "triton-2.1.0-0-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82fc5aeeedf6e36be4e4530cbdcba81a09d65c18e02f52dc298696d45721f3bd"}, + {file = "triton-2.1.0-0-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:81a96d110a738ff63339fc892ded095b31bd0d205e3aace262af8400d40b6fa8"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.18)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + [[package]] name = "types-pyyaml" version = "6.0.12.12" @@ -2550,6 +3129,26 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "virtualenv" +version = "20.25.0" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.25.0-py3-none-any.whl", hash = "sha256:4238949c5ffe6876362d9c0180fc6c3a824a7b12b80604eeb8085f2ed7460de3"}, + {file = "virtualenv-20.25.0.tar.gz", hash = "sha256:bf51c0d9c7dd63ea8e44086fa1e4fb1093a31e963b86959257378aef020e1f1b"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "wcwidth" version = "0.2.13" @@ -2692,8 +3291,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] fastembed = ["fastembed"] hybrid = ["pinecone-text"] +local = ["torch", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "3e3b13e2493e7bef6ef1d9487d4618f834f3387a55379edf63d00f76fe4def0a" +content-hash = "4ceb0344e006fc7657c66548321ff51d34a297ee6f9ab069fdc53e1256024a12" diff --git a/pyproject.toml b/pyproject.toml index bbe14c563a182133b69bcbbf27c1f7922f1fbb76..45f105fd9ac63138f7b5d4d82b6300d1052f2ccb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,12 +21,15 @@ cohere = "^4.32" numpy = "^1.25.2" colorlog = "^6.8.0" pyyaml = "^6.0.1" -pinecone-text = {version = "^0.7.1", optional = true, python = "<3.12"} +pinecone-text = {version = "^0.7.1", optional = true} fastembed = {version = "^0.1.3", optional = true, python = "<3.12"} +torch = {version = "^2.1.2", optional = true} +transformers = {version = "^4.36.2", optional = true} [tool.poetry.extras] hybrid = ["pinecone-text"] fastembed = ["fastembed"] +local = ["torch", "transformers"] [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.0" diff --git a/replace.py b/replace.py new file mode 100644 index 0000000000000000000000000000000000000000..508c3f33f32c42600dc9aa25c21f62bbaf119dd3 --- /dev/null +++ b/replace.py @@ -0,0 +1,28 @@ +import os +import re + + +def replace_type_hints(file_path): + with open(file_path, "rb") as file: + file_data = file.read() + + # Decode the file data with error handling + file_data = file_data.decode("utf-8", errors="ignore") + + # Regular expression pattern to find 'Optional[dict[int, int]]' and replace with 'Optional[dict[int, int]]' + file_data = re.sub( + r"dict\[(\w+), (\w+)\]\s*\|\s*None", r"Optional[dict[\1, \2]]", file_data + ) + + with open(file_path, "w") as file: + file.write(file_data) + + +# Directory path +dir_path = "/Users/jakit/customers/aurelio/semantic-router" + +# Traverse the directory +for root, dirs, files in os.walk(dir_path): + for file in files: + if file.endswith(".py"): + replace_type_hints(os.path.join(root, file)) diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 9d3a027ee2927452e27e78f4bc1b11b0b1152702..cc6fc7f99b47c6ca7f08481e9ce49c324c30e8b0 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -2,6 +2,7 @@ from semantic_router.encoders.base import BaseEncoder from semantic_router.encoders.bm25 import BM25Encoder from semantic_router.encoders.cohere import CohereEncoder from semantic_router.encoders.fastembed import FastEmbedEncoder +from semantic_router.encoders.huggingface import HuggingFaceEncoder from semantic_router.encoders.openai import OpenAIEncoder __all__ = [ @@ -10,4 +11,5 @@ __all__ = [ "OpenAIEncoder", "BM25Encoder", "FastEmbedEncoder", + "HuggingFaceEncoder", ] diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index 11a964b6cd0f430a9f4d9aeec801542ac7d3e225..451273cdcc78899861c7b64baea4eb4e1cc6b33b 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -6,10 +6,15 @@ from semantic_router.utils.logger import logger class BM25Encoder(BaseEncoder): model: Optional[Any] = None - idx_mapping: dict[int, int] | None = None + idx_mapping: Optional[dict[int, int]] = None type: str = "sparse" - def __init__(self, name: str = "bm25", score_threshold: float = 0.82): + def __init__( + self, + name: str = "bm25", + score_threshold: float = 0.82, + use_default_params: bool = True, + ): super().__init__(name=name, score_threshold=score_threshold) try: from pinecone_text.sparse import BM25Encoder as encoder @@ -18,9 +23,15 @@ class BM25Encoder(BaseEncoder): "Please install pinecone-text to use BM25Encoder. " "You can install it with: `pip install 'semantic-router[hybrid]'`" ) - logger.info("Downloading and initializing BM25 model parameters.") - self.model = encoder.default() + self.model = encoder() + + if use_default_params: + logger.info("Downloading and initializing default sBM25 model parameters.") + self.model = encoder.default() + self._set_idx_mapping() + + def _set_idx_mapping(self): params = self.model.get_params() doc_freq = params["doc_freq"] if isinstance(doc_freq, dict): @@ -53,3 +64,4 @@ class BM25Encoder(BaseEncoder): if self.model is None: raise ValueError("Model is not initialized.") self.model.fit(docs) + self._set_idx_mapping() diff --git a/semantic_router/encoders/huggingface.py b/semantic_router/encoders/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..ace189213b76aed940dd8b4280ce1505339f656f --- /dev/null +++ b/semantic_router/encoders/huggingface.py @@ -0,0 +1,114 @@ +from typing import Any, Optional + +from pydantic import PrivateAttr + +from semantic_router.encoders import BaseEncoder + + +class HuggingFaceEncoder(BaseEncoder): + name: str = "sentence-transformers/all-MiniLM-L6-v2" + type: str = "huggingface" + score_threshold: float = 0.5 + tokenizer_kwargs: dict = {} + model_kwargs: dict = {} + device: Optional[str] = None + _tokenizer: Any = PrivateAttr() + _model: Any = PrivateAttr() + _torch: Any = PrivateAttr() + + def __init__(self, **data): + super().__init__(**data) + self._tokenizer, self._model = self._initialize_hf_model() + + def _initialize_hf_model(self): + try: + from transformers import AutoModel, AutoTokenizer + except ImportError: + raise ImportError( + "Please install transformers to use HuggingFaceEncoder. " + "You can install it with: " + "`pip install semantic-router[local]`" + ) + + try: + import torch + except ImportError: + raise ImportError( + "Please install Pytorch to use HuggingFaceEncoder. " + "You can install it with: " + "`pip install semantic-router[local]`" + ) + + self._torch = torch + + tokenizer = AutoTokenizer.from_pretrained( + self.name, + **self.tokenizer_kwargs, + ) + + model = AutoModel.from_pretrained(self.name, **self.model_kwargs) + + if self.device: + model.to(self.device) + + else: + device = "cuda" if self._torch.cuda.is_available() else "cpu" + model.to(device) + self.device = device + + return tokenizer, model + + def __call__( + self, + docs: list[str], + batch_size: int = 32, + normalize_embeddings: bool = True, + pooling_strategy: str = "mean", + ) -> list[list[float]]: + all_embeddings = [] + for i in range(0, len(docs), batch_size): + batch_docs = docs[i : i + batch_size] + + encoded_input = self._tokenizer( + batch_docs, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + + with self._torch.no_grad(): + model_output = self._model(**encoded_input) + + if pooling_strategy == "mean": + embeddings = self._mean_pooling( + model_output, encoded_input["attention_mask"] + ) + elif pooling_strategy == "max": + embeddings = self._max_pooling( + model_output, encoded_input["attention_mask"] + ) + else: + raise ValueError( + "Invalid pooling_strategy. Please use 'mean' or 'max'." + ) + + if normalize_embeddings: + embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1) + + embeddings = embeddings.tolist() + all_embeddings.extend(embeddings) + return all_embeddings + + def _mean_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return self._torch.sum( + token_embeddings * input_mask_expanded, 1 + ) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + def _max_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + token_embeddings[input_mask_expanded == 0] = -1e9 + return self._torch.max(token_embeddings, 1)[0] diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 4504fefec56bc8a590b7a9028807759e57575136..169761afa8f726a72439a534a69bac3ebf73de29 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -1,5 +1,6 @@ import os from time import sleep +from typing import Optional import openai from openai import OpenAIError @@ -7,7 +8,6 @@ from openai.types import CreateEmbeddingResponse from semantic_router.encoders import BaseEncoder from semantic_router.utils.logger import logger -from typing import Optional class OpenAIEncoder(BaseEncoder): diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 06862a63dd6646913704614b7ca03a139004477a..aa65d7d74304db940b921f6168b180636df5baa3 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np from numpy.linalg import norm @@ -7,7 +9,6 @@ from semantic_router.encoders import ( ) from semantic_router.route import Route from semantic_router.utils.logger import logger -from typing import Optional class HybridRouteLayer: @@ -17,11 +18,21 @@ class HybridRouteLayer: score_threshold: float def __init__( - self, encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3 + self, + encoder: BaseEncoder, + sparse_encoder: Optional[BaseEncoder] = None, + routes: list[Route] = [], + alpha: float = 0.3, ): self.encoder = encoder self.score_threshold = self.encoder.score_threshold - self.sparse_encoder = BM25Encoder() + + if sparse_encoder is None: + logger.warning("No sparse_encoder provided. Using default BM25Encoder.") + self.sparse_encoder = BM25Encoder() + else: + self.sparse_encoder = sparse_encoder + self.alpha = alpha # if routes list has been passed, we initialize index now if routes: diff --git a/semantic_router/layer.py b/semantic_router/layer.py index e6a214f901885a50cb17e1876c4daa71a517a8d8..082e24b642dffa48a81668001ecddf5b1261f08c 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -157,7 +157,7 @@ class RouteLayer: self, encoder: Optional[BaseEncoder] = None, llm: Optional[BaseLLM] = None, - routes: list[Route] | None = None, + routes: Optional[list[Route]] = None, ): logger.info("Initializing RouteLayer") self.index = None diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index d3b215bffe5eda3b7c476aebefa978024a9ff83d..8b3442c742c2f268773bf551fb49ce0cd24645af 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -1,11 +1,11 @@ import os +from typing import Optional import openai from semantic_router.llms import BaseLLM from semantic_router.schema import Message from semantic_router.utils.logger import logger -from typing import Optional class OpenAILLM(BaseLLM): diff --git a/semantic_router/llms/openrouter.py b/semantic_router/llms/openrouter.py index 6130e0a7d48e3245df7863a4e9f6e15c20cbea15..4cc15d6bfedbfa67fb5957129d1ce901544dcb38 100644 --- a/semantic_router/llms/openrouter.py +++ b/semantic_router/llms/openrouter.py @@ -1,11 +1,11 @@ import os +from typing import Optional import openai from semantic_router.llms import BaseLLM from semantic_router.schema import Message from semantic_router.utils.logger import logger -from typing import Optional class OpenRouterLLM(BaseLLM): diff --git a/semantic_router/route.py b/semantic_router/route.py index c2b9b3dcdb127c343ffd1ecb05a18023aec942e7..6cca7eaf7aae943ab7adf97c1369e7b943f1a655 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -1,6 +1,6 @@ import json import re -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union from pydantic import BaseModel @@ -8,7 +8,6 @@ from semantic_router.llms import BaseLLM from semantic_router.schema import Message, RouteChoice from semantic_router.utils import function_call from semantic_router.utils.logger import logger -from typing import Optional def is_valid(route_config: str) -> bool: @@ -43,7 +42,7 @@ class Route(BaseModel): name: str utterances: list[str] description: Optional[str] = None - function_schema: dict[str, Any] | None = None + function_schema: Optional[dict[str, Any]] = None llm: Optional[BaseLLM] = None def __call__(self, query: str) -> RouteChoice: diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 9505df243adb8ed46f3aa0578565703358c11ba6..88dd753cc65ec3771db3635463b1688ecc89160a 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional from pydantic import BaseModel from pydantic.dataclasses import dataclass @@ -10,7 +11,6 @@ from semantic_router.encoders import ( OpenAIEncoder, ) from semantic_router.utils.splitters import semantic_splitter -from typing import Optional class EncoderType(Enum): diff --git a/semantic_router/utils/llm.py b/semantic_router/utils/llm.py index f0db13c893eb764da3814b71bf807a733700b85c..4f89566f79df700ec3ae8103b5b8cdfd84700627 100644 --- a/semantic_router/utils/llm.py +++ b/semantic_router/utils/llm.py @@ -1,9 +1,9 @@ import os +from typing import Optional import openai from semantic_router.utils.logger import logger -from typing import Optional def llm(prompt: str) -> Optional[str]: diff --git a/tests/unit/encoders/test_bm25.py b/tests/unit/encoders/test_bm25.py index e654d7bbc98070b6db16249c3164d515852e28e5..174453d254370c609c83a7f2a533b4d03ca264dc 100644 --- a/tests/unit/encoders/test_bm25.py +++ b/tests/unit/encoders/test_bm25.py @@ -5,7 +5,11 @@ from semantic_router.encoders import BM25Encoder @pytest.fixture def bm25_encoder(): - return BM25Encoder() + sparse_encoder = BM25Encoder(use_default_params=False) + sparse_encoder.fit( + ["The quick brown fox", "jumps over the lazy dog", "Hello, world!"] + ) + return sparse_encoder class TestBM25Encoder: diff --git a/tests/unit/encoders/test_huggingface.py b/tests/unit/encoders/test_huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..1e916f3eca2956b32c6147db6cf5c42695720f14 --- /dev/null +++ b/tests/unit/encoders/test_huggingface.py @@ -0,0 +1,63 @@ +from unittest.mock import patch + +import numpy as np +import pytest + +from semantic_router.encoders.huggingface import HuggingFaceEncoder + +encoder = HuggingFaceEncoder() + + +class TestHuggingFaceEncoder: + def test_huggingface_encoder_import_errors_transformers(self): + with patch.dict("sys.modules", {"transformers": None}): + with pytest.raises(ImportError) as error: + HuggingFaceEncoder() + + assert "Please install transformers to use HuggingFaceEncoder" in str( + error.value + ) + + def test_huggingface_encoder_import_errors_torch(self): + with patch.dict("sys.modules", {"torch": None}): + with pytest.raises(ImportError) as error: + HuggingFaceEncoder() + + assert "Please install Pytorch to use HuggingFaceEncoder" in str(error.value) + + def test_huggingface_encoder_mean_pooling(self): + test_docs = ["This is a test", "This is another test"] + embeddings = encoder(test_docs, pooling_strategy="mean") + assert isinstance(embeddings, list) + assert len(embeddings) == len(test_docs) + assert all(isinstance(embedding, list) for embedding in embeddings) + assert all(len(embedding) > 0 for embedding in embeddings) + + def test_huggingface_encoder_max_pooling(self): + test_docs = ["This is a test", "This is another test"] + embeddings = encoder(test_docs, pooling_strategy="max") + assert isinstance(embeddings, list) + assert len(embeddings) == len(test_docs) + assert all(isinstance(embedding, list) for embedding in embeddings) + assert all(len(embedding) > 0 for embedding in embeddings) + + def test_huggingface_encoder_normalized_embeddings(self): + docs = ["This is a test document.", "Another test document."] + unnormalized_embeddings = encoder(docs, normalize_embeddings=False) + normalized_embeddings = encoder(docs, normalize_embeddings=True) + assert len(unnormalized_embeddings) == len(normalized_embeddings) + + for unnormalized, normalized in zip( + unnormalized_embeddings, normalized_embeddings + ): + norm_unnormalized = np.linalg.norm(unnormalized, ord=2) + norm_normalized = np.linalg.norm(normalized, ord=2) + # Ensure the norm of the normalized embeddings is approximately 1 + assert np.isclose(norm_normalized, 1.0) + # Ensure the normalized embeddings are actually normalized versions of unnormalized embeddings + np.testing.assert_allclose( + normalized, + np.divide(unnormalized, norm_unnormalized), + rtol=1e-5, + atol=1e-5, # Adjust tolerance levels + ) diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 6896c4de1cb1e13196d209455f2bd39e8e14915d..df530149d72fe44d836765b7654a2cbcdf71c694 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -1,6 +1,11 @@ import pytest -from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder +from semantic_router.encoders import ( + BaseEncoder, + BM25Encoder, + CohereEncoder, + OpenAIEncoder, +) from semantic_router.hybrid_layer import HybridRouteLayer from semantic_router.route import Route @@ -42,9 +47,15 @@ def routes(): ] +sparse_encoder = BM25Encoder(use_default_params=False) +sparse_encoder.fit(["The quick brown fox", "jumps over the lazy dog", "Hello, world!"]) + + class TestHybridRouteLayer: def test_initialization(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes + ) assert route_layer.index is not None and route_layer.categories is not None assert openai_encoder.score_threshold == 0.82 assert route_layer.score_threshold == 0.82 @@ -52,14 +63,20 @@ class TestHybridRouteLayer: assert len(set(route_layer.categories)) == 2 def test_initialization_different_encoders(self, cohere_encoder, openai_encoder): - route_layer_cohere = HybridRouteLayer(encoder=cohere_encoder) + route_layer_cohere = HybridRouteLayer( + encoder=cohere_encoder, sparse_encoder=sparse_encoder + ) assert route_layer_cohere.score_threshold == 0.3 - route_layer_openai = HybridRouteLayer(encoder=openai_encoder) + route_layer_openai = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder + ) assert route_layer_openai.score_threshold == 0.82 def test_add_route(self, openai_encoder): - route_layer = HybridRouteLayer(encoder=openai_encoder) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder + ) route = Route(name="Route 3", utterances=["Yes", "No"]) route_layer._add_routes([route]) assert route_layer.index is not None and route_layer.categories is not None @@ -67,7 +84,9 @@ class TestHybridRouteLayer: assert len(set(route_layer.categories)) == 1 def test_add_multiple_routes(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder + ) for route in routes: route_layer.add(route) assert route_layer.index is not None and route_layer.categories is not None @@ -75,16 +94,22 @@ class TestHybridRouteLayer: assert len(set(route_layer.categories)) == 2 def test_query_and_classification(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes + ) query_result = route_layer("Hello") assert query_result in ["Route 1", "Route 2"] def test_query_with_no_index(self, openai_encoder): - route_layer = HybridRouteLayer(encoder=openai_encoder) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder + ) assert route_layer("Anything") is None def test_semantic_classify(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes + ) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -95,7 +120,9 @@ class TestHybridRouteLayer: assert score == [0.9] def test_semantic_classify_multiple_routes(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes + ) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -107,12 +134,16 @@ class TestHybridRouteLayer: assert score == [0.9, 0.8] def test_pass_threshold(self, openai_encoder): - route_layer = HybridRouteLayer(encoder=openai_encoder) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder + ) assert not route_layer._pass_threshold([], 0.5) assert route_layer._pass_threshold([0.6, 0.7], 0.5) def test_failover_score_threshold(self, base_encoder): - route_layer = HybridRouteLayer(encoder=base_encoder) + route_layer = HybridRouteLayer( + encoder=base_encoder, sparse_encoder=sparse_encoder + ) assert base_encoder.score_threshold == 0.50 assert route_layer.score_threshold == 0.50