diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 09dc40d305813f3444d710896bd0ca4af484d6bf..616eaea1ffa16177bc8a6961a8846c96276abd13 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,3 +39,10 @@ jobs: - name: Pytest run: | make test + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v2 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + file: ./coverage.xml + fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 5e807c4d1ed56d35548cc859123a3cc2666acbcb..807674fa1ab9fee059c2942a22139cdd77f60eca 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ mac.env # Code coverage history .coverage +.coverage.* +.pytest_cache diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 03b6163c6274e2526a1a5262ca5716cf9b00739f..43af57e5ed2e38785c50adbd36cbf70afd300b36 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,8 +16,11 @@ repos: rev: v0.0.290 hooks: - id: ruff - types_or: [python, pyi, jupyter] - + types_or: [ python, pyi, jupyter ] + args: [ --fix ] + - id: ruff-format + types_or: [ python, pyi, jupyter ] + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 diff --git a/Makefile b/Makefile index 372221c63b5f6ba50b28a1c171645b44d4e94f0e..573998e954e3ae3084cb057f0d0749a58b98090d 100644 --- a/Makefile +++ b/Makefile @@ -11,4 +11,4 @@ lint lint_diff: poetry run ruff . test: - poetry run pytest -vv --cov=semantic_router --cov-report=term-missing --cov-fail-under=100 + poetry run pytest -vv -n auto --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=100 diff --git a/README.md b/README.md index 5a4725c9a5ff6de59b68f451d53dcd7d99f05489..9dac42225f967e558b3cac590e3dd4a0bf7a782b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,14 @@ [](https://aurelio.ai) # Semantic Router +<p> +<img alt="GitHub Contributors" src="https://img.shields.io/github/contributors/aurelio-labs/semantic-router" /> +<img alt="GitHub Last Commit" src="https://img.shields.io/github/last-commit/aurelio-labs/semantic-router" /> +<img alt="" src="https://img.shields.io/github/repo-size/aurelio-labs/semantic-router" /> +<img alt="GitHub Issues" src="https://img.shields.io/github/issues/aurelio-labs/semantic-router" /> +<img alt="GitHub Pull Requests" src="https://img.shields.io/github/issues-pr/aurelio-labs/semantic-router" /> +<img alt="Github License" src="https://img.shields.io/badge/License-MIT-yellow.svg" /> +</p> Semantic Router is a superfast decision layer for your LLMs and agents. Rather than waiting for slow LLM generations to make tool-use decisions, we use the magic of semantic vector space to make those decisions — _routing_ our requests using _semantic_ meaning. @@ -23,11 +31,10 @@ politics = Decision( utterances=[ "isn't politics the best thing ever", "why don't you tell me about your political opinions", - "don't you just love the president" - "don't you just hate the president", + "don't you just love the president" "don't you just hate the president", "they're going to destroy this country!", - "they will save the country!" - ] + "they will save the country!", + ], ) # this could be used as an indicator to our chatbot to switch to a more @@ -39,8 +46,8 @@ chitchat = Decision( "how are things going?", "lovely weather today", "the weather is horrendous", - "let's go to the chippy" - ] + "let's go to the chippy", + ], ) # we place both of our decisions together into single list @@ -97,13 +104,13 @@ dl("I'm interested in learning about llama 2") ``` ``` -[Out]: +[Out]: ``` In this case, no decision could be made as we had no matches — so our decision layer returned `None`! ## 📚 Resources -| | | -| --- | --- | -| 🃠[Walkthrough](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/walkthrough.ipynb) | Quickstart Python notebook | +| | | +| --------------------------------------------------------------------------------------------------------------- | -------------------------- | +| ðŸƒ[Walkthrough](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/walkthrough.ipynb) | Quickstart Python notebook | diff --git a/coverage.xml b/coverage.xml new file mode 100644 index 0000000000000000000000000000000000000000..65441b3247366c54ae80f8dc6a4fefcc67f16016 --- /dev/null +++ b/coverage.xml @@ -0,0 +1,383 @@ +<?xml version="1.0" ?> +<coverage version="7.3.2" timestamp="1702457433568" lines-valid="311" lines-covered="311" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> + <!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.2 --> + <!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd --> + <sources> + <source>/Users/jakit/customers/aurelio/semantic-router/semantic_router</source> + </sources> + <packages> + <package name="." line-rate="1" branch-rate="0" complexity="0"> + <classes> + <class name="__init__.py" filename="__init__.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="3" hits="1"/> + </lines> + </class> + <class name="layer.py" filename="layer.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="2" hits="1"/> + <line number="3" hits="1"/> + <line number="5" hits="1"/> + <line number="11" hits="1"/> + <line number="12" hits="1"/> + <line number="15" hits="1"/> + <line number="16" hits="1"/> + <line number="17" hits="1"/> + <line number="18" hits="1"/> + <line number="20" hits="1"/> + <line number="21" hits="1"/> + <line number="23" hits="1"/> + <line number="24" hits="1"/> + <line number="25" hits="1"/> + <line number="26" hits="1"/> + <line number="28" hits="1"/> + <line number="30" hits="1"/> + <line number="32" hits="1"/> + <line number="34" hits="1"/> + <line number="35" hits="1"/> + <line number="36" hits="1"/> + <line number="37" hits="1"/> + <line number="38" hits="1"/> + <line number="39" hits="1"/> + <line number="41" hits="1"/> + <line number="46" hits="1"/> + <line number="48" hits="1"/> + <line number="51" hits="1"/> + <line number="52" hits="1"/> + <line number="54" hits="1"/> + <line number="55" hits="1"/> + <line number="57" hits="1"/> + <line number="58" hits="1"/> + <line number="60" hits="1"/> + <line number="61" hits="1"/> + <line number="63" hits="1"/> + <line number="65" hits="1"/> + <line number="68" hits="1"/> + <line number="71" hits="1"/> + <line number="74" hits="1"/> + <line number="75" hits="1"/> + <line number="82" hits="1"/> + <line number="83" hits="1"/> + <line number="89" hits="1"/> + <line number="94" hits="1"/> + <line number="95" hits="1"/> + <line number="97" hits="1"/> + <line number="99" hits="1"/> + <line number="100" hits="1"/> + <line number="102" hits="1"/> + <line number="103" hits="1"/> + <line number="107" hits="1"/> + <line number="109" hits="1"/> + <line number="110" hits="1"/> + <line number="111" hits="1"/> + <line number="112" hits="1"/> + <line number="113" hits="1"/> + <line number="114" hits="1"/> + <line number="115" hits="1"/> + <line number="117" hits="1"/> + <line number="120" hits="1"/> + <line number="123" hits="1"/> + <line number="126" hits="1"/> + <line number="128" hits="1"/> + <line number="129" hits="1"/> + <line number="130" hits="1"/> + <line number="132" hits="1"/> + <line number="135" hits="1"/> + <line number="136" hits="1"/> + <line number="137" hits="1"/> + <line number="138" hits="1"/> + <line number="139" hits="1"/> + <line number="141" hits="1"/> + <line number="144" hits="1"/> + <line number="145" hits="1"/> + <line number="146" hits="1"/> + <line number="148" hits="1"/> + <line number="149" hits="1"/> + <line number="150" hits="1"/> + <line number="151" hits="1"/> + <line number="153" hits="1"/> + <line number="155" hits="1"/> + <line number="157" hits="1"/> + <line number="158" hits="1"/> + <line number="160" hits="1"/> + <line number="161" hits="1"/> + <line number="162" hits="1"/> + <line number="163" hits="1"/> + <line number="164" hits="1"/> + <line number="165" hits="1"/> + <line number="167" hits="1"/> + <line number="169" hits="1"/> + <line number="170" hits="1"/> + <line number="172" hits="1"/> + <line number="174" hits="1"/> + <line number="175" hits="1"/> + <line number="180" hits="1"/> + <line number="181" hits="1"/> + <line number="182" hits="1"/> + <line number="184" hits="1"/> + <line number="185" hits="1"/> + <line number="186" hits="1"/> + <line number="190" hits="1"/> + <line number="191" hits="1"/> + <line number="193" hits="1"/> + <line number="195" hits="1"/> + <line number="196" hits="1"/> + <line number="198" hits="1"/> + <line number="200" hits="1"/> + <line number="205" hits="1"/> + <line number="206" hits="1"/> + <line number="208" hits="1"/> + <line number="209" hits="1"/> + <line number="211" hits="1"/> + <line number="213" hits="1"/> + <line number="215" hits="1"/> + <line number="216" hits="1"/> + <line number="217" hits="1"/> + <line number="219" hits="1"/> + <line number="220" hits="1"/> + <line number="221" hits="1"/> + <line number="222" hits="1"/> + <line number="224" hits="1"/> + <line number="225" hits="1"/> + <line number="226" hits="1"/> + <line number="228" hits="1"/> + <line number="229" hits="1"/> + <line number="233" hits="1"/> + <line number="235" hits="1"/> + <line number="237" hits="1"/> + <line number="238" hits="1"/> + <line number="239" hits="1"/> + <line number="241" hits="1"/> + <line number="242" hits="1"/> + <line number="243" hits="1"/> + <line number="244" hits="1"/> + <line number="245" hits="1"/> + <line number="246" hits="1"/> + <line number="247" hits="1"/> + <line number="249" hits="1"/> + <line number="252" hits="1"/> + <line number="255" hits="1"/> + <line number="258" hits="1"/> + <line number="260" hits="1"/> + <line number="261" hits="1"/> + <line number="262" hits="1"/> + <line number="264" hits="1"/> + </lines> + </class> + <class name="linear.py" filename="linear.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="3" hits="1"/> + <line number="4" hits="1"/> + <line number="7" hits="1"/> + <line number="18" hits="1"/> + <line number="19" hits="1"/> + <line number="20" hits="1"/> + <line number="21" hits="1"/> + <line number="24" hits="1"/> + <line number="26" hits="1"/> + <line number="27" hits="1"/> + <line number="28" hits="1"/> + <line number="30" hits="1"/> + </lines> + </class> + <class name="schema.py" filename="schema.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="3" hits="1"/> + <line number="4" hits="1"/> + <line number="6" hits="1"/> + <line number="13" hits="1"/> + <line number="14" hits="1"/> + <line number="15" hits="1"/> + <line number="16" hits="1"/> + <line number="19" hits="1"/> + <line number="20" hits="1"/> + <line number="21" hits="1"/> + <line number="22" hits="1"/> + <line number="25" hits="1"/> + <line number="26" hits="1"/> + <line number="27" hits="1"/> + <line number="28" hits="1"/> + <line number="29" hits="1"/> + <line number="31" hits="1"/> + <line number="32" hits="1"/> + <line number="33" hits="1"/> + <line number="34" hits="1"/> + <line number="35" hits="1"/> + <line number="36" hits="1"/> + <line number="37" hits="1"/> + <line number="38" hits="1"/> + <line number="39" hits="1"/> + <line number="41" hits="1"/> + <line number="42" hits="1"/> + <line number="45" hits="1"/> + <line number="46" hits="1"/> + <line number="47" hits="1"/> + <line number="48" hits="1"/> + <line number="49" hits="1"/> + <line number="51" hits="1"/> + <line number="52" hits="1"/> + <line number="53" hits="1"/> + <line number="55" hits="1"/> + <line number="56" hits="1"/> + </lines> + </class> + </classes> + </package> + <package name="encoders" line-rate="1" branch-rate="0" complexity="0"> + <classes> + <class name="__init__.py" filename="encoders/__init__.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="2" hits="1"/> + <line number="3" hits="1"/> + <line number="4" hits="1"/> + <line number="6" hits="1"/> + </lines> + </class> + <class name="base.py" filename="encoders/base.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="4" hits="1"/> + <line number="5" hits="1"/> + <line number="7" hits="1"/> + <line number="8" hits="1"/> + <line number="10" hits="1"/> + <line number="11" hits="1"/> + </lines> + </class> + <class name="bm25.py" filename="encoders/bm25.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="3" hits="1"/> + <line number="6" hits="1"/> + <line number="7" hits="1"/> + <line number="8" hits="1"/> + <line number="10" hits="1"/> + <line number="11" hits="1"/> + <line number="13" hits="1"/> + <line number="14" hits="1"/> + <line number="19" hits="1"/> + <line number="20" hits="1"/> + <line number="21" hits="1"/> + <line number="22" hits="1"/> + <line number="23" hits="1"/> + <line number="25" hits="1"/> + <line number="27" hits="1"/> + <line number="28" hits="1"/> + <line number="29" hits="1"/> + <line number="30" hits="1"/> + <line number="31" hits="1"/> + <line number="32" hits="1"/> + <line number="33" hits="1"/> + <line number="34" hits="1"/> + <line number="36" hits="1"/> + <line number="37" hits="1"/> + <line number="39" hits="1"/> + <line number="40" hits="1"/> + </lines> + </class> + <class name="cohere.py" filename="encoders/cohere.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="3" hits="1"/> + <line number="5" hits="1"/> + <line number="8" hits="1"/> + <line number="9" hits="1"/> + <line number="11" hits="1"/> + <line number="16" hits="1"/> + <line number="17" hits="1"/> + <line number="18" hits="1"/> + <line number="19" hits="1"/> + <line number="20" hits="1"/> + <line number="21" hits="1"/> + <line number="22" hits="1"/> + <line number="23" hits="1"/> + <line number="25" hits="1"/> + <line number="26" hits="1"/> + <line number="27" hits="1"/> + <line number="28" hits="1"/> + <line number="29" hits="1"/> + <line number="30" hits="1"/> + <line number="31" hits="1"/> + <line number="32" hits="1"/> + </lines> + </class> + <class name="openai.py" filename="encoders/openai.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="2" hits="1"/> + <line number="4" hits="1"/> + <line number="5" hits="1"/> + <line number="7" hits="1"/> + <line number="8" hits="1"/> + <line number="11" hits="1"/> + <line number="12" hits="1"/> + <line number="13" hits="1"/> + <line number="14" hits="1"/> + <line number="15" hits="1"/> + <line number="16" hits="1"/> + <line number="18" hits="1"/> + <line number="22" hits="1"/> + <line number="23" hits="1"/> + <line number="26" hits="1"/> + <line number="27" hits="1"/> + <line number="28" hits="1"/> + <line number="29" hits="1"/> + <line number="30" hits="1"/> + <line number="31" hits="1"/> + <line number="32" hits="1"/> + <line number="33" hits="1"/> + <line number="34" hits="1"/> + <line number="35" hits="1"/> + <line number="36" hits="1"/> + <line number="37" hits="1"/> + <line number="39" hits="1"/> + <line number="40" hits="1"/> + </lines> + </class> + </classes> + </package> + <package name="utils" line-rate="1" branch-rate="0" complexity="0"> + <classes> + <class name="logger.py" filename="utils/logger.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="3" hits="1"/> + <line number="6" hits="1"/> + <line number="7" hits="1"/> + <line number="8" hits="1"/> + <line number="23" hits="1"/> + <line number="24" hits="1"/> + <line number="26" hits="1"/> + <line number="27" hits="1"/> + <line number="29" hits="1"/> + <line number="35" hits="1"/> + <line number="37" hits="1"/> + <line number="40" hits="1"/> + <line number="41" hits="1"/> + <line number="42" hits="1"/> + <line number="44" hits="1"/> + <line number="46" hits="1"/> + <line number="47" hits="1"/> + <line number="49" hits="1"/> + <line number="52" hits="1"/> + </lines> + </class> + </classes> + </package> + </packages> +</coverage> diff --git a/poetry.lock b/poetry.lock index 307c43b0b14801ba4563046e85e95dff5320be58..3bedc8de3502b726b4fdbc6d765f0c7b1ffacea9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -181,43 +181,49 @@ files = [ [[package]] name = "black" -version = "23.11.0" +version = "23.12.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dbea0bb8575c6b6303cc65017b46351dc5953eea5c0a59d7b7e3a2d2f433a911"}, - {file = "black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:412f56bab20ac85927f3a959230331de5614aecda1ede14b373083f62ec24e6f"}, - {file = "black-23.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d136ef5b418c81660ad847efe0e55c58c8208b77a57a28a503a5f345ccf01394"}, - {file = "black-23.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:6c1cac07e64433f646a9a838cdc00c9768b3c362805afc3fce341af0e6a9ae9f"}, - {file = "black-23.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cf57719e581cfd48c4efe28543fea3d139c6b6f1238b3f0102a9c73992cbb479"}, - {file = "black-23.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:698c1e0d5c43354ec5d6f4d914d0d553a9ada56c85415700b81dc90125aac244"}, - {file = "black-23.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:760415ccc20f9e8747084169110ef75d545f3b0932ee21368f63ac0fee86b221"}, - {file = "black-23.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:58e5f4d08a205b11800332920e285bd25e1a75c54953e05502052738fe16b3b5"}, - {file = "black-23.11.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:45aa1d4675964946e53ab81aeec7a37613c1cb71647b5394779e6efb79d6d187"}, - {file = "black-23.11.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c44b7211a3a0570cc097e81135faa5f261264f4dfaa22bd5ee2875a4e773bd6"}, - {file = "black-23.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a9acad1451632021ee0d146c8765782a0c3846e0e0ea46659d7c4f89d9b212b"}, - {file = "black-23.11.0-cp38-cp38-win_amd64.whl", hash = "sha256:fc7f6a44d52747e65a02558e1d807c82df1d66ffa80a601862040a43ec2e3142"}, - {file = "black-23.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7f622b6822f02bfaf2a5cd31fdb7cd86fcf33dab6ced5185c35f5db98260b055"}, - {file = "black-23.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:250d7e60f323fcfc8ea6c800d5eba12f7967400eb6c2d21ae85ad31c204fb1f4"}, - {file = "black-23.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5133f5507007ba08d8b7b263c7aa0f931af5ba88a29beacc4b2dc23fcefe9c06"}, - {file = "black-23.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:421f3e44aa67138ab1b9bfbc22ee3780b22fa5b291e4db8ab7eee95200726b07"}, - {file = "black-23.11.0-py3-none-any.whl", hash = "sha256:54caaa703227c6e0c87b76326d0862184729a69b73d3b7305b6288e1d830067e"}, - {file = "black-23.11.0.tar.gz", hash = "sha256:4c68855825ff432d197229846f971bc4d6666ce90492e5b02013bcaca4d9ab05"}, + {file = "black-23.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:67f19562d367468ab59bd6c36a72b2c84bc2f16b59788690e02bbcb140a77175"}, + {file = "black-23.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bbd75d9f28a7283b7426160ca21c5bd640ca7cd8ef6630b4754b6df9e2da8462"}, + {file = "black-23.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:593596f699ca2dcbbbdfa59fcda7d8ad6604370c10228223cd6cf6ce1ce7ed7e"}, + {file = "black-23.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:12d5f10cce8dc27202e9a252acd1c9a426c83f95496c959406c96b785a92bb7d"}, + {file = "black-23.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e73c5e3d37e5a3513d16b33305713237a234396ae56769b839d7c40759b8a41c"}, + {file = "black-23.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ba09cae1657c4f8a8c9ff6cfd4a6baaf915bb4ef7d03acffe6a2f6585fa1bd01"}, + {file = "black-23.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ace64c1a349c162d6da3cef91e3b0e78c4fc596ffde9413efa0525456148873d"}, + {file = "black-23.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:72db37a2266b16d256b3ea88b9affcdd5c41a74db551ec3dd4609a59c17d25bf"}, + {file = "black-23.12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fdf6f23c83078a6c8da2442f4d4eeb19c28ac2a6416da7671b72f0295c4a697b"}, + {file = "black-23.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39dda060b9b395a6b7bf9c5db28ac87b3c3f48d4fdff470fa8a94ab8271da47e"}, + {file = "black-23.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7231670266ca5191a76cb838185d9be59cfa4f5dd401b7c1c70b993c58f6b1b5"}, + {file = "black-23.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:193946e634e80bfb3aec41830f5d7431f8dd5b20d11d89be14b84a97c6b8bc75"}, + {file = "black-23.12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bcf91b01ddd91a2fed9a8006d7baa94ccefe7e518556470cf40213bd3d44bbbc"}, + {file = "black-23.12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:996650a89fe5892714ea4ea87bc45e41a59a1e01675c42c433a35b490e5aa3f0"}, + {file = "black-23.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdbff34c487239a63d86db0c9385b27cdd68b1bfa4e706aa74bb94a435403672"}, + {file = "black-23.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:97af22278043a6a1272daca10a6f4d36c04dfa77e61cbaaf4482e08f3640e9f0"}, + {file = "black-23.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ead25c273adfad1095a8ad32afdb8304933efba56e3c1d31b0fee4143a1e424a"}, + {file = "black-23.12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c71048345bdbced456cddf1622832276d98a710196b842407840ae8055ade6ee"}, + {file = "black-23.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81a832b6e00eef2c13b3239d514ea3b7d5cc3eaa03d0474eedcbbda59441ba5d"}, + {file = "black-23.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:6a82a711d13e61840fb11a6dfecc7287f2424f1ca34765e70c909a35ffa7fb95"}, + {file = "black-23.12.0-py3-none-any.whl", hash = "sha256:a7c07db8200b5315dc07e331dda4d889a56f6bf4db6a9c2a526fa3166a81614f"}, + {file = "black-23.12.0.tar.gz", hash = "sha256:330a327b422aca0634ecd115985c1c7fd7bdb5b5a2ef8aa9888a82e2ebe9437a"}, ] [package.dependencies] click = ">=8.0.0" +ipython = {version = ">=7.8.0", optional = true, markers = "extra == \"jupyter\""} mypy-extensions = ">=0.4.3" packaging = ">=22.0" pathspec = ">=0.9.0" platformdirs = ">=2" +tokenize-rt = {version = ">=3.2.0", optional = true, markers = "extra == \"jupyter\""} tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} [package.extras] colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] +d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] @@ -592,6 +598,20 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "execnet" +version = "2.0.2" +description = "execnet: rapid multi-Python deployment" +optional = false +python-versions = ">=3.7" +files = [ + {file = "execnet-2.0.2-py3-none-any.whl", hash = "sha256:88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41"}, + {file = "execnet-2.0.2.tar.gz", hash = "sha256:cc59bc4423742fd71ad227122eb0dd44db51efb3dc4095b45ac9a08c770096af"}, +] + +[package.extras] +testing = ["hatch", "pre-commit", "pytest", "tox"] + [[package]] name = "executing" version = "2.0.1" @@ -1453,6 +1473,26 @@ pytest = ">=5.0" [package.extras] dev = ["pre-commit", "pytest-asyncio", "tox"] +[[package]] +name = "pytest-xdist" +version = "3.5.0" +description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-xdist-3.5.0.tar.gz", hash = "sha256:cbb36f3d67e0c478baa57fa4edc8843887e0f6cfc42d677530a36d7472b32d8a"}, + {file = "pytest_xdist-3.5.0-py3-none-any.whl", hash = "sha256:d075629c7e00b611df89f490a5063944bee7a4362a5ff11c7cc7824a03dfce24"}, +] + +[package.dependencies] +execnet = ">=1.1" +pytest = ">=6.2.0" + +[package.extras] +psutil = ["psutil (>=3.0)"] +setproctitle = ["setproctitle"] +testing = ["filelock"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -1769,6 +1809,17 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "tokenize-rt" +version = "5.2.0" +description = "A wrapper around the stdlib `tokenize` which roundtrips." +optional = false +python-versions = ">=3.8" +files = [ + {file = "tokenize_rt-5.2.0-py2.py3-none-any.whl", hash = "sha256:b79d41a65cfec71285433511b50271b05da3584a1da144a0752e9c621a285289"}, + {file = "tokenize_rt-5.2.0.tar.gz", hash = "sha256:9fe80f8a5c1edad2d3ede0f37481cc0cc1538a2f442c9c2f9e4feacd2792d054"}, +] + [[package]] name = "tomli" version = "2.0.1" @@ -2004,4 +2055,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "64e772051ca3411e09defc8ab06235a7c3e39f9bf60e58fb06b25317c5a34053" +content-hash = "b17b9fd9486d6c744c41a31ab54f7871daba1e2d4166fda228033c5858f6f9d8" diff --git a/pyproject.toml b/pyproject.toml index 4dec2ef4a359f664fe936b4927b2571dd2966da4..b21cd485683a42c26724bba6c32724b197ec518e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,11 +24,15 @@ colorlog = "^6.8.0" [tool.poetry.group.dev.dependencies] ipykernel = "^6.26.0" ruff = "^0.1.5" -black = "^23.11.0" +black = {extras = ["jupyter"], version = "^23.12.0"} pytest = "^7.4.3" pytest-mock = "^3.12.0" pytest-cov = "^4.1.0" +pytest-xdist = "^3.5.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.ruff.per-file-ignores] +"*.ipynb" = ["E402"] diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 0c86ce7c47290d1490565837bb8a98a5631941f2..30ad624a2104ec3c48b8684615b89e3a35d43be2 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -1,6 +1,6 @@ from .base import BaseEncoder +from .bm25 import BM25Encoder from .cohere import CohereEncoder from .openai import OpenAIEncoder -from .bm25 import BM25Encoder __all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder"] diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index 00ef722d3dfba2e28747d3905f58a870b0ffdc37..34331d23ffc46d317c883343b0b42a8bd03b7381 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -9,7 +9,9 @@ class CohereEncoder(BaseEncoder): client: cohere.Client | None def __init__( - self, name: str = "embed-english-v3.0", cohere_api_key: str | None = None + self, + name: str = os.getenv("COHERE_MODEL_NAME", "embed-english-v3.0"), + cohere_api_key: str | None = None, ): super().__init__(name=name) cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index b828c2e571d9dbc361dfadb8738b56bf3866cae3..858e5b7ad8faa17bd687f4aa42c26770f814fd99 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -2,7 +2,7 @@ import os from time import sleep import openai -from openai.error import RateLimitError, ServiceUnavailableError, OpenAIError +from openai.error import OpenAIError, RateLimitError, ServiceUnavailableError from semantic_router.encoders import BaseEncoder from semantic_router.utils.logger import logger diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 832fb9cd7ef34f390233354f3edad6d51141d3a3..1bb900fb78884becca16a37988751786d356a408 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -29,7 +29,7 @@ class DecisionLayer: # if decisions list has been passed, we initialize index now if decisions: # initialize index now - self._add_decisions(decisions=decisions) + self.add_decisions(decisions=decisions) def __call__(self, text: str) -> str | None: results = self._query(text) @@ -40,10 +40,10 @@ class DecisionLayer: else: return None - def add(self, decision: Decision): - self._add_decision(decision=decision) + # def add(self, decision: Decision): + # self.add_decision(decision=decision) - def _add_decision(self, decision: Decision): + def add_decision(self, decision: Decision): # create embeddings embeds = self.encoder(decision.utterances) @@ -60,7 +60,7 @@ class DecisionLayer: embed_arr = np.array(embeds) self.index = np.concatenate([self.index, embed_arr]) - def _add_decisions(self, decisions: list[Decision]): + def add_decisions(self, decisions: list[Decision]): # create embeddings for all decisions all_utterances = [ utterance for decision in decisions for utterance in decision.utterances @@ -197,9 +197,6 @@ class HybridDecisionLayer: else: self.sparse_index = np.concatenate([self.sparse_index, sparse_embeds]) - def _add_decisions(self, decisions: list[Decision]): - raise NotImplementedError - def _query(self, text: str, top_k: int = 5): """Given some text, encodes and searches the index vector space to retrieve the top_k most similar records. diff --git a/tests/unit/encoders/test_cohere.py b/tests/unit/encoders/test_cohere.py index 7f7ddf281244132c130e027fedc99e00829770de..0f7607af39ce4ad5b6cc4ccfe13beea174745648 100644 --- a/tests/unit/encoders/test_cohere.py +++ b/tests/unit/encoders/test_cohere.py @@ -34,8 +34,52 @@ class TestCohereEncoder: ), "Each item in result should be a list" cohere_encoder.client.embed.assert_called_once() - def test_call_with_uninitialized_client(self, mocker): + def test_returns_list_of_embeddings_for_valid_input(self, cohere_encoder, mocker): + mock_embed = mocker.MagicMock() + mock_embed.embeddings = [[0.1, 0.2, 0.3]] + cohere_encoder.client.embed.return_value = mock_embed + + result = cohere_encoder(["test"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(sublist, list) for sublist in result + ), "Each item in result should be a list" + cohere_encoder.client.embed.assert_called_once() + + def test_handles_multiple_inputs_correctly(self, cohere_encoder, mocker): + mock_embed = mocker.MagicMock() + mock_embed.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + cohere_encoder.client.embed.return_value = mock_embed + + result = cohere_encoder(["test1", "test2"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(sublist, list) for sublist in result + ), "Each item in result should be a list" + cohere_encoder.client.embed.assert_called_once() + + def test_raises_value_error_if_api_key_is_none(self, mocker, monkeypatch): + monkeypatch.delenv("COHERE_API_KEY", raising=False) + mocker.patch("cohere.Client") + with pytest.raises(ValueError): + CohereEncoder() + + def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker): + mocker.patch( + "cohere.Client", side_effect=Exception("Failed to initialize client") + ) + with pytest.raises(ValueError): + CohereEncoder(cohere_api_key="test_api_key") + + def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker): mocker.patch("cohere.Client", return_value=None) encoder = CohereEncoder(cohere_api_key="test_api_key") with pytest.raises(ValueError): encoder(["test"]) + + def test_call_method_raises_error_on_api_failure(self, cohere_encoder, mocker): + mocker.patch.object( + cohere_encoder.client, "embed", side_effect=Exception("API call failed") + ) + with pytest.raises(ValueError): + cohere_encoder(["test"]) diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index a746c4ec8f220f0896abb5c415749cb7b62a5796..8c0c9729379e1abfa8c67278a1ef5a092c0859be 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -4,7 +4,9 @@ from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.layer import ( DecisionLayer, HybridDecisionLayer, -) # Replace with the actual module name +) + +# Replace with the actual module name from semantic_router.schema import Decision @@ -49,8 +51,12 @@ class TestDecisionLayer: def test_initialization(self, openai_encoder, decisions): decision_layer = DecisionLayer(encoder=openai_encoder, decisions=decisions) assert decision_layer.score_threshold == 0.82 - assert len(decision_layer.index) == 5 - assert len(set(decision_layer.categories)) == 2 + assert len(decision_layer.index) if decision_layer.index is not None else 0 == 5 + assert ( + len(set(decision_layer.categories)) + if decision_layer.categories is not None + else 0 == 2 + ) def test_initialization_different_encoders(self, cohere_encoder, openai_encoder): decision_layer_cohere = DecisionLayer(encoder=cohere_encoder) @@ -61,15 +67,28 @@ class TestDecisionLayer: def test_add_decision(self, openai_encoder): decision_layer = DecisionLayer(encoder=openai_encoder) - decision = Decision(name="Decision 3", utterances=["Yes", "No"]) - decision_layer.add(decision) + decision1 = Decision(name="Decision 1", utterances=["Yes", "No"]) + decision2 = Decision(name="Decision 2", utterances=["Maybe", "Sure"]) + + decision_layer.add_decision(decision=decision1) + assert ( + decision_layer.index is not None and decision_layer.categories is not None + ) assert len(decision_layer.index) == 2 assert len(set(decision_layer.categories)) == 1 + assert set(decision_layer.categories) == {"Decision 1"} + + decision_layer.add_decision(decision=decision2) + assert len(decision_layer.index) == 4 + assert len(set(decision_layer.categories)) == 2 + assert set(decision_layer.categories) == {"Decision 1", "Decision 2"} def test_add_multiple_decisions(self, openai_encoder, decisions): decision_layer = DecisionLayer(encoder=openai_encoder) - for decision in decisions: - decision_layer.add(decision) + decision_layer.add_decisions(decisions=decisions) + assert ( + decision_layer.index is not None and decision_layer.categories is not None + ) assert len(decision_layer.index) == 5 assert len(set(decision_layer.categories)) == 2 @@ -121,6 +140,9 @@ class TestHybridDecisionLayer: encoder=openai_encoder, decisions=decisions ) assert decision_layer.score_threshold == 0.82 + assert ( + decision_layer.index is not None and decision_layer.categories is not None + ) assert len(decision_layer.index) == 5 assert len(set(decision_layer.categories)) == 2 @@ -135,6 +157,9 @@ class TestHybridDecisionLayer: decision_layer = HybridDecisionLayer(encoder=openai_encoder) decision = Decision(name="Decision 3", utterances=["Yes", "No"]) decision_layer.add(decision) + assert ( + decision_layer.index is not None and decision_layer.categories is not None + ) assert len(decision_layer.index) == 2 assert len(set(decision_layer.categories)) == 1 @@ -142,6 +167,9 @@ class TestHybridDecisionLayer: decision_layer = HybridDecisionLayer(encoder=openai_encoder) for decision in decisions: decision_layer.add(decision) + assert ( + decision_layer.index is not None and decision_layer.categories is not None + ) assert len(decision_layer.index) == 5 assert len(set(decision_layer.categories)) == 2 diff --git a/walkthrough.ipynb b/walkthrough.ipynb index dcd024932df97ca7fe1a1824cd0cd1e5fd0bad87..2e9570b881948caab8ac1ef25be4795147f858ca 100644 --- a/walkthrough.ipynb +++ b/walkthrough.ipynb @@ -46,19 +46,9 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/jamesbriggs/opt/anaconda3/envs/decision-layer/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from semantic_router.schema import Decision\n", "\n", @@ -67,11 +57,10 @@ " utterances=[\n", " \"isn't politics the best thing ever\",\n", " \"why don't you tell me about your political opinions\",\n", - " \"don't you just love the president\"\n", - " \"don't you just hate the president\",\n", + " \"don't you just love the president\" \"don't you just hate the president\",\n", " \"they're going to destroy this country!\",\n", - " \"they will save the country!\"\n", - " ]\n", + " \"they will save the country!\",\n", + " ],\n", ")" ] }, @@ -84,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -95,8 +84,8 @@ " \"how are things going?\",\n", " \"lovely weather today\",\n", " \"the weather is horrendous\",\n", - " \"let's go to the chippy\"\n", - " ]\n", + " \"let's go to the chippy\",\n", + " ],\n", ")\n", "\n", "decisions = [politics, chitchat]" @@ -111,16 +100,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from semantic_router.encoders import CohereEncoder\n", - "from getpass import getpass\n", "import os\n", + "from getpass import getpass\n", + "from semantic_router.encoders import CohereEncoder\n", "\n", - "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or \\\n", - " getpass(\"Enter Cohere API Key: \")\n", + "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n", + " \"Enter Cohere API Key: \"\n", + ")\n", "\n", "encoder = CohereEncoder()" ] @@ -134,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -152,40 +142,18 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'politics'" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "dl(\"don't you love politics?\")" ] }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'chitchat'" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "dl(\"how's the weather today?\")" ] @@ -199,7 +167,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -230,7 +198,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.3" } }, "nbformat": 4,