diff --git a/coverage.xml b/coverage.xml index 9af9ebee27365dd1289c5962a87b8451a3feef7c..9c68d3d96892e722a958cbb3e9bca6fa3d888203 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,12 +1,12 @@ <?xml version="1.0" ?> -<coverage version="7.3.3" timestamp="1702894511196" lines-valid="345" lines-covered="345" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> - <!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.3 --> +<coverage version="7.3.2" timestamp="1702996740775" lines-valid="383" lines-covered="369" line-rate="0.9634" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> + <!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.2 --> <!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd --> <sources> - <source>/Users/jakit/customers/aurelio/semantic-router/semantic_router</source> + <source>/Users/danielgriffiths/Coding_files/Aurelio_local/semantic-router/semantic_router</source> </sources> <packages> - <package name="." line-rate="1" branch-rate="0" complexity="0"> + <package name="." line-rate="0.9823" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> @@ -16,7 +16,7 @@ <line number="4" hits="1"/> </lines> </class> - <class name="hybrid_layer.py" filename="hybrid_layer.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="hybrid_layer.py" filename="hybrid_layer.py" complexity="0" line-rate="0.9604" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -31,84 +31,95 @@ <line number="18" hits="1"/> <line number="19" hits="1"/> <line number="21" hits="1"/> - <line number="24" hits="1"/> - <line number="25" hits="1"/> - <line number="26" hits="1"/> <line number="28" hits="1"/> <line number="29" hits="1"/> <line number="30" hits="1"/> <line number="31" hits="1"/> <line number="33" hits="1"/> + <line number="34" hits="1"/> <line number="35" hits="1"/> - <line number="37" hits="1"/> + <line number="36" hits="1"/> <line number="38" hits="1"/> <line number="40" hits="1"/> <line number="41" hits="1"/> <line number="42" hits="1"/> - <line number="43" hits="1"/> <line number="44" hits="1"/> <line number="45" hits="1"/> <line number="47" hits="1"/> + <line number="48" hits="1"/> <line number="49" hits="1"/> <line number="50" hits="1"/> + <line number="51" hits="1"/> <line number="52" hits="1"/> <line number="54" hits="1"/> - <line number="55" hits="1"/> - <line number="60" hits="1"/> - <line number="61" hits="1"/> + <line number="56" hits="1"/> + <line number="57" hits="1"/> + <line number="58" hits="0"/> + <line number="59" hits="0"/> + <line number="60" hits="0"/> + <line number="61" hits="0"/> <line number="62" hits="1"/> - <line number="64" hits="1"/> + <line number="63" hits="1"/> <line number="65" hits="1"/> - <line number="66" hits="1"/> + <line number="67" hits="1"/> + <line number="68" hits="1"/> <line number="70" hits="1"/> <line number="71" hits="1"/> - <line number="73" hits="1"/> + <line number="72" hits="1"/> + <line number="74" hits="1"/> <line number="75" hits="1"/> <line number="76" hits="1"/> - <line number="78" hits="1"/> <line number="80" hits="1"/> + <line number="81" hits="1"/> + <line number="83" hits="1"/> <line number="85" hits="1"/> <line number="86" hits="1"/> - <line number="88" hits="1"/> - <line number="89" hits="1"/> + <line number="90" hits="1"/> <line number="91" hits="1"/> <line number="93" hits="1"/> <line number="95" hits="1"/> - <line number="96" hits="1"/> - <line number="97" hits="1"/> - <line number="99" hits="1"/> <line number="100" hits="1"/> <line number="101" hits="1"/> - <line number="102" hits="1"/> + <line number="103" hits="1"/> <line number="104" hits="1"/> - <line number="105" hits="1"/> <line number="106" hits="1"/> <line number="108" hits="1"/> - <line number="109" hits="1"/> + <line number="110" hits="1"/> <line number="111" hits="1"/> <line number="112" hits="1"/> <line number="114" hits="1"/> + <line number="115" hits="1"/> <line number="116" hits="1"/> <line number="117" hits="1"/> - <line number="118" hits="1"/> + <line number="119" hits="1"/> <line number="120" hits="1"/> <line number="121" hits="1"/> - <line number="122" hits="1"/> <line number="123" hits="1"/> <line number="124" hits="1"/> - <line number="125" hits="1"/> <line number="126" hits="1"/> - <line number="128" hits="1"/> + <line number="127" hits="1"/> + <line number="129" hits="1"/> <line number="131" hits="1"/> <line number="132" hits="1"/> + <line number="133" hits="1"/> <line number="135" hits="1"/> <line number="136" hits="1"/> + <line number="137" hits="1"/> <line number="138" hits="1"/> <line number="139" hits="1"/> + <line number="140" hits="1"/> <line number="141" hits="1"/> - <line number="142" hits="1"/> <line number="143" hits="1"/> - <line number="145" hits="1"/> + <line number="146" hits="1"/> + <line number="147" hits="1"/> + <line number="150" hits="1"/> + <line number="151" hits="1"/> + <line number="153" hits="1"/> + <line number="154" hits="1"/> + <line number="156" hits="1"/> + <line number="157" hits="1"/> + <line number="158" hits="1"/> + <line number="160" hits="1"/> </lines> </class> <class name="layer.py" filename="layer.py" complexity="0" line-rate="1" branch-rate="0"> @@ -250,7 +261,7 @@ </class> </classes> </package> - <package name="encoders" line-rate="1" branch-rate="0" complexity="0"> + <package name="encoders" line-rate="0.927" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="encoders/__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> @@ -259,7 +270,8 @@ <line number="2" hits="1"/> <line number="3" hits="1"/> <line number="4" hits="1"/> - <line number="6" hits="1"/> + <line number="5" hits="1"/> + <line number="7" hits="1"/> </lines> </class> <class name="base.py" filename="encoders/base.py" complexity="0" line-rate="1" branch-rate="0"> @@ -274,7 +286,7 @@ <line number="11" hits="1"/> </lines> </class> - <class name="bm25.py" filename="encoders/bm25.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="bm25.py" filename="encoders/bm25.py" complexity="0" line-rate="0.9722" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -298,7 +310,7 @@ <line number="27" hits="1"/> <line number="28" hits="1"/> <line number="29" hits="1"/> - <line number="30" hits="1"/> + <line number="30" hits="0"/> <line number="32" hits="1"/> <line number="34" hits="1"/> <line number="35" hits="1"/> @@ -387,6 +399,37 @@ <line number="58" hits="1"/> </lines> </class> + <class name="tfidf.py" filename="encoders/tfidf.py" complexity="0" line-rate="0.6538" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="2" hits="1"/> + <line number="3" hits="1"/> + <line number="6" hits="1"/> + <line number="7" hits="1"/> + <line number="9" hits="1"/> + <line number="10" hits="1"/> + <line number="11" hits="1"/> + <line number="13" hits="1"/> + <line number="14" hits="0"/> + <line number="15" hits="0"/> + <line number="16" hits="0"/> + <line number="17" hits="0"/> + <line number="19" hits="0"/> + <line number="20" hits="0"/> + <line number="22" hits="1"/> + <line number="23" hits="1"/> + <line number="24" hits="0"/> + <line number="25" hits="1"/> + <line number="26" hits="1"/> + <line number="28" hits="1"/> + <line number="29" hits="1"/> + <line number="30" hits="1"/> + <line number="31" hits="0"/> + <line number="32" hits="0"/> + <line number="33" hits="1"/> + </lines> + </class> </classes> </package> <package name="utils" line-rate="1" branch-rate="0" complexity="0"> diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 1257e0a18bbd6db47f1cbfd7b678eccaef183367..9c0a02fc16245578508dfd0dfb6e116cdbdb74c0 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -143,7 +143,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:00<00:00, 4.22it/s]\n" + "100%|██████████| 2/2 [00:00<00:00, 3.41it/s]\n" ] } ], diff --git a/poetry.lock b/poetry.lock index 5d9fc23d66653568ccc0e06cfd4f6efb70dff080..0a9be4e117e6f4b6b8cc805b44fd7624d109bd3d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1935,95 +1935,6 @@ files = [ {file = "ruff-0.1.8.tar.gz", hash = "sha256:f7ee467677467526cfe135eab86a40a0e8db43117936ac4f9b469ce9cdb3fb62"}, ] -[[package]] -name = "scikit-learn" -version = "1.3.2" -description = "A set of python modules for machine learning and data mining" -optional = false -python-versions = ">=3.8" -files = [ - {file = "scikit-learn-1.3.2.tar.gz", hash = "sha256:a2f54c76accc15a34bfb9066e6c7a56c1e7235dda5762b990792330b52ccfb05"}, - {file = "scikit_learn-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e326c0eb5cf4d6ba40f93776a20e9a7a69524c4db0757e7ce24ba222471ee8a1"}, - {file = "scikit_learn-1.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:535805c2a01ccb40ca4ab7d081d771aea67e535153e35a1fd99418fcedd1648a"}, - {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1215e5e58e9880b554b01187b8c9390bf4dc4692eedeaf542d3273f4785e342c"}, - {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ee107923a623b9f517754ea2f69ea3b62fc898a3641766cb7deb2f2ce450161"}, - {file = "scikit_learn-1.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:35a22e8015048c628ad099da9df5ab3004cdbf81edc75b396fd0cff8699ac58c"}, - {file = "scikit_learn-1.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6fb6bc98f234fda43163ddbe36df8bcde1d13ee176c6dc9b92bb7d3fc842eb66"}, - {file = "scikit_learn-1.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:18424efee518a1cde7b0b53a422cde2f6625197de6af36da0b57ec502f126157"}, - {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3271552a5eb16f208a6f7f617b8cc6d1f137b52c8a1ef8edf547db0259b2c9fb"}, - {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4144a5004a676d5022b798d9e573b05139e77f271253a4703eed295bde0433"}, - {file = "scikit_learn-1.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:67f37d708f042a9b8d59551cf94d30431e01374e00dc2645fa186059c6c5d78b"}, - {file = "scikit_learn-1.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8db94cd8a2e038b37a80a04df8783e09caac77cbe052146432e67800e430c028"}, - {file = "scikit_learn-1.3.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:61a6efd384258789aa89415a410dcdb39a50e19d3d8410bd29be365bcdd512d5"}, - {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb06f8dce3f5ddc5dee1715a9b9f19f20d295bed8e3cd4fa51e1d050347de525"}, - {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b2de18d86f630d68fe1f87af690d451388bb186480afc719e5f770590c2ef6c"}, - {file = "scikit_learn-1.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:0402638c9a7c219ee52c94cbebc8fcb5eb9fe9c773717965c1f4185588ad3107"}, - {file = "scikit_learn-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a19f90f95ba93c1a7f7924906d0576a84da7f3b2282ac3bfb7a08a32801add93"}, - {file = "scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b8692e395a03a60cd927125eef3a8e3424d86dde9b2370d544f0ea35f78a8073"}, - {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15e1e94cc23d04d39da797ee34236ce2375ddea158b10bee3c343647d615581d"}, - {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:785a2213086b7b1abf037aeadbbd6d67159feb3e30263434139c98425e3dcfcf"}, - {file = "scikit_learn-1.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:64381066f8aa63c2710e6b56edc9f0894cc7bf59bd71b8ce5613a4559b6145e0"}, - {file = "scikit_learn-1.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c43290337f7a4b969d207e620658372ba3c1ffb611f8bc2b6f031dc5c6d1d03"}, - {file = "scikit_learn-1.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dc9002fc200bed597d5d34e90c752b74df516d592db162f756cc52836b38fe0e"}, - {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d08ada33e955c54355d909b9c06a4789a729977f165b8bae6f225ff0a60ec4a"}, - {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f0ae4b79b0ff9cca0bf3716bcc9915bdacff3cebea15ec79652d1cc4fa5c9"}, - {file = "scikit_learn-1.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:ed932ea780517b00dae7431e031faae6b49b20eb6950918eb83bd043237950e0"}, -] - -[package.dependencies] -joblib = ">=1.1.1" -numpy = ">=1.17.3,<2.0" -scipy = ">=1.5.0" -threadpoolctl = ">=2.0.0" - -[package.extras] -benchmark = ["matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "pandas (>=1.0.5)"] -docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.10.1)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] -examples = ["matplotlib (>=3.1.3)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)"] -tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.0.272)", "scikit-image (>=0.16.2)"] - -[[package]] -name = "scipy" -version = "1.11.4" -description = "Fundamental algorithms for scientific computing in Python" -optional = false -python-versions = ">=3.9" -files = [ - {file = "scipy-1.11.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc9a714581f561af0848e6b69947fda0614915f072dfd14142ed1bfe1b806710"}, - {file = "scipy-1.11.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:cf00bd2b1b0211888d4dc75656c0412213a8b25e80d73898083f402b50f47e41"}, - {file = "scipy-1.11.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9999c008ccf00e8fbcce1236f85ade5c569d13144f77a1946bef8863e8f6eb4"}, - {file = "scipy-1.11.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:933baf588daa8dc9a92c20a0be32f56d43faf3d1a60ab11b3f08c356430f6e56"}, - {file = "scipy-1.11.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8fce70f39076a5aa62e92e69a7f62349f9574d8405c0a5de6ed3ef72de07f446"}, - {file = "scipy-1.11.4-cp310-cp310-win_amd64.whl", hash = "sha256:6550466fbeec7453d7465e74d4f4b19f905642c89a7525571ee91dd7adabb5a3"}, - {file = "scipy-1.11.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f313b39a7e94f296025e3cffc2c567618174c0b1dde173960cf23808f9fae4be"}, - {file = "scipy-1.11.4-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1b7c3dca977f30a739e0409fb001056484661cb2541a01aba0bb0029f7b68db8"}, - {file = "scipy-1.11.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00150c5eae7b610c32589dda259eacc7c4f1665aedf25d921907f4d08a951b1c"}, - {file = "scipy-1.11.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:530f9ad26440e85766509dbf78edcfe13ffd0ab7fec2560ee5c36ff74d6269ff"}, - {file = "scipy-1.11.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5e347b14fe01003d3b78e196e84bd3f48ffe4c8a7b8a1afbcb8f5505cb710993"}, - {file = "scipy-1.11.4-cp311-cp311-win_amd64.whl", hash = "sha256:acf8ed278cc03f5aff035e69cb511741e0418681d25fbbb86ca65429c4f4d9cd"}, - {file = "scipy-1.11.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:028eccd22e654b3ea01ee63705681ee79933652b2d8f873e7949898dda6d11b6"}, - {file = "scipy-1.11.4-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c6ff6ef9cc27f9b3db93a6f8b38f97387e6e0591600369a297a50a8e96e835d"}, - {file = "scipy-1.11.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b030c6674b9230d37c5c60ab456e2cf12f6784596d15ce8da9365e70896effc4"}, - {file = "scipy-1.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad669df80528aeca5f557712102538f4f37e503f0c5b9541655016dd0932ca79"}, - {file = "scipy-1.11.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ce7fff2e23ab2cc81ff452a9444c215c28e6305f396b2ba88343a567feec9660"}, - {file = "scipy-1.11.4-cp312-cp312-win_amd64.whl", hash = "sha256:36750b7733d960d7994888f0d148d31ea3017ac15eef664194b4ef68d36a4a97"}, - {file = "scipy-1.11.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6e619aba2df228a9b34718efb023966da781e89dd3d21637b27f2e54db0410d7"}, - {file = "scipy-1.11.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:f3cd9e7b3c2c1ec26364856f9fbe78695fe631150f94cd1c22228456404cf1ec"}, - {file = "scipy-1.11.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d10e45a6c50211fe256da61a11c34927c68f277e03138777bdebedd933712fea"}, - {file = "scipy-1.11.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91af76a68eeae0064887a48e25c4e616fa519fa0d38602eda7e0f97d65d57937"}, - {file = "scipy-1.11.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6df1468153a31cf55ed5ed39647279beb9cfb5d3f84369453b49e4b8502394fd"}, - {file = "scipy-1.11.4-cp39-cp39-win_amd64.whl", hash = "sha256:ee410e6de8f88fd5cf6eadd73c135020bfbbbdfcd0f6162c36a7638a1ea8cc65"}, - {file = "scipy-1.11.4.tar.gz", hash = "sha256:90a2b78e7f5733b9de748f589f09225013685f9b218275257f8a8168ededaeaa"}, -] - -[package.dependencies] -numpy = ">=1.21.6,<1.28.0" - -[package.extras] -dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] -doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] -test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] - [[package]] name = "six" version = "1.16.0" @@ -2065,17 +1976,6 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] -[[package]] -name = "threadpoolctl" -version = "3.2.0" -description = "threadpoolctl" -optional = false -python-versions = ">=3.8" -files = [ - {file = "threadpoolctl-3.2.0-py3-none-any.whl", hash = "sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032"}, - {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"}, -] - [[package]] name = "tokenize-rt" version = "5.2.0" @@ -2322,4 +2222,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "7e705f5c5f2a8bba630031c0ff6752972e7cddc8ec95f3fb05b5be2ad7962268" \ No newline at end of file +content-hash = "f2735c243faa3d788c0f6268d6cb550648ed0d1fffec27a084344dafa4590a80" diff --git a/pyproject.toml b/pyproject.toml index c491ed1c4dc3a6a3416ee53865adbc8c0aa8a978..e45e5f17d0356cce8a2cfe5a33d9fa0529c170c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ cohere = "^4.32" numpy = "^1.25.2" pinecone-text = "^0.7.0" colorlog = "^6.8.0" -scikit-learn = "^1.3.2" [tool.poetry.group.dev.dependencies] diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 226e9dd06d6a8f81cdaa82ee778131d3c669b8e8..e7c5782f10a8fb1aeb1b6a124d493ac8e31af358 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -1,33 +1,62 @@ -from sklearn.feature_extraction.text import TfidfVectorizer +import numpy as np +from collections import Counter from semantic_router.encoders import BaseEncoder from semantic_router.schema import Route +from numpy.linalg import norm class TfidfEncoder(BaseEncoder): - vectorizer: TfidfVectorizer | None = None + idf: dict | None = None + word_index: dict | None = None def __init__(self, name: str = "tfidf"): super().__init__(name=name) - self.vectorizer = TfidfVectorizer() + self.word_index = None + self.idf = None def __call__(self, docs: list[str]) -> list[list[float]]: - if self.vectorizer is None: + if self.word_index is None or self.idf is None: raise ValueError("Vectorizer is not initialized.") if len(docs) == 0: raise ValueError("No documents to encode.") - embeds = self.vectorizer.transform(docs).toarray() - return embeds.tolist() + tf = self._compute_tf(docs) + tfidf = tf * self.idf + return tfidf.tolist() def fit(self, routes: list[Route]): - if self.vectorizer is None: - raise ValueError("Vectorizer is not initialized.") - docs = self._get_all_utterances(routes) - self.vectorizer.fit(docs) - - def _get_all_utterances(self, routes: list[Route]) -> list[str]: - utterances = [] + docs = [] for route in routes: for utterance in route.utterances: - utterances.append(utterance) - return utterances + docs.append(utterance) + self.word_index = self._build_word_index(docs) + self.idf = self._compute_idf(docs) + + def _build_word_index(self, docs: list[str]) -> dict: + words = set() + for doc in docs: + for word in doc.split(): + words.add(word) + word_index = {word: i for i, word in enumerate(words)} + return word_index + + def _compute_tf(self, docs: list[str]) -> np.ndarray: + tf = np.zeros((len(docs), len(self.word_index))) + for i, doc in enumerate(docs): + word_counts = Counter(doc.split()) + for word, count in word_counts.items(): + if word in self.word_index: + tf[i, self.word_index[word]] = count + # L2 normalization + tf = tf / norm(tf, axis=1, keepdims=True) + return tf + + def _compute_idf(self, docs: list[str]) -> np.ndarray: + idf = np.zeros(len(self.word_index)) + for doc in docs: + words = set(doc.split()) + for word in words: + if word in self.word_index: + idf[self.word_index[word]] += 1 + idf = np.log(len(docs) / (idf + 1)) + return idf diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py new file mode 100644 index 0000000000000000000000000000000000000000..93a966391e77b8645c45d52471fba80ea38f2d2c --- /dev/null +++ b/tests/unit/encoders/test_tfidf.py @@ -0,0 +1,61 @@ +import pytest +from semantic_router.encoders import TfidfEncoder +from semantic_router.schema import Route + + +@pytest.fixture +def tfidf_encoder(): + return TfidfEncoder() + + +class TestTfidfEncoder: + def test_initialization(self, tfidf_encoder): + assert tfidf_encoder.word_index is None + assert tfidf_encoder.idf is None + + def test_fit(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + assert tfidf_encoder.word_index is not None + assert tfidf_encoder.idf is not None + + def test_call_method(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + result = tfidf_encoder(["test"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(sublist, list) for sublist in result + ), "Each item in result should be a list" + + def test_call_method_no_docs(self, tfidf_encoder): + with pytest.raises(ValueError): + tfidf_encoder([]) + + def test_call_method_no_word(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + result = tfidf_encoder(["doc with fake word gta5jabcxyz"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(sublist, list) for sublist in result + ), "Each item in result should be a list" + + def test_call_method_with_uninitialized_model(self, tfidf_encoder): + with pytest.raises(ValueError): + tfidf_encoder(["test"]) diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 0a5dba6c49d0b3f017b902fc33ac8b660c34a810..ee7d8f6b484e000c1a22cff5a72907482b3385d1 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -45,11 +45,13 @@ def bm25_encoder(mocker): mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call) return BM25Encoder(name="test-bm25-encoder") + @pytest.fixture def tfidf_encoder(mocker): mocker.patch.object(TfidfEncoder, "__call__", side_effect=mock_encoder_call) return TfidfEncoder(name="test-tfidf-encoder") + @pytest.fixture def routes(): return [ @@ -60,21 +62,31 @@ def routes(): class TestHybridRouteLayer: def test_initialization(self, openai_encoder, bm25_encoder, routes): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes + ) assert route_layer.index is not None and route_layer.categories is not None assert route_layer.score_threshold == 0.82 assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 - def test_initialization_different_encoders(self, cohere_encoder, openai_encoder, bm25_encoder): - route_layer_cohere = HybridRouteLayer(dense_encoder=cohere_encoder, sparse_encoder=bm25_encoder) + def test_initialization_different_encoders( + self, cohere_encoder, openai_encoder, bm25_encoder + ): + route_layer_cohere = HybridRouteLayer( + dense_encoder=cohere_encoder, sparse_encoder=bm25_encoder + ) assert route_layer_cohere.score_threshold == 0.3 - route_layer_openai = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) + route_layer_openai = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder + ) assert route_layer_openai.score_threshold == 0.82 def test_add_route(self, openai_encoder, bm25_encoder): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder + ) route = Route(name="Route 3", utterances=["Yes", "No"]) route_layer.add(route) assert route_layer.index is not None and route_layer.categories is not None @@ -82,7 +94,9 @@ class TestHybridRouteLayer: assert len(set(route_layer.categories)) == 1 def test_add_multiple_routes(self, openai_encoder, bm25_encoder, routes): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder + ) for route in routes: route_layer.add(route) assert route_layer.index is not None and route_layer.categories is not None @@ -97,11 +111,15 @@ class TestHybridRouteLayer: assert query_result in ["Route 1", "Route 2"] def test_query_with_no_index(self, openai_encoder, bm25_encoder): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder + ) assert route_layer("Anything") is None def test_semantic_classify(self, openai_encoder, bm25_encoder, routes): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes + ) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -111,8 +129,12 @@ class TestHybridRouteLayer: assert classification == "Route 1" assert score == [0.9] - def test_semantic_classify_multiple_routes(self, openai_encoder, bm25_encoder, routes): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes) + def test_semantic_classify_multiple_routes( + self, openai_encoder, bm25_encoder, routes + ): + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes + ) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -124,12 +146,16 @@ class TestHybridRouteLayer: assert score == [0.9, 0.8] def test_pass_threshold(self, openai_encoder, bm25_encoder): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder + ) assert not route_layer._pass_threshold([], 0.5) assert route_layer._pass_threshold([0.6, 0.7], 0.5) def test_failover_score_threshold(self, base_encoder, bm25_encoder): - route_layer = HybridRouteLayer(dense_encoder=base_encoder, sparse_encoder=bm25_encoder) + route_layer = HybridRouteLayer( + dense_encoder=base_encoder, sparse_encoder=bm25_encoder + ) assert route_layer.score_threshold == 0.82