diff --git a/.env.example b/.env.example index 85742347a89803823c44a3e44aadfbce934263c9..9f97b08c533e1306aaa7e14fcfacffec3de38330 100644 --- a/.env.example +++ b/.env.example @@ -1 +1 @@ -COHERE_API_KEY= \ No newline at end of file +COHERE_API_KEY= diff --git a/LICENSE b/LICENSE index a477bd488ad05487e6f225e16ddd10692289026e..3fa741bbaa7903b52deccb2358ad3b0ac6b24759 100644 --- a/LICENSE +++ b/LICENSE @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 5d0cb4525a5b4b50d6ae564827044a9899a5172f..0dd149b9456453114888105c9830761e354748ba 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -4,53 +4,53 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Semantic Router: Hybrid Layer" + "# Semantic Router: Hybrid Layer\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The Hybrid Layer in the Semantic Router library can improve making performance particularly for niche use-cases that contain specific terminology, such as finance or medical. It helps us provide more importance to making based on the keywords contained in our utterances and user queries." + "The Hybrid Layer in the Semantic Router library can improve making performance particularly for niche use-cases that contain specific terminology, such as finance or medical. It helps us provide more importance to making based on the keywords contained in our utterances and user queries.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Getting Started" + "## Getting Started\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We start by installing the library:" + "We start by installing the library:\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "!pip install -qU semantic-router==0.0.11" + "#!pip install -qU semantic-router==0.0.11" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We start by defining a dictionary mapping s to example phrases that should trigger those s." + "We start by defining a dictionary mapping s to example phrases that should trigger those s.\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "from semantic_router.schema import Route\n", + "from semantic_router.route import Route\n", "\n", "politics = Route(\n", " name=\"politics\",\n", @@ -69,7 +69,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's define another for good measure:" + "Let's define another for good measure:\n" ] }, { @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -114,58 +114,92 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now we initialize our embedding model:" + "Now we initialize our embedding model:\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import os\n", - "from semantic_router.encoders import CohereEncoder\n", + "from semantic_router.encoders import CohereEncoder, BM25Encoder, TfidfEncoder\n", "from getpass import getpass\n", "\n", "os.environ[\"COHERE_API_KEY\"] = os.environ[\"COHERE_API_KEY\"] or getpass(\n", " \"Enter Cohere API Key: \"\n", ")\n", "\n", - "encoder = CohereEncoder()" + "dense_encoder = CohereEncoder()\n", + "# sparse_encoder = BM25Encoder()\n", + "sparse_encoder = TfidfEncoder()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`.\n" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-08 16:50:29 INFO semantic_router.utils.logger Creating embeddings for all routes...\u001b[0m\n" + ] + } + ], "source": [ "from semantic_router.hybrid_layer import HybridRouteLayer\n", "\n", - "dl = HybridRouteLayer(encoder=encoder, routes=routes)" + "dl = HybridRouteLayer(\n", + " dense_encoder=dense_encoder, sparse_encoder=sparse_encoder, routes=routes\n", + ")" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'politics'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "dl(\"don't you love politics?\")" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'chitchat'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "dl(\"how's the weather today?\")" ] @@ -174,7 +208,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "---" + "---\n" ] } ], diff --git a/poetry.lock b/poetry.lock index 26e392884612b9b043d9234a858b83554444b278..96a21bbc54879c58239b7afab825ba958c34d213 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -250,17 +250,6 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] -[[package]] -name = "cachetools" -version = "5.3.2" -description = "Extensible memoizing collections and decorators" -optional = false -python-versions = ">=3.7" -files = [ - {file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"}, - {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"}, -] - [[package]] name = "certifi" version = "2023.11.17" @@ -336,17 +325,6 @@ files = [ [package.dependencies] pycparser = "*" -[[package]] -name = "chardet" -version = "5.2.0" -description = "Universal encoding detector for Python 3" -optional = false -python-versions = ">=3.7" -files = [ - {file = "chardet-5.2.0-py3-none-any.whl", hash = "sha256:e1cf59446890a00105fe7b7912492ea04b6e6f06d4b742b2c788469e34c82970"}, - {file = "chardet-5.2.0.tar.gz", hash = "sha256:1b3b6ff479a8c414bc3fa2c0852995695c4a026dcd6d0633b2dd092ca39c1cf7"}, -] - [[package]] name = "charset-normalizer" version = "3.3.2" @@ -646,17 +624,6 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] -[[package]] -name = "distlib" -version = "0.3.8" -description = "Distribution utilities" -optional = false -python-versions = "*" -files = [ - {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, - {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, -] - [[package]] name = "distro" version = "1.9.0" @@ -779,7 +746,7 @@ tqdm = ">=4.65,<5.0" name = "filelock" version = "3.13.1" description = "A platform independent file lock." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, @@ -1254,6 +1221,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2105,25 +2082,6 @@ files = [ plugins = ["importlib-metadata"] windows-terminal = ["colorama (>=0.4.6)"] -[[package]] -name = "pyproject-api" -version = "1.6.1" -description = "API to interact with the python pyproject.toml based projects" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pyproject_api-1.6.1-py3-none-any.whl", hash = "sha256:4c0116d60476b0786c88692cf4e325a9814965e2469c5998b830bba16b183675"}, - {file = "pyproject_api-1.6.1.tar.gz", hash = "sha256:1817dc018adc0d1ff9ca1ed8c60e1623d5aaca40814b953af14a9cf9a5cae538"}, -] - -[package.dependencies] -packaging = ">=23.1" -tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} - -[package.extras] -docs = ["furo (>=2023.8.19)", "sphinx (<7.2)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "setuptools (>=68.1.2)", "wheel (>=0.41.2)"] - [[package]] name = "pyreadline3" version = "3.4.1" @@ -2261,6 +2219,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2268,8 +2227,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2286,6 +2252,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2293,6 +2260,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -2936,33 +2904,6 @@ files = [ {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"}, ] -[[package]] -name = "tox" -version = "4.11.4" -description = "tox is a generic virtualenv management and test command line tool" -optional = false -python-versions = ">=3.8" -files = [ - {file = "tox-4.11.4-py3-none-any.whl", hash = "sha256:2adb83d68f27116812b69aa36676a8d6a52249cb0d173649de0e7d0c2e3e7229"}, - {file = "tox-4.11.4.tar.gz", hash = "sha256:73a7240778fabf305aeb05ab8ea26e575e042ab5a18d71d0ed13e343a51d6ce1"}, -] - -[package.dependencies] -cachetools = ">=5.3.1" -chardet = ">=5.2" -colorama = ">=0.4.6" -filelock = ">=3.12.3" -packaging = ">=23.1" -platformdirs = ">=3.10" -pluggy = ">=1.3" -pyproject-api = ">=1.6.1" -tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} -virtualenv = ">=20.24.3" - -[package.extras] -docs = ["furo (>=2023.8.19)", "sphinx (>=7.2.4)", "sphinx-argparse-cli (>=1.11.1)", "sphinx-autodoc-typehints (>=1.24)", "sphinx-copybutton (>=0.5.2)", "sphinx-inline-tabs (>=2023.4.21)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -testing = ["build[virtualenv] (>=0.10)", "covdefaults (>=2.3)", "detect-test-pollution (>=1.1.1)", "devpi-process (>=1)", "diff-cover (>=7.7)", "distlib (>=0.3.7)", "flaky (>=3.7)", "hatch-vcs (>=0.3)", "hatchling (>=1.18)", "psutil (>=5.9.5)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-xdist (>=3.3.1)", "re-assert (>=1.1)", "time-machine (>=2.12)", "wheel (>=0.41.2)"] - [[package]] name = "tqdm" version = "4.66.1" @@ -3129,26 +3070,6 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] -[[package]] -name = "virtualenv" -version = "20.25.0" -description = "Virtual Python Environment builder" -optional = false -python-versions = ">=3.7" -files = [ - {file = "virtualenv-20.25.0-py3-none-any.whl", hash = "sha256:4238949c5ffe6876362d9c0180fc6c3a824a7b12b80604eeb8085f2ed7460de3"}, - {file = "virtualenv-20.25.0.tar.gz", hash = "sha256:bf51c0d9c7dd63ea8e44086fa1e4fb1093a31e963b86959257378aef020e1f1b"}, -] - -[package.dependencies] -distlib = ">=0.3.7,<1" -filelock = ">=3.12.2,<4" -platformdirs = ">=3.9.1,<5" - -[package.extras] -docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] - [[package]] name = "wcwidth" version = "0.2.13" @@ -3296,4 +3217,4 @@ local = ["torch", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4ceb0344e006fc7657c66548321ff51d34a297ee6f9ab069fdc53e1256024a12" +content-hash = "5b459c6820bcf5c2b73daf0ecfcbbac95019311c74d88634bd7188650e48b749" diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 695e508a25410c6583aa814688f6b64b78ee09b7..57cebca8d426ab585478a15b19449f27ec31d8d2 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -5,6 +5,7 @@ from semantic_router.encoders.fastembed import FastEmbedEncoder from semantic_router.encoders.openai import OpenAIEncoder from semantic_router.encoders.zure import AzureOpenAIEncoder from semantic_router.encoders.huggingface import HuggingFaceEncoder +from semantic_router.encoders.tfidf import TfidfEncoder __all__ = [ "BaseEncoder", @@ -12,6 +13,7 @@ __all__ = [ "CohereEncoder", "OpenAIEncoder", "BM25Encoder", + "TfidfEncoder", "FastEmbedEncoder", "HuggingFaceEncoder", ] diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index f5968578ead0d01a269876f948e259910a6116fb..edc98641147668705150a0ee1242e77eeeebb431 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -1,3 +1,4 @@ +from typing import List from pydantic import BaseModel, Field @@ -9,5 +10,5 @@ class BaseEncoder(BaseModel): class Config: arbitrary_types_allowed = True - def __call__(self, docs: list[str]) -> list[list[float]]: + def __call__(self, docs: List[str]) -> List[List[float]]: raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index 451273cdcc78899861c7b64baea4eb4e1cc6b33b..83cbccc06fe453203cd729e6ab2f56c4237a0f74 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, List, Dict from semantic_router.encoders import BaseEncoder from semantic_router.utils.logger import logger @@ -6,7 +6,7 @@ from semantic_router.utils.logger import logger class BM25Encoder(BaseEncoder): model: Optional[Any] = None - idx_mapping: Optional[dict[int, int]] = None + idx_mapping: Optional[Dict[int, int]] = None type: str = "sparse" def __init__( @@ -40,7 +40,7 @@ class BM25Encoder(BaseEncoder): else: raise TypeError("Expected a dictionary for 'doc_freq'") - def __call__(self, docs: list[str]) -> list[list[float]]: + def __call__(self, docs: List[str]) -> List[List[float]]: if self.model is None or self.idx_mapping is None: raise ValueError("Model or index mapping is not initialized.") if len(docs) == 1: @@ -60,7 +60,7 @@ class BM25Encoder(BaseEncoder): embeds[i][position] = val return embeds - def fit(self, docs: list[str]): + def fit(self, docs: List[str]): if self.model is None: raise ValueError("Model is not initialized.") self.model.fit(docs) diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index ec8ee0f8fcebd39444f3689d7bdffc2b7e98c812..803fe779f82b54460040d5ba57b82aff1bcb1f13 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Optional, List import cohere @@ -27,7 +27,7 @@ class CohereEncoder(BaseEncoder): except Exception as e: raise ValueError(f"Cohere API client failed to initialize. Error: {e}") - def __call__(self, docs: list[str]) -> list[list[float]]: + def __call__(self, docs: List[str]) -> List[List[float]]: if self.client is None: raise ValueError("Cohere client is not initialized.") try: diff --git a/semantic_router/encoders/fastembed.py b/semantic_router/encoders/fastembed.py index 98cfc6cc529ff4c06e0a3e7f0d0af779df7f71dd..ec356317671fc93848e0f3977985cec1a221d827 100644 --- a/semantic_router/encoders/fastembed.py +++ b/semantic_router/encoders/fastembed.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, Optional, List import numpy as np from pydantic import PrivateAttr @@ -42,7 +42,7 @@ class FastEmbedEncoder(BaseEncoder): embedding = Embedding(**embedding_args) return embedding - def __call__(self, docs: list[str]) -> list[list[float]]: + def __call__(self, docs: List[str]) -> List[List[float]]: try: embeds: List[np.ndarray] = list(self._client.embed(docs)) embeddings: List[List[float]] = [e.tolist() for e in embeds] diff --git a/semantic_router/encoders/huggingface.py b/semantic_router/encoders/huggingface.py index ace189213b76aed940dd8b4280ce1505339f656f..2166ea13f68cb263d76fabb96b310501d58169fb 100644 --- a/semantic_router/encoders/huggingface.py +++ b/semantic_router/encoders/huggingface.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, List from pydantic import PrivateAttr @@ -60,11 +60,11 @@ class HuggingFaceEncoder(BaseEncoder): def __call__( self, - docs: list[str], + docs: List[str], batch_size: int = 32, normalize_embeddings: bool = True, pooling_strategy: str = "mean", - ) -> list[list[float]]: + ) -> List[List[float]]: all_embeddings = [] for i in range(0, len(docs), batch_size): batch_docs = docs[i : i + batch_size] diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 169761afa8f726a72439a534a69bac3ebf73de29..3b06d33de2a4ad01da3ad950feddf15731d332c8 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -1,6 +1,6 @@ import os from time import sleep -from typing import Optional +from typing import Optional, List import openai from openai import OpenAIError @@ -31,7 +31,7 @@ class OpenAIEncoder(BaseEncoder): except Exception as e: raise ValueError(f"OpenAI API client failed to initialize. Error: {e}") - def __call__(self, docs: list[str]) -> list[list[float]]: + def __call__(self, docs: List[str]) -> List[List[float]]: if self.client is None: raise ValueError("OpenAI client is not initialized.") embeds = None diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py new file mode 100644 index 0000000000000000000000000000000000000000..0809b5ad467fabac693aae234eb28cb51b0991a0 --- /dev/null +++ b/semantic_router/encoders/tfidf.py @@ -0,0 +1,80 @@ +import string +from collections import Counter +from typing import Dict + +import numpy as np +from numpy import ndarray +from numpy.linalg import norm + +from semantic_router.encoders import BaseEncoder +from semantic_router.route import Route + + +class TfidfEncoder(BaseEncoder): + idf: ndarray = np.array([]) + word_index: Dict = {} + + def __init__(self, name: str = "tfidf", score_threshold: float = 0.82): + # TODO default score_threshold not thoroughly tested, should optimize + super().__init__(name=name, score_threshold=score_threshold) + self.word_index = {} + self.idf = np.array([]) + + def __call__(self, docs: list[str]) -> list[list[float]]: + if len(self.word_index) == 0 or self.idf.size == 0: + raise ValueError("Vectorizer is not initialized.") + if len(docs) == 0: + raise ValueError("No documents to encode.") + + docs = [self._preprocess(doc) for doc in docs] + tf = self._compute_tf(docs) + tfidf = tf * self.idf + return tfidf.tolist() + + def fit(self, routes: list[Route]): + docs = [] + for route in routes: + for doc in route.utterances: + docs.append(self._preprocess(doc)) + self.word_index = self._build_word_index(docs) + self.idf = self._compute_idf(docs) + + def _build_word_index(self, docs: list[str]) -> dict: + words = set() + for doc in docs: + for word in doc.split(): + words.add(word) + word_index = {word: i for i, word in enumerate(words)} + return word_index + + def _compute_tf(self, docs: list[str]) -> np.ndarray: + if len(self.word_index) == 0: + raise ValueError("Word index is not initialized.") + tf = np.zeros((len(docs), len(self.word_index))) + for i, doc in enumerate(docs): + word_counts = Counter(doc.split()) + for word, count in word_counts.items(): + if word in self.word_index: + tf[i, self.word_index[word]] = count + # L2 normalization + tf = tf / norm(tf, axis=1, keepdims=True) + return tf + + def _compute_idf(self, docs: list[str]) -> np.ndarray: + if len(self.word_index) == 0: + raise ValueError("Word index is not initialized.") + idf = np.zeros(len(self.word_index)) + for doc in docs: + words = set(doc.split()) + for word in words: + if word in self.word_index: + idf[self.word_index[word]] += 1 + idf = np.log(len(docs) / (idf + 1)) + return idf + + def _preprocess(self, doc: str) -> str: + lowercased_doc = doc.lower() + no_punctuation_doc = lowercased_doc.translate( + str.maketrans("", "", string.punctuation) + ) + return no_punctuation_doc diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index d4c81b13c87749f070a2177896ed6c702484fde0..f3eb3e6427b0b7e7f55261cceffb8c5082db3f63 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List, Dict, Tuple import numpy as np from numpy.linalg import norm @@ -6,6 +6,7 @@ from numpy.linalg import norm from semantic_router.encoders import ( BaseEncoder, BM25Encoder, + TfidfEncoder, ) from semantic_router.route import Route from semantic_router.utils.logger import logger @@ -21,7 +22,7 @@ class HybridRouteLayer: self, encoder: BaseEncoder, sparse_encoder: Optional[BM25Encoder] = None, - routes: list[Route] = [], + routes: List[Route] = [], alpha: float = 0.3, ): self.encoder = encoder @@ -34,6 +35,11 @@ class HybridRouteLayer: self.sparse_encoder = sparse_encoder self.alpha = alpha + self.routes = routes + if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( + self.sparse_encoder, "fit" + ): + self.sparse_encoder.fit(routes) # if routes list has been passed, we initialize index now if routes: # initialize index now @@ -54,41 +60,39 @@ class HybridRouteLayer: self._add_route(route=route) def _add_route(self, route: Route): - # create embeddings - dense_embeds = np.array(self.encoder(route.utterances)) # * self.alpha - sparse_embeds = np.array( - self.sparse_encoder(route.utterances) - ) # * (1 - self.alpha) + self.routes += [route] + + self.update_dense_embeddings_index(route.utterances) + + if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( + self.sparse_encoder, "fit" + ): + self.sparse_encoder.fit(self.routes) + # re-build index + self.sparse_index = None + all_utterances = [ + utterance for route in self.routes for utterance in route.utterances + ] + self.update_sparse_embeddings_index(all_utterances) + else: + self.update_sparse_embeddings_index(route.utterances) # create route array if self.categories is None: self.categories = np.array([route.name] * len(route.utterances)) - self.utterances = np.array(route.utterances) else: str_arr = np.array([route.name] * len(route.utterances)) self.categories = np.concatenate([self.categories, str_arr]) - self.utterances = np.concatenate( - [self.utterances, np.array(route.utterances)] - ) - # create utterance array (the dense index) - if self.index is None: - self.index = dense_embeds - else: - self.index = np.concatenate([self.index, dense_embeds]) - # create sparse utterance array - if self.sparse_index is None: - self.sparse_index = sparse_embeds - else: - self.sparse_index = np.concatenate([self.sparse_index, sparse_embeds]) + self.routes.append(route) - def _add_routes(self, routes: list[Route]): + def _add_routes(self, routes: List[Route]): # create embeddings for all routes logger.info("Creating embeddings for all routes...") all_utterances = [ utterance for route in routes for utterance in route.utterances ] - dense_embeds = np.array(self.encoder(all_utterances)) - sparse_embeds = np.array(self.sparse_encoder(all_utterances)) + self.update_dense_embeddings_index(all_utterances) + self.update_sparse_embeddings_index(all_utterances) # create route array route_names = [route.name for route in routes for _ in route.utterances] @@ -99,6 +103,8 @@ class HybridRouteLayer: else route_array ) + def update_dense_embeddings_index(self, utterances: list): + dense_embeds = np.array(self.encoder(utterances)) # create utterance array (the dense index) self.index = ( np.concatenate([self.index, dense_embeds]) @@ -106,6 +112,8 @@ class HybridRouteLayer: else dense_embeds ) + def update_sparse_embeddings_index(self, utterances: list): + sparse_embeds = np.array(self.sparse_encoder(utterances)) # create sparse utterance array self.sparse_index = ( np.concatenate([self.sparse_index, sparse_embeds]) @@ -153,8 +161,8 @@ class HybridRouteLayer: sparse = np.array(sparse) * (1 - self.alpha) return dense, sparse - def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]: - scores_by_class: dict[str, list[float]] = {} + def _semantic_classify(self, query_results: List[Dict]) -> Tuple[str, List[float]]: + scores_by_class: Dict[str, List[float]] = {} for result in query_results: score = result["score"] route = result["route"] @@ -174,7 +182,7 @@ class HybridRouteLayer: logger.warning("No classification found for semantic classifier.") return "", [] - def _pass_threshold(self, scores: list[float], threshold: float) -> bool: + def _pass_threshold(self, scores: List[float], threshold: float) -> bool: if scores: return max(scores) > threshold else: diff --git a/semantic_router/layer.py b/semantic_router/layer.py index cf546bfc1d791c6843bd62ce5ad4f178b9f0254b..bce160ba7853e5f70d30ea9a223ee3d44630c40f 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -1,6 +1,6 @@ import json import os -from typing import Optional +from typing import Optional, Any, List, Dict, Tuple import numpy as np import yaml @@ -14,6 +14,7 @@ from semantic_router.utils.logger import logger def is_valid(layer_config: str) -> bool: + """Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]""" try: output_json = json.loads(layer_config) required_keys = ["encoder_name", "encoder_type", "routes"] @@ -47,11 +48,11 @@ class LayerConfig: RouteLayer. """ - routes: list[Route] = [] + routes: List[Route] = [] def __init__( self, - routes: list[Route] = [], + routes: List[Route] = [], encoder_type: str = "openai", encoder_name: Optional[str] = None, ): @@ -73,7 +74,7 @@ class LayerConfig: self.routes = routes @classmethod - def from_file(cls, path: str): + def from_file(cls, path: str) -> "LayerConfig": """Load the routes from a file in JSON or YAML format""" logger.info(f"Loading route config from {path}") _, ext = os.path.splitext(path) @@ -98,7 +99,7 @@ class LayerConfig: else: raise Exception("Invalid config JSON or YAML") - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "encoder_type": self.encoder_type, "encoder_name": self.encoder_name, @@ -157,7 +158,7 @@ class RouteLayer: self, encoder: Optional[BaseEncoder] = None, llm: Optional[BaseLLM] = None, - routes: Optional[list[Route]] = None, + routes: Optional[List[Route]] = None, top_k_routes: int = 3, ): logger.info("Initializing RouteLayer") @@ -246,7 +247,7 @@ class RouteLayer: # add route to routes list self.routes.append(route) - def _add_routes(self, routes: list[Route]): + def _add_routes(self, routes: List[Route]): # create embeddings for all routes all_utterances = [ utterance for route in routes for utterance in route.utterances @@ -289,8 +290,8 @@ class RouteLayer: logger.warning("No index found for route layer.") return [] - def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]: - scores_by_class: dict[str, list[float]] = {} + def _semantic_classify(self, query_results: List[dict]) -> Tuple[str, List[float]]: + scores_by_class: Dict[str, List[float]] = {} for result in query_results: score = result["score"] route = result["route"] @@ -310,7 +311,7 @@ class RouteLayer: logger.warning("No classification found for semantic classifier.") return "", [] - def _pass_threshold(self, scores: list[float], threshold: float) -> bool: + def _pass_threshold(self, scores: List[float], threshold: float) -> bool: if scores: return max(scores) > threshold else: diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index bf5f29b6005daaa76abc4674971dc8f775f4af80..12d89f2d31e1cd181346322daf01d0b206222a20 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List from pydantic import BaseModel @@ -11,5 +11,5 @@ class BaseLLM(BaseModel): class Config: arbitrary_types_allowed = True - def __call__(self, messages: list[Message]) -> Optional[str]: + def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_router/llms/cohere.py b/semantic_router/llms/cohere.py index 0ec21f354c090f0d1d00da7c20b89c5b233c3a89..0eebbe6d6e8385e66ed9df42b941a915fa144e22 100644 --- a/semantic_router/llms/cohere.py +++ b/semantic_router/llms/cohere.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Optional, List import cohere @@ -26,7 +26,7 @@ class CohereLLM(BaseLLM): except Exception as e: raise ValueError(f"Cohere API client failed to initialize. Error: {e}") - def __call__(self, messages: list[Message]) -> str: + def __call__(self, messages: List[Message]) -> str: if self.client is None: raise ValueError("Cohere client is not initialized.") try: diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 8b3442c742c2f268773bf551fb49ce0cd24645af..06d6865ca1ec095d04453aaf8deb7c8e8d5ef54e 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Optional, List import openai @@ -33,7 +33,7 @@ class OpenAILLM(BaseLLM): self.temperature = temperature self.max_tokens = max_tokens - def __call__(self, messages: list[Message]) -> str: + def __call__(self, messages: List[Message]) -> str: if self.client is None: raise ValueError("OpenAI client is not initialized.") try: diff --git a/semantic_router/llms/openrouter.py b/semantic_router/llms/openrouter.py index 4cc15d6bfedbfa67fb5957129d1ce901544dcb38..8c3efb8d1f67fc246f62116555368eafa1f36288 100644 --- a/semantic_router/llms/openrouter.py +++ b/semantic_router/llms/openrouter.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Optional, List import openai @@ -38,7 +38,7 @@ class OpenRouterLLM(BaseLLM): self.temperature = temperature self.max_tokens = max_tokens - def __call__(self, messages: list[Message]) -> str: + def __call__(self, messages: List[Message]) -> str: if self.client is None: raise ValueError("OpenRouter client is not initialized.") try: diff --git a/semantic_router/route.py b/semantic_router/route.py index 2289825071a39652aeaa467395103d71dc623140..bf24b14c13ca2b43d087d6574af1e9fc2fe14326 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -1,6 +1,6 @@ import json import re -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, List, Dict from pydantic import BaseModel @@ -40,9 +40,9 @@ def is_valid(route_config: str) -> bool: class Route(BaseModel): name: str - utterances: list[str] + utterances: List[str] description: Optional[str] = None - function_schema: Optional[dict[str, Any]] = None + function_schema: Optional[Dict[str, Any]] = None llm: Optional[BaseLLM] = None def __call__(self, query: str) -> RouteChoice: @@ -62,11 +62,11 @@ class Route(BaseModel): func_call = None return RouteChoice(name=self.name, function_call=func_call) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return self.dict() @classmethod - def from_dict(cls, data: dict[str, Any]): + def from_dict(cls, data: Dict[str, Any]): return cls(**data) @classmethod @@ -92,7 +92,7 @@ class Route(BaseModel): raise ValueError("No <config></config> tags found in the output.") @classmethod - def _generate_dynamic_route(cls, llm: BaseLLM, function_schema: dict[str, Any]): + def _generate_dynamic_route(cls, llm: BaseLLM, function_schema: Dict[str, Any]): logger.info("Generating dynamic route...") prompt = f""" diff --git a/semantic_router/schema.py b/semantic_router/schema.py index bb1a4c6ac26819ef7b13bec1f293fe6e51a66a7a..7dcb7fde1252088ab7736510c7f04fa32c3a6f6d 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Optional, Literal, List, Dict from pydantic import BaseModel from pydantic.dataclasses import dataclass @@ -47,7 +47,7 @@ class Encoder: else: raise ValueError - def __call__(self, texts: list[str]) -> list[list[float]]: + def __call__(self, texts: List[str]) -> List[List[float]]: return self.model(texts) @@ -65,14 +65,16 @@ class Message(BaseModel): class Conversation(BaseModel): - messages: list[Message] + messages: List[Message] def split_by_topic( self, encoder: BaseEncoder, threshold: float = 0.5, - split_method: str = "consecutive_similarity_drop", - ): + split_method: Literal[ + "consecutive_similarity_drop", "cumulative_similarity_drop" + ] = "consecutive_similarity_drop", + ) -> Dict[str, List[str]]: docs = [f"{m.role}: {m.content}" for m in self.messages] return semantic_splitter( encoder=encoder, docs=docs, threshold=threshold, split_method=split_method diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index 3c8b3277b3b4cdd8d3c2ebc1849ff3da4cbd1ca7..1b42a6133a3faeb0a001f646b45d2a842b4e7d4e 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -1,6 +1,6 @@ import inspect import json -from typing import Any, Callable, Union +from typing import Any, Callable, Union, Dict, List from pydantic import BaseModel @@ -9,7 +9,7 @@ from semantic_router.schema import Message, RouteChoice from semantic_router.utils.logger import logger -def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]: +def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]: if isinstance(item, BaseModel): signature_parts = [] for field_name, field_model in item.__annotations__.items(): @@ -42,8 +42,8 @@ def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]: def extract_function_inputs( - query: str, llm: BaseLLM, function_schema: dict[str, Any] -) -> dict: + query: str, llm: BaseLLM, function_schema: Dict[str, Any] +) -> Dict[str, Any]: logger.info("Extracting function input...") prompt = f""" @@ -90,7 +90,7 @@ Result: return function_inputs -def is_valid_inputs(inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool: +def is_valid_inputs(inputs: Dict[str, Any], function_schema: Dict[str, Any]) -> bool: """Validate the extracted inputs against the function schema""" try: # Extract parameter names and types from the signature string @@ -113,7 +113,7 @@ def is_valid_inputs(inputs: dict[str, Any], function_schema: dict[str, Any]) -> # TODO: Add route layer object to the input, solve circular import issue async def route_and_execute( - query: str, llm: BaseLLM, functions: list[Callable], layer + query: str, llm: BaseLLM, functions: List[Callable], layer ) -> Any: route_choice: RouteChoice = layer(query) diff --git a/semantic_router/utils/logger.py b/semantic_router/utils/logger.py index 00c83693435487016f819c4716900fc09f8b8b92..607f09d512a08b9d52afeaf8e9ebe73883870f35 100644 --- a/semantic_router/utils/logger.py +++ b/semantic_router/utils/logger.py @@ -40,4 +40,4 @@ def setup_custom_logger(name): return logger -logger = setup_custom_logger(__name__) +logger: logging.Logger = setup_custom_logger(__name__) diff --git a/semantic_router/utils/splitters.py b/semantic_router/utils/splitters.py index 746015204d702690a0eff289eaa1537c42658f23..83a32839c5efc3b528f9a14643c3f3db3571f3e3 100644 --- a/semantic_router/utils/splitters.py +++ b/semantic_router/utils/splitters.py @@ -1,14 +1,17 @@ import numpy as np +from typing import List, Dict, Literal from semantic_router.encoders import BaseEncoder def semantic_splitter( encoder: BaseEncoder, - docs: list[str], + docs: List[str], threshold: float, - split_method: str = "consecutive_similarity_drop", -) -> dict[str, list[str]]: + split_method: Literal[ + "consecutive_similarity_drop", "cumulative_similarity_drop" + ] = "consecutive_similarity_drop", +) -> Dict[str, List[str]]: """ Splits a list of documents base on semantic similarity changes. @@ -20,13 +23,13 @@ def semantic_splitter( Args: encoder (BaseEncoder): Encoder for document embeddings. - docs (list[str]): Documents to split. + docs (List[str]): Documents to split. threshold (float): The similarity drop value that will trigger a new document split. split_method (str): The method to use for splitting. Returns: - Dict[str, list[str]]: Splits with corresponding documents. + Dict[str, List[str]]: Splits with corresponding documents. """ total_docs = len(docs) splits = {} diff --git a/tests/unit/encoders/test_bm25.py b/tests/unit/encoders/test_bm25.py index 174453d254370c609c83a7f2a533b4d03ca264dc..73e52d5585a3391488d7198d64543a0b2f0dcc5c 100644 --- a/tests/unit/encoders/test_bm25.py +++ b/tests/unit/encoders/test_bm25.py @@ -27,7 +27,7 @@ class TestBM25Encoder: isinstance(sublist, list) for sublist in result ), "Each item in result should be a list" - def test_call_method_no_docs(self, bm25_encoder): + def test_call_method_no_docs_bm25_encoder(self, bm25_encoder): with pytest.raises(ValueError): bm25_encoder([]) diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py new file mode 100644 index 0000000000000000000000000000000000000000..7664433d070d74d06c00b7e0fb0336fa6fb7bc6e --- /dev/null +++ b/tests/unit/encoders/test_tfidf.py @@ -0,0 +1,82 @@ +import numpy as np +import pytest + +from semantic_router.encoders import TfidfEncoder +from semantic_router.route import Route + + +@pytest.fixture +def tfidf_encoder(): + return TfidfEncoder() + + +class TestTfidfEncoder: + def test_initialization(self, tfidf_encoder): + assert tfidf_encoder.word_index == {} + assert (tfidf_encoder.idf == np.array([])).all() + + def test_fit(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + assert tfidf_encoder.word_index != {} + assert not np.array_equal(tfidf_encoder.idf, np.array([])) + + def test_call_method(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + result = tfidf_encoder(["test"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(sublist, list) for sublist in result + ), "Each item in result should be a list" + + def test_call_method_no_docs_tfidf_encoder(self, tfidf_encoder): + with pytest.raises(ValueError): + tfidf_encoder([]) + + def test_call_method_no_word(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + result = tfidf_encoder(["doc with fake word gta5jabcxyz"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(sublist, list) for sublist in result + ), "Each item in result should be a list" + + def test_call_method_with_uninitialized_model(self, tfidf_encoder): + with pytest.raises(ValueError): + tfidf_encoder(["test"]) + + def test_compute_tf_no_word_index(self, tfidf_encoder): + with pytest.raises(ValueError, match="Word index is not initialized."): + tfidf_encoder._compute_tf(["some docs"]) + + def test_compute_tf_with_word_in_word_index(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + tf = tfidf_encoder._compute_tf(["some docs"]) + assert tf.shape == (1, len(tfidf_encoder.word_index)) + + def test_compute_idf_no_word_index(self, tfidf_encoder): + with pytest.raises(ValueError, match="Word index is not initialized."): + tfidf_encoder._compute_idf(["some docs"]) diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 50cae41505533b12727b7164b2088b3271bfc37d..aa7a33c0f631458f2583483669b6fd7041f9d845 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -6,6 +6,7 @@ from semantic_router.encoders import ( BM25Encoder, CohereEncoder, OpenAIEncoder, + TfidfEncoder, ) from semantic_router.hybrid_layer import HybridRouteLayer from semantic_router.route import Route @@ -24,8 +25,10 @@ def mock_encoder_call(utterances): @pytest.fixture -def base_encoder(): - return BaseEncoder(name="test-encoder", score_threshold=0.5) +def base_encoder(mocker): + mock_base_encoder = BaseEncoder(name="test-encoder", score_threshold=0.5) + mocker.patch.object(BaseEncoder, "__call__", return_value=[[0.1, 0.2, 0.3]]) + return mock_base_encoder @pytest.fixture @@ -41,6 +44,7 @@ def openai_encoder(mocker): @pytest.fixture + def azure_encoder(mocker): mocker.patch.object(AzureOpenAIEncoder, "__call__", side_effect=mock_encoder_call) return AzureOpenAIEncoder( @@ -51,6 +55,16 @@ def azure_encoder(mocker): model="test_model", ) +def bm25_encoder(mocker): + mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call) + return BM25Encoder(name="test-bm25-encoder") + +@pytest.fixture +def tfidf_encoder(mocker): + mocker.patch.object(TfidfEncoder, "__call__", side_effect=mock_encoder_call) + return TfidfEncoder(name="test-tfidf-encoder") + + @pytest.fixture def routes(): @@ -160,5 +174,18 @@ class TestHybridRouteLayer: assert base_encoder.score_threshold == 0.50 assert route_layer.score_threshold == 0.50 + def test_add_route_tfidf(self, cohere_encoder, tfidf_encoder, routes): + hybrid_route_layer = HybridRouteLayer( + encoder=cohere_encoder, + sparse_encoder=tfidf_encoder, + routes=routes[:-1], + ) + hybrid_route_layer.add(routes[-1]) + all_utterances = [ + utterance for route in routes for utterance in route.utterances + ] + assert hybrid_route_layer.sparse_index is not None + assert len(hybrid_route_layer.sparse_index) == len(all_utterances) + # Add more tests for edge cases and error handling as needed.