diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5ab1b013a1f7c0a000a50afb7ee923dc8ffcc8b6..cf22ea9b75332a14d994c57a58f946dc7b53b807 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,7 +35,7 @@ jobs: cache: poetry - name: Install dependencies run: | - poetry install + poetry install --all-extras - name: Install nltk run: | pip install nltk diff --git a/docs/00-introduction.ipynb b/docs/00-introduction.ipynb index 95222c2a7eedfa8807ef70adb902c9d82c81c457..ae6b768d54baa59a6d6117c41fe3df736ca3584e 100644 --- a/docs/00-introduction.ipynb +++ b/docs/00-introduction.ipynb @@ -157,7 +157,7 @@ "source": [ "from semantic_router.layer import RouteLayer\n", "\n", - "dl = RouteLayer(encoder=encoder, routes=routes)" + "rl = RouteLayer(encoder=encoder, routes=routes)" ] }, { @@ -184,7 +184,7 @@ } ], "source": [ - "dl(\"don't you love politics?\")" + "rl(\"don't you love politics?\")" ] }, { @@ -204,7 +204,7 @@ } ], "source": [ - "dl(\"how's the weather today?\")" + "rl(\"how's the weather today?\")" ] }, { @@ -231,7 +231,7 @@ } ], "source": [ - "dl(\"I'm interested in learning about llama 2\")" + "rl(\"I'm interested in learning about llama 2\")" ] }, { diff --git a/docs/01-save-load-from-file.ipynb b/docs/01-save-load-from-file.ipynb index 6f084a9aedfb44e4a1c77c518584a664e60b86f4..715679ceedcb95b5b34b6812382e57c114a612ea 100644 --- a/docs/01-save-load-from-file.ipynb +++ b/docs/01-save-load-from-file.ipynb @@ -132,7 +132,7 @@ " \"Enter Cohere API Key: \"\n", ")\n", "\n", - "layer = RouteLayer(routes=routes)" + "rl = RouteLayer(routes=routes)" ] }, { @@ -156,7 +156,7 @@ } ], "source": [ - "layer.to_json(\"layer.json\")" + "rl.to_json(\"layer.json\")" ] }, { @@ -190,9 +190,9 @@ "import json\n", "\n", "with open(\"layer.json\", \"r\") as f:\n", - " router_json = json.load(f)\n", + " layer_json = json.load(f)\n", "\n", - "print(router_json)" + "print(layer_json)" ] }, { @@ -217,7 +217,7 @@ } ], "source": [ - "layer = RouteLayer.from_json(\"layer.json\")" + "rl = RouteLayer.from_json(\"layer.json\")" ] }, { @@ -244,9 +244,9 @@ ], "source": [ "print(\n", - " f\"\"\"{layer.encoder.type=}\n", - "{layer.encoder.name=}\n", - "{layer.routes=}\"\"\"\n", + " f\"\"\"{rl.encoder.type=}\n", + "{rl.encoder.name=}\n", + "{rl.routes=}\"\"\"\n", ")" ] }, diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb index 2b17da17cbee79d186134144d93563bfce08da82..d8078cb203837c62b3a307b76b6378c47827c0f7 100644 --- a/docs/02-dynamic-routes.ipynb +++ b/docs/02-dynamic-routes.ipynb @@ -125,7 +125,7 @@ " \"Enter Cohere API Key: \"\n", ")\n", "\n", - "layer = RouteLayer(routes=routes)" + "rl = RouteLayer(routes=routes)" ] }, { @@ -152,7 +152,7 @@ } ], "source": [ - "layer(\"how's the weather today?\")" + "rl(\"how's the weather today?\")" ] }, { @@ -291,7 +291,7 @@ } ], "source": [ - "layer.add(time_route)" + "rl.add(time_route)" ] }, { @@ -330,7 +330,7 @@ " \"Enter OpenRouter API Key: \"\n", ")\n", "\n", - "layer(\"what is the time in new york city?\")" + "rl(\"what is the time in new york city?\")" ] }, { diff --git a/docs/03-basic-langchain-agent.ipynb b/docs/03-basic-langchain-agent.ipynb index 09294c780e9f7ad60e20ae11f960e37954c1f98a..3bfd3ba52bb97ed97084b0ea58cb5d98a5dde08d 100644 --- a/docs/03-basic-langchain-agent.ipynb +++ b/docs/03-basic-langchain-agent.ipynb @@ -223,7 +223,7 @@ "from semantic_router import RouteLayer\n", "from semantic_router.encoders import OpenAIEncoder\n", "\n", - "layer = RouteLayer(encoder=OpenAIEncoder(), routes=routes)" + "rl = RouteLayer(encoder=OpenAIEncoder(), routes=routes)" ] }, { @@ -258,7 +258,7 @@ } ], "source": [ - "layer(\"should I buy ON whey or MP?\")" + "rl(\"should I buy ON whey or MP?\")" ] }, { @@ -284,7 +284,7 @@ } ], "source": [ - "layer(\"how's the weather today?\")" + "rl(\"how's the weather today?\")" ] }, { @@ -310,7 +310,7 @@ } ], "source": [ - "layer(\"how do I get big arms?\")" + "rl(\"how do I get big arms?\")" ] }, { @@ -382,7 +382,7 @@ "outputs": [], "source": [ "def semantic_layer(query: str):\n", - " route = layer(query)\n", + " route = rl(query)\n", " if route.name == \"get_time\":\n", " query += f\" (SYSTEM NOTE: {get_time()})\"\n", " elif route.name == \"supplement_brand\":\n", diff --git a/poetry.lock b/poetry.lock index effd3033950ff6543f67f1d9ff5517e811f28d79..815226ea246d9037bd0e0a3f6460b09ca49959a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -472,7 +472,7 @@ files = [ name = "coloredlogs" version = "15.0.1" description = "Colored terminal output for Python's logging module" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"}, @@ -727,7 +727,7 @@ zstandard = ["zstandard"] name = "fastembed" version = "0.1.3" description = "Fast, light, accurate library built for retrieval embedding generation" -optional = false +optional = true python-versions = ">=3.8.0,<3.12" files = [ {file = "fastembed-0.1.3-py3-none-any.whl", hash = "sha256:98b6c6d9effec8c96d97048e59cdd53627b16a70fcdbfa7c663772de66e11b3a"}, @@ -746,7 +746,7 @@ tqdm = ">=4.65,<5.0" name = "filelock" version = "3.13.1" description = "A platform independent file lock." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, @@ -762,7 +762,7 @@ typing = ["typing-extensions (>=4.8)"] name = "flatbuffers" version = "23.5.26" description = "The FlatBuffers serialization format for Python" -optional = false +optional = true python-versions = "*" files = [ {file = "flatbuffers-23.5.26-py2.py3-none-any.whl", hash = "sha256:c0ff356da363087b915fde4b8b45bdda73432fc17cddb3c8157472eab1422ad1"}, @@ -859,7 +859,7 @@ files = [ name = "fsspec" version = "2023.12.2" description = "File-system specification" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "fsspec-2023.12.2-py3-none-any.whl", hash = "sha256:d800d87f72189a745fa3d6b033b9dc4a34ad069f60ca60b943a63599f5501960"}, @@ -950,7 +950,7 @@ socks = ["socksio (==1.*)"] name = "huggingface-hub" version = "0.19.4" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -optional = false +optional = true python-versions = ">=3.8.0" files = [ {file = "huggingface_hub-0.19.4-py3-none-any.whl", hash = "sha256:dba013f779da16f14b606492828f3760600a1e1801432d09fe1c33e50b825bb5"}, @@ -983,7 +983,7 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t name = "humanfriendly" version = "10.0" description = "Human friendly output for text interfaces using Python" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ {file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"}, @@ -1127,7 +1127,7 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "joblib" version = "1.3.2" description = "Lightweight pipelining with Python functions" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"}, @@ -1195,7 +1195,7 @@ traitlets = "*" name = "mmh3" version = "3.1.0" description = "Python wrapper for MurmurHash (MurmurHash3), a set of fast and robust hash functions." -optional = false +optional = true python-versions = "*" files = [ {file = "mmh3-3.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:16ee043b1bac040b4324b8baee39df9fdca480a560a6d74f2eef66a5009a234e"}, @@ -1239,7 +1239,7 @@ files = [ name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" -optional = false +optional = true python-versions = "*" files = [ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, @@ -1408,7 +1408,7 @@ files = [ name = "nltk" version = "3.8.1" description = "Natural Language Toolkit" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"}, @@ -1467,7 +1467,7 @@ files = [ name = "onnx" version = "1.15.0" description = "Open Neural Network Exchange" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "onnx-1.15.0-cp310-cp310-macosx_10_12_universal2.whl", hash = "sha256:51cacb6aafba308aaf462252ced562111f6991cdc7bc57a6c554c3519453a8ff"}, @@ -1508,7 +1508,7 @@ reference = ["Pillow", "google-re2"] name = "onnxruntime" version = "1.16.3" description = "ONNX Runtime is a runtime accelerator for Machine Learning models" -optional = false +optional = true python-versions = "*" files = [ {file = "onnxruntime-1.16.3-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:3bc41f323ac77acfed190be8ffdc47a6a75e4beeb3473fbf55eeb075ccca8df2"}, @@ -1623,7 +1623,7 @@ ptyprocess = ">=0.5" name = "pinecone-text" version = "0.7.1" description = "Text utilities library by Pinecone.io" -optional = false +optional = true python-versions = ">=3.8,<4.0" files = [ {file = "pinecone_text-0.7.1-py3-none-any.whl", hash = "sha256:b806b5d66190d09888ed2d3bcdef49534aa9200b9da521371a062e6ccc79bb2c"}, @@ -1690,7 +1690,7 @@ wcwidth = "*" name = "protobuf" version = "4.25.1" description = "" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "protobuf-4.25.1-cp310-abi3-win32.whl", hash = "sha256:193f50a6ab78a970c9b4f148e7c750cfde64f59815e86f686c22e26b4fe01ce7"}, @@ -1841,7 +1841,7 @@ windows-terminal = ["colorama (>=0.4.6)"] name = "pyreadline3" version = "3.4.1" description = "A python implementation of GNU readline." -optional = false +optional = true python-versions = "*" files = [ {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"}, @@ -1870,24 +1870,6 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] -[[package]] -name = "pytest-asyncio" -version = "0.23.3" -description = "Pytest support for asyncio" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pytest-asyncio-0.23.3.tar.gz", hash = "sha256:af313ce900a62fbe2b1aed18e37ad757f1ef9940c6b6a88e2954de38d6b1fb9f"}, - {file = "pytest_asyncio-0.23.3-py3-none-any.whl", hash = "sha256:37a9d912e8338ee7b4a3e917381d1c95bfc8682048cb0fbc35baba316ec1faba"}, -] - -[package.dependencies] -pytest = ">=7.0.0" - -[package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] -testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] - [[package]] name = "pytest-cov" version = "4.1.0" @@ -1992,6 +1974,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1999,8 +1982,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2017,6 +2007,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2024,6 +2015,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -2138,7 +2130,7 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "regex" version = "2023.12.25" description = "Alternative regular expression module, to replace re." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "regex-2023.12.25-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0694219a1d54336fd0445ea382d49d36882415c0134ee1e8332afd1529f0baa5"}, @@ -2328,7 +2320,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "sympy" version = "1.12" description = "Computer algebra system (CAS) in Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, @@ -2353,7 +2345,7 @@ files = [ name = "tokenizers" version = "0.15.0" description = "" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "tokenizers-0.15.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:cd3cd0299aaa312cd2988957598f80becd04d5a07338741eca076057a2b37d6e"}, @@ -2570,20 +2562,20 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "wcwidth" -version = "0.2.12" +version = "0.2.13" description = "Measures the displayed width of unicode strings in a terminal" optional = false python-versions = "*" files = [ - {file = "wcwidth-0.2.12-py2.py3-none-any.whl", hash = "sha256:f26ec43d96c8cbfed76a5075dac87680124fa84e0855195a6184da9c187f133c"}, - {file = "wcwidth-0.2.12.tar.gz", hash = "sha256:f01c104efdf57971bcb756f054dd58ddec5204dd15fa31d6503ea57947d97c02"}, + {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, + {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, ] [[package]] name = "wget" version = "3.2" description = "pure python download utility" -optional = false +optional = true python-versions = "*" files = [ {file = "wget-3.2.zip", hash = "sha256:35e630eca2aa50ce998b9b1a127bb26b30dfee573702782aa982f875e3f16061"}, @@ -2707,7 +2699,11 @@ files = [ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +[extras] +fastembed = ["fastembed"] +hybrid = ["pinecone-text"] + [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "887cf3e564b33d43b6bdcf5d089d9eb12931312e8eeadd3f8488cb5fbe384fab" +content-hash = "64f0fef330108fe47110c203bf96403e8d986f8b751f6eed1abfec3ce57539a6" diff --git a/pyproject.toml b/pyproject.toml index 9ce7f26d2d47c44904a5710604a385c5a4549936..d3561c644e7a7acab7a08ef699132d8cea615200 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,12 +18,14 @@ pydantic = "^1.8.2" openai = "^1.3.9" cohere = "^4.32" numpy = "^1.25.2" -pinecone-text = "^0.7.0" +pinecone-text = {version = "^0.7.0", optional = true} colorlog = "^6.8.0" pyyaml = "^6.0.1" -pytest-asyncio = "^0.23.2" -fastembed = "^0.1.3" +fastembed = {version = "^0.1.3", optional = true} +[tool.poetry.extras] +hybrid = ["pinecone-text"] +fastembed = ["fastembed"] [tool.poetry.group.dev.dependencies] ipykernel = "^6.26.0" diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index f43e1780cace53529bc7a0b5b5f9eb15a98fd9da..e597b4fee78ecbc92d2726e88e83536347d36c8e 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -1,7 +1,5 @@ from typing import Any -from pinecone_text.sparse import BM25Encoder as encoder - from semantic_router.encoders import BaseEncoder @@ -12,6 +10,13 @@ class BM25Encoder(BaseEncoder): def __init__(self, name: str = "bm25"): super().__init__(name=name) + try: + from pinecone_text.sparse import BM25Encoder as encoder + except ImportError: + raise ImportError( + "Please install pinecone-text to use BM25Encoder. " + "You can install it with: `pip install semantic-router[hybrid]`" + ) self.model = encoder.default() params = self.model.get_params() diff --git a/semantic_router/encoders/fastembed.py b/semantic_router/encoders/fastembed.py index 4bb46b85836a8b3beb9988068b9264c2d2c70d04..fb845ce7278470f790b765d19cba0c318e04c6fb 100644 --- a/semantic_router/encoders/fastembed.py +++ b/semantic_router/encoders/fastembed.py @@ -1,12 +1,14 @@ from typing import Any, List, Optional import numpy as np -from pydantic import BaseModel, PrivateAttr +from pydantic import PrivateAttr +from semantic_router.encoders import BaseEncoder -class FastEmbedEncoder(BaseModel): + +class FastEmbedEncoder(BaseEncoder): type: str = "fastembed" - model_name: str = "BAAI/bge-small-en-v1.5" + name: str = "BAAI/bge-small-en-v1.5" max_length: int = 512 cache_dir: Optional[str] = None threads: Optional[int] = None @@ -21,12 +23,13 @@ class FastEmbedEncoder(BaseModel): from fastembed.embedding import FlagEmbedding as Embedding except ImportError: raise ImportError( - "Please install fastembed to use FastEmbedEncoder" - "You can install it with: `pip install fastembed`" + "Please install fastembed to use FastEmbedEncoder. " + "You can install it with: " + "`pip install semantic-router[fastembed]`" ) embedding_args = { - "model_name": self.model_name, + "model_name": self.name, "max_length": self.max_length, "cache_dir": self.cache_dir, "threads": self.threads, diff --git a/semantic_router/layer.py b/semantic_router/layer.py index e2c8286be75d7ed6752d5e341ceff95660583822..6d85508cf8ca2b43306ee745e6501b844d9074eb 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -8,6 +8,7 @@ from semantic_router.encoders import ( BaseEncoder, CohereEncoder, OpenAIEncoder, + FastEmbedEncoder, ) from semantic_router.linear import similarity_matrix, top_scores from semantic_router.route import Route @@ -60,10 +61,14 @@ class LayerConfig: self.encoder_type = encoder_type if encoder_name is None: # if encoder_name is not provided, use the default encoder for type + # TODO base these values on default values in encoders themselves.. + # TODO without initializing them (as this is just config) if encoder_type == EncoderType.OPENAI: encoder_name = "text-embedding-ada-002" elif encoder_type == EncoderType.COHERE: encoder_name = "embed-english-v3.0" + elif encoder_type == EncoderType.FASTEMBED: + encoder_name = "BAAI/bge-small-en-v1.5" elif encoder_type == EncoderType.HUGGINGFACE: raise NotImplementedError logger.info(f"Using default {encoder_type} encoder: {encoder_name}") @@ -159,10 +164,14 @@ class RouteLayer: self.encoder = encoder if encoder is not None else CohereEncoder() self.routes: list[Route] = routes if routes is not None else [] # decide on default threshold based on encoder + # TODO move defaults to the encoder objects and extract from there if isinstance(encoder, OpenAIEncoder): self.score_threshold = 0.82 elif isinstance(encoder, CohereEncoder): self.score_threshold = 0.3 + elif isinstance(encoder, FastEmbedEncoder): + # TODO default not thoroughly tested, should optimize + self.score_threshold = 0.5 else: self.score_threshold = 0.82 # if routes list has been passed, we initialize index now diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 360442f65b3dc75f0c96953aeaa22803ad28f962..644803556176fbd5f7c70ee4040b686c9207c38e 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -7,12 +7,14 @@ from semantic_router.encoders import ( BaseEncoder, CohereEncoder, OpenAIEncoder, + FastEmbedEncoder, ) from semantic_router.utils.splitters import semantic_splitter class EncoderType(Enum): HUGGINGFACE = "huggingface" + FASTEMBED = "fastembed" OPENAI = "openai" COHERE = "cohere" @@ -33,10 +35,12 @@ class Encoder: self.name = name if self.type == EncoderType.HUGGINGFACE: raise NotImplementedError + elif self.type == EncoderType.FASTEMBED: + self.model = FastEmbedEncoder(name=name) elif self.type == EncoderType.OPENAI: - self.model = OpenAIEncoder(name) + self.model = OpenAIEncoder(name=name) elif self.type == EncoderType.COHERE: - self.model = CohereEncoder(name) + self.model = CohereEncoder(name=name) else: raise ValueError