diff --git a/.github/workflows/conventional_commits.yml b/.github/workflows/conventional_commits.yml new file mode 100644 index 0000000000000000000000000000000000000000..9778b0186ada4e84e76d39f6d83eb9f4af822e9b --- /dev/null +++ b/.github/workflows/conventional_commits.yml @@ -0,0 +1,13 @@ +name: Conventional Commits + +on: + pull_request: + branches: [main] + +jobs: + build: + name: Conventional Commits + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: webiny/action-conventional-commits@v1.1.0 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b7c6e679c2f89003a14023fbd12afc30f8b7c5db..dcd301e5d631cab642bd45c813ddc8d86e86f7a4 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -37,4 +37,4 @@ jobs: poetry install - name: Analyzing the code with our lint run: | - make lint \ No newline at end of file + make lint diff --git a/.github/workflows/pr_agent.yml b/.github/workflows/pr_agent.yml index 4e86dfbc55a223672dbe56b2145f2660cae3d74e..e9db72d81ed26d07d1cef34e272139a8f98cedab 100644 --- a/.github/workflows/pr_agent.yml +++ b/.github/workflows/pr_agent.yml @@ -14,5 +14,5 @@ jobs: id: pragent uses: Codium-ai/pr-agent@main env: - OPENAI_KEY: ${{ secrets.OPENAI_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43af57e5ed2e38785c50adbd36cbf70afd300b36..530abf4e8f6dfa6d0d8b84e29033848f5b6def88 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,11 @@ default_language_version: python: python3.11.3 repos: + - repo: meta + hooks: + - id: check-hooks-apply + - id: check-useless-excludes + - repo: https://github.com/psf/black rev: 23.9.1 hooks: @@ -21,10 +26,38 @@ repos: - id: ruff-format types_or: [ python, pyi, jupyter ] + - repo: https://github.com/codespell-project/codespell + rev: v2.2.4 + hooks: + - id: codespell + name: Run codespell to check for common misspellings in files + language: python + types: [ text ] + args: [ "--write-changes", "--ignore-words-list", "asend" ] + exclude: "poetry.lock" + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: trailing-whitespace + - id: check-vcs-permalinks - id: end-of-file-fixer + # exclude: "tests/((commands|data)/|test_).+" + - id: trailing-whitespace + args: [ --markdown-linebreak-ext=md ] + - id: debug-statements + - id: no-commit-to-branch + - id: check-merge-conflict + - id: check-toml - id: check-yaml + args: [ '--unsafe' ] # for mkdocs.yml + - id: detect-private-key + + - repo: https://github.com/commitizen-tools/commitizen + rev: v3.13.0 + hooks: + - id: commitizen + - id: commitizen-branch + stages: + - post-commit + - push diff --git a/Makefile b/Makefile index 8de202fa56f0de52a80f6f8a63e68ab6fe18ef33..aeb3d3b19ff9262b933b9022f6fc80c240279040 100644 --- a/Makefile +++ b/Makefile @@ -12,4 +12,4 @@ lint lint_diff: poetry run mypy $(PYTHON_FILES) test: - poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=100 + poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=80 diff --git a/README.md b/README.md index 4ab443ce797c232d5846f4126be4fb081e034f61..d297275f1b885037679dad883bde106c1ba0bfa0 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ # Semantic Router <p> +<img alt="PyPI - Python Version" src="https://img.shields.io/pypi/pyversions/semantic-router?logo=python&logoColor=gold" /> <img alt="GitHub Contributors" src="https://img.shields.io/github/contributors/aurelio-labs/semantic-router" /> <img alt="GitHub Last Commit" src="https://img.shields.io/github/last-commit/aurelio-labs/semantic-router" /> <img alt="" src="https://img.shields.io/github/repo-size/aurelio-labs/semantic-router" /> @@ -25,7 +26,7 @@ pip install -qU semantic-router We begin by defining a set of `Decision` objects. These are the decision paths that the semantic router can decide to use, let's try two simple decisions for now — one for talk on _politics_ and another for _chitchat_: ```python -from semantic_router.schemas.route import Route +from semantic_router import Route # we could use this as a guide for our chatbot to avoid political conversations politics = Route( @@ -53,7 +54,7 @@ chitchat = Route( ) # we place both of our decisions together into single list -decisions = [politics, chitchat] +routes = [politics, chitchat] ``` We have our decisions ready, now we initialize an embedding / encoder model. We currently support a `CohereEncoder` and `OpenAIEncoder` — more encoders will be added soon. To initialize them we do: @@ -76,13 +77,13 @@ With our `decisions` and `encoder` defined we now create a `DecisionLayer`. The ```python from semantic_router.layer import RouteLayer -dl = RouteLayer(encoder=encoder, decisions=decisions) +dl = RouteLayer(encoder=encoder, routes=routes) ``` We can now use our decision layer to make super fast decisions based on user queries. Let's try with two queries that should trigger our decisions: ```python -dl("don't you love politics?") +dl("don't you love politics?").name ``` ``` @@ -92,7 +93,7 @@ dl("don't you love politics?") Correct decision, let's try another: ```python -dl("how's the weather today?") +dl("how's the weather today?").name ``` ``` @@ -102,7 +103,7 @@ dl("how's the weather today?") We get both decisions correct! Now lets try sending an unrelated query: ```python -dl("I'm interested in learning about llama 2") +dl("I'm interested in learning about llama 2").name ``` ``` @@ -111,8 +112,11 @@ dl("I'm interested in learning about llama 2") In this case, no decision could be made as we had no matches — so our decision layer returned `None`! -## 📚 Resources +## 📚 [Resources](https://github.com/aurelio-labs/semantic-router/tree/main/docs) + + + + + + -| | | -| ------------------------------------------------------------------------------------------------------------------ | -------------------------- | -| ðŸƒ[Walkthrough](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/walkthrough.ipynb) | Quickstart Python notebook | diff --git a/coverage.xml b/coverage.xml index 9af9ebee27365dd1289c5962a87b8451a3feef7c..f4b8af22c04a918a523a8dc7cacbf5edefab3d47 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,121 +1,341 @@ <?xml version="1.0" ?> -<coverage version="7.3.3" timestamp="1702894511196" lines-valid="345" lines-covered="345" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> +<coverage version="7.3.3" timestamp="1704188881490" lines-valid="608" lines-covered="411" line-rate="0.676" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> <!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.3 --> <!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd --> <sources> <source>/Users/jakit/customers/aurelio/semantic-router/semantic_router</source> </sources> <packages> - <package name="." line-rate="1" branch-rate="0" complexity="0"> + <package name="." line-rate="0.7045" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> <line number="2" hits="1"/> - <line number="4" hits="1"/> + <line number="3" hits="1"/> + <line number="5" hits="1"/> </lines> </class> - <class name="hybrid_layer.py" filename="hybrid_layer.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="hybrid_layer.py" filename="hybrid_layer.py" complexity="0" line-rate="0.2143" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> <line number="2" hits="1"/> - <line number="3" hits="1"/> - <line number="5" hits="1"/> + <line number="4" hits="1"/> + <line number="10" hits="1"/> <line number="11" hits="1"/> - <line number="12" hits="1"/> + <line number="14" hits="1"/> <line number="15" hits="1"/> <line number="16" hits="1"/> <line number="17" hits="1"/> <line number="18" hits="1"/> + <line number="20" hits="1"/> + <line number="23" hits="1"/> + <line number="24" hits="1"/> + <line number="25" hits="0"/> + <line number="27" hits="0"/> + <line number="28" hits="0"/> + <line number="29" hits="0"/> + <line number="30" hits="0"/> + <line number="32" hits="0"/> + <line number="34" hits="0"/> + <line number="38" hits="0"/> + <line number="40" hits="1"/> + <line number="41" hits="0"/> + <line number="42" hits="0"/> + <line number="43" hits="0"/> + <line number="44" hits="0"/> + <line number="45" hits="0"/> + <line number="47" hits="0"/> + <line number="49" hits="1"/> + <line number="50" hits="0"/> + <line number="52" hits="1"/> + <line number="54" hits="0"/> + <line number="55" hits="0"/> + <line number="60" hits="0"/> + <line number="61" hits="0"/> + <line number="62" hits="0"/> + <line number="64" hits="0"/> + <line number="65" hits="0"/> + <line number="66" hits="0"/> + <line number="70" hits="0"/> + <line number="71" hits="0"/> + <line number="73" hits="0"/> + <line number="75" hits="0"/> + <line number="76" hits="0"/> + <line number="78" hits="0"/> + <line number="80" hits="1"/> + <line number="82" hits="0"/> + <line number="83" hits="0"/> + <line number="86" hits="0"/> + <line number="87" hits="0"/> + <line number="90" hits="0"/> + <line number="91" hits="0"/> + <line number="92" hits="0"/> + <line number="99" hits="0"/> + <line number="106" hits="0"/> + <line number="112" hits="1"/> + <line number="117" hits="0"/> + <line number="118" hits="0"/> + <line number="120" hits="0"/> + <line number="121" hits="0"/> + <line number="123" hits="0"/> + <line number="125" hits="0"/> + <line number="127" hits="0"/> + <line number="128" hits="0"/> + <line number="129" hits="0"/> + <line number="131" hits="0"/> + <line number="132" hits="0"/> + <line number="133" hits="0"/> + <line number="134" hits="0"/> + <line number="136" hits="0"/> + <line number="137" hits="0"/> + <line number="138" hits="0"/> + <line number="140" hits="0"/> + <line number="141" hits="0"/> + <line number="143" hits="0"/> + <line number="144" hits="0"/> + <line number="146" hits="1"/> + <line number="148" hits="0"/> + <line number="149" hits="0"/> + <line number="150" hits="0"/> + <line number="152" hits="1"/> + <line number="153" hits="0"/> + <line number="154" hits="0"/> + <line number="155" hits="0"/> + <line number="156" hits="0"/> + <line number="157" hits="0"/> + <line number="158" hits="0"/> + <line number="160" hits="0"/> + <line number="163" hits="0"/> + <line number="164" hits="0"/> + <line number="167" hits="0"/> + <line number="168" hits="0"/> + <line number="170" hits="0"/> + <line number="171" hits="0"/> + <line number="173" hits="1"/> + <line number="174" hits="0"/> + <line number="175" hits="0"/> + <line number="177" hits="0"/> + </lines> + </class> + <class name="layer.py" filename="layer.py" complexity="0" line-rate="0.8541" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="2" hits="1"/> + <line number="4" hits="1"/> + <line number="5" hits="1"/> + <line number="7" hits="1"/> + <line number="12" hits="1"/> + <line number="13" hits="1"/> + <line number="14" hits="1"/> + <line number="15" hits="1"/> + <line number="18" hits="1"/> <line number="19" hits="1"/> + <line number="20" hits="1"/> <line number="21" hits="1"/> - <line number="24" hits="1"/> - <line number="25" hits="1"/> - <line number="26" hits="1"/> - <line number="28" hits="1"/> - <line number="29" hits="1"/> - <line number="30" hits="1"/> - <line number="31" hits="1"/> + <line number="23" hits="1"/> + <line number="24" hits="0"/> + <line number="25" hits="0"/> + <line number="26" hits="0"/> + <line number="27" hits="0"/> + <line number="30" hits="0"/> + <line number="31" hits="0"/> <line number="33" hits="1"/> - <line number="35" hits="1"/> - <line number="37" hits="1"/> - <line number="38" hits="1"/> + <line number="34" hits="1"/> + <line number="35" hits="0"/> + <line number="38" hits="0"/> <line number="40" hits="1"/> - <line number="41" hits="1"/> - <line number="42" hits="1"/> - <line number="43" hits="1"/> - <line number="44" hits="1"/> - <line number="45" hits="1"/> - <line number="47" hits="1"/> - <line number="49" hits="1"/> - <line number="50" hits="1"/> + <line number="41" hits="0"/> + <line number="42" hits="0"/> + <line number="43" hits="0"/> + <line number="46" hits="1"/> <line number="52" hits="1"/> <line number="54" hits="1"/> - <line number="55" hits="1"/> <line number="60" hits="1"/> <line number="61" hits="1"/> - <line number="62" hits="1"/> - <line number="64" hits="1"/> + <line number="63" hits="1"/> + <line number="64" hits="0"/> <line number="65" hits="1"/> - <line number="66" hits="1"/> + <line number="66" hits="0"/> + <line number="67" hits="1"/> + <line number="68" hits="0"/> + <line number="69" hits="1"/> <line number="70" hits="1"/> <line number="71" hits="1"/> <line number="73" hits="1"/> - <line number="75" hits="1"/> + <line number="74" hits="1"/> <line number="76" hits="1"/> + <line number="77" hits="1"/> <line number="78" hits="1"/> + <line number="79" hits="1"/> <line number="80" hits="1"/> - <line number="85" hits="1"/> - <line number="86" hits="1"/> + <line number="81" hits="1"/> + <line number="82" hits="1"/> + <line number="84" hits="1"/> <line number="88" hits="1"/> <line number="89" hits="1"/> + <line number="90" hits="1"/> <line number="91" hits="1"/> + <line number="92" hits="1"/> <line number="93" hits="1"/> - <line number="95" hits="1"/> - <line number="96" hits="1"/> - <line number="97" hits="1"/> + <line number="97" hits="0"/> <line number="99" hits="1"/> <line number="100" hits="1"/> - <line number="101" hits="1"/> - <line number="102" hits="1"/> - <line number="104" hits="1"/> - <line number="105" hits="1"/> <line number="106" hits="1"/> <line number="108" hits="1"/> <line number="109" hits="1"/> + <line number="110" hits="1"/> <line number="111" hits="1"/> <line number="112" hits="1"/> + <line number="113" hits="1"/> <line number="114" hits="1"/> <line number="116" hits="1"/> - <line number="117" hits="1"/> - <line number="118" hits="1"/> <line number="120" hits="1"/> - <line number="121" hits="1"/> - <line number="122" hits="1"/> - <line number="123" hits="1"/> + <line number="121" hits="0"/> + <line number="122" hits="0"/> <line number="124" hits="1"/> - <line number="125" hits="1"/> - <line number="126" hits="1"/> - <line number="128" hits="1"/> + <line number="125" hits="0"/> + <line number="126" hits="0"/> + <line number="127" hits="0"/> + <line number="128" hits="0"/> + <line number="129" hits="0"/> <line number="131" hits="1"/> - <line number="132" hits="1"/> - <line number="135" hits="1"/> - <line number="136" hits="1"/> - <line number="138" hits="1"/> + <line number="132" hits="0"/> + <line number="133" hits="0"/> + <line number="135" hits="0"/> + <line number="136" hits="0"/> <line number="139" hits="1"/> + <line number="140" hits="1"/> <line number="141" hits="1"/> <line number="142" hits="1"/> - <line number="143" hits="1"/> - <line number="145" hits="1"/> + <line number="144" hits="1"/> + <line number="147" hits="1"/> + <line number="148" hits="1"/> + <line number="149" hits="1"/> + <line number="150" hits="1"/> + <line number="151" hits="1"/> + <line number="153" hits="1"/> + <line number="154" hits="1"/> + <line number="155" hits="1"/> + <line number="156" hits="1"/> + <line number="158" hits="1"/> + <line number="160" hits="1"/> + <line number="162" hits="1"/> + <line number="164" hits="1"/> + <line number="165" hits="1"/> + <line number="166" hits="1"/> + <line number="167" hits="1"/> + <line number="168" hits="1"/> + <line number="170" hits="1"/> + <line number="171" hits="1"/> + <line number="174" hits="1"/> + <line number="176" hits="1"/> + <line number="177" hits="0"/> + <line number="183" hits="1"/> + <line number="184" hits="1"/> + <line number="185" hits="1"/> + <line number="186" hits="1"/> + <line number="187" hits="1"/> + <line number="189" hits="1"/> + <line number="190" hits="1"/> + <line number="191" hits="1"/> + <line number="192" hits="1"/> + <line number="193" hits="1"/> + <line number="195" hits="1"/> + <line number="196" hits="1"/> + <line number="197" hits="1"/> + <line number="198" hits="1"/> + <line number="200" hits="1"/> + <line number="201" hits="1"/> + <line number="203" hits="1"/> + <line number="206" hits="1"/> + <line number="207" hits="1"/> + <line number="208" hits="1"/> + <line number="210" hits="1"/> + <line number="211" hits="1"/> + <line number="212" hits="1"/> + <line number="214" hits="1"/> + <line number="215" hits="1"/> + <line number="216" hits="1"/> + <line number="218" hits="1"/> + <line number="219" hits="1"/> + <line number="220" hits="1"/> + <line number="222" hits="1"/> + <line number="224" hits="1"/> + <line number="226" hits="1"/> + <line number="229" hits="1"/> + <line number="232" hits="1"/> + <line number="233" hits="1"/> + <line number="234" hits="1"/> + <line number="241" hits="1"/> + <line number="242" hits="1"/> + <line number="248" hits="1"/> + <line number="253" hits="1"/> + <line number="254" hits="1"/> + <line number="256" hits="1"/> + <line number="258" hits="1"/> + <line number="259" hits="1"/> + <line number="261" hits="1"/> + <line number="262" hits="1"/> + <line number="264" hits="1"/> + <line number="265" hits="1"/> + <line number="267" hits="1"/> + <line number="268" hits="1"/> + <line number="269" hits="1"/> + <line number="270" hits="1"/> + <line number="271" hits="1"/> + <line number="272" hits="1"/> + <line number="273" hits="1"/> + <line number="275" hits="1"/> + <line number="278" hits="1"/> + <line number="279" hits="1"/> + <line number="282" hits="1"/> + <line number="283" hits="1"/> + <line number="285" hits="1"/> + <line number="286" hits="1"/> + <line number="288" hits="1"/> + <line number="289" hits="1"/> + <line number="290" hits="1"/> + <line number="292" hits="1"/> + <line number="294" hits="1"/> + <line number="295" hits="1"/> + <line number="301" hits="1"/> + <line number="302" hits="1"/> + <line number="303" hits="1"/> + <line number="305" hits="1"/> + <line number="306" hits="1"/> + <line number="307" hits="1"/> + </lines> + </class> + <class name="linear.py" filename="linear.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="3" hits="1"/> + <line number="4" hits="1"/> + <line number="7" hits="1"/> + <line number="18" hits="1"/> + <line number="19" hits="1"/> + <line number="20" hits="1"/> + <line number="21" hits="1"/> + <line number="24" hits="1"/> + <line number="26" hits="1"/> + <line number="27" hits="1"/> + <line number="28" hits="1"/> + <line number="30" hits="1"/> </lines> </class> - <class name="layer.py" filename="layer.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="route.py" filename="route.py" complexity="0" line-rate="0.8529" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> + <line number="2" hits="1"/> <line number="3" hits="1"/> + <line number="5" hits="1"/> + <line number="7" hits="1"/> <line number="8" hits="1"/> <line number="9" hits="1"/> <line number="10" hits="1"/> @@ -125,87 +345,63 @@ <line number="16" hits="1"/> <line number="18" hits="1"/> <line number="19" hits="1"/> + <line number="20" hits="1"/> <line number="21" hits="1"/> - <line number="22" hits="1"/> - <line number="23" hits="1"/> - <line number="24" hits="1"/> + <line number="22" hits="0"/> + <line number="25" hits="0"/> <line number="26" hits="1"/> <line number="28" hits="1"/> + <line number="29" hits="1"/> <line number="30" hits="1"/> - <line number="32" hits="1"/> <line number="33" hits="1"/> - <line number="34" hits="1"/> <line number="35" hits="1"/> - <line number="36" hits="1"/> - <line number="37" hits="1"/> - <line number="39" hits="1"/> + <line number="36" hits="0"/> + <line number="37" hits="0"/> + <line number="38" hits="0"/> <line number="41" hits="1"/> + <line number="42" hits="1"/> <line number="43" hits="1"/> - <line number="46" hits="1"/> + <line number="44" hits="1"/> + <line number="45" hits="1"/> <line number="47" hits="1"/> - <line number="49" hits="1"/> - <line number="50" hits="1"/> - <line number="52" hits="1"/> - <line number="53" hits="1"/> - <line number="55" hits="1"/> + <line number="48" hits="1"/> + <line number="50" hits="0"/> + <line number="53" hits="0"/> <line number="56" hits="1"/> - <line number="58" hits="1"/> + <line number="57" hits="1"/> + <line number="59" hits="1"/> <line number="60" hits="1"/> + <line number="62" hits="1"/> <line number="63" hits="1"/> + <line number="64" hits="1"/> <line number="66" hits="1"/> <line number="67" hits="1"/> - <line number="68" hits="1"/> + <line number="71" hits="1"/> + <line number="72" hits="1"/> + <line number="73" hits="1"/> <line number="75" hits="1"/> <line number="76" hits="1"/> + <line number="78" hits="1"/> + <line number="79" hits="1"/> + <line number="81" hits="1"/> <line number="82" hits="1"/> + <line number="83" hits="1"/> + <line number="85" hits="0"/> <line number="87" hits="1"/> <line number="88" hits="1"/> - <line number="90" hits="1"/> - <line number="92" hits="1"/> - <line number="93" hits="1"/> - <line number="95" hits="1"/> - <line number="96" hits="1"/> - <line number="98" hits="1"/> - <line number="99" hits="1"/> - <line number="101" hits="1"/> - <line number="102" hits="1"/> - <line number="103" hits="1"/> - <line number="104" hits="1"/> - <line number="105" hits="1"/> - <line number="106" hits="1"/> - <line number="107" hits="1"/> - <line number="109" hits="1"/> - <line number="112" hits="1"/> - <line number="113" hits="1"/> + <line number="89" hits="1"/> + <line number="91" hits="1"/> <line number="116" hits="1"/> <line number="117" hits="1"/> - <line number="119" hits="1"/> + <line number="118" hits="0"/> <line number="120" hits="1"/> <line number="122" hits="1"/> - <line number="123" hits="1"/> <line number="124" hits="1"/> - <line number="126" hits="1"/> - </lines> - </class> - <class name="linear.py" filename="linear.py" complexity="0" line-rate="1" branch-rate="0"> - <methods/> - <lines> - <line number="1" hits="1"/> - <line number="3" hits="1"/> - <line number="4" hits="1"/> - <line number="7" hits="1"/> - <line number="18" hits="1"/> - <line number="19" hits="1"/> - <line number="20" hits="1"/> - <line number="21" hits="1"/> - <line number="24" hits="1"/> - <line number="26" hits="1"/> - <line number="27" hits="1"/> - <line number="28" hits="1"/> - <line number="30" hits="1"/> + <line number="125" hits="1"/> + <line number="126" hits="0"/> </lines> </class> - <class name="schema.py" filename="schema.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="schema.py" filename="schema.py" complexity="0" line-rate="0.8929" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -219,38 +415,28 @@ <line number="19" hits="1"/> <line number="20" hits="1"/> <line number="21" hits="1"/> - <line number="22" hits="1"/> + <line number="24" hits="1"/> <line number="25" hits="1"/> <line number="26" hits="1"/> <line number="27" hits="1"/> <line number="28" hits="1"/> - <line number="29" hits="1"/> + <line number="30" hits="1"/> <line number="31" hits="1"/> <line number="32" hits="1"/> <line number="33" hits="1"/> <line number="34" hits="1"/> <line number="35" hits="1"/> <line number="36" hits="1"/> - <line number="37" hits="1"/> - <line number="38" hits="1"/> - <line number="39" hits="1"/> - <line number="41" hits="1"/> + <line number="37" hits="0"/> + <line number="38" hits="0"/> + <line number="40" hits="0"/> <line number="42" hits="1"/> - <line number="45" hits="1"/> - <line number="46" hits="1"/> - <line number="47" hits="1"/> - <line number="48" hits="1"/> - <line number="49" hits="1"/> - <line number="51" hits="1"/> - <line number="52" hits="1"/> - <line number="53" hits="1"/> - <line number="55" hits="1"/> - <line number="56" hits="1"/> + <line number="43" hits="1"/> </lines> </class> </classes> </package> - <package name="encoders" line-rate="1" branch-rate="0" complexity="0"> + <package name="encoders" line-rate="0.812" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="encoders/__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> @@ -268,13 +454,14 @@ <line number="1" hits="1"/> <line number="4" hits="1"/> <line number="5" hits="1"/> - <line number="7" hits="1"/> + <line number="6" hits="1"/> <line number="8" hits="1"/> - <line number="10" hits="1"/> + <line number="9" hits="1"/> <line number="11" hits="1"/> + <line number="12" hits="1"/> </lines> </class> - <class name="bm25.py" filename="encoders/bm25.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="bm25.py" filename="encoders/bm25.py" complexity="0" line-rate="0.4865" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -283,39 +470,40 @@ <line number="8" hits="1"/> <line number="9" hits="1"/> <line number="10" hits="1"/> - <line number="12" hits="1"/> + <line number="11" hits="1"/> <line number="13" hits="1"/> <line number="14" hits="1"/> - <line number="16" hits="1"/> + <line number="15" hits="1"/> <line number="17" hits="1"/> <line number="18" hits="1"/> <line number="19" hits="1"/> <line number="20" hits="1"/> - <line number="22" hits="1"/> - <line number="24" hits="1"/> + <line number="21" hits="1"/> + <line number="23" hits="1"/> <line number="25" hits="1"/> - <line number="26" hits="1"/> - <line number="27" hits="1"/> - <line number="28" hits="1"/> - <line number="29" hits="1"/> - <line number="30" hits="1"/> - <line number="32" hits="1"/> - <line number="34" hits="1"/> - <line number="35" hits="1"/> - <line number="36" hits="1"/> - <line number="37" hits="1"/> - <line number="38" hits="1"/> - <line number="39" hits="1"/> - <line number="40" hits="1"/> - <line number="41" hits="1"/> - <line number="42" hits="1"/> - <line number="44" hits="1"/> + <line number="26" hits="0"/> + <line number="27" hits="0"/> + <line number="28" hits="0"/> + <line number="29" hits="0"/> + <line number="30" hits="0"/> + <line number="31" hits="0"/> + <line number="33" hits="0"/> + <line number="35" hits="0"/> + <line number="36" hits="0"/> + <line number="37" hits="0"/> + <line number="38" hits="0"/> + <line number="39" hits="0"/> + <line number="40" hits="0"/> + <line number="41" hits="0"/> + <line number="42" hits="0"/> + <line number="43" hits="0"/> <line number="45" hits="1"/> - <line number="46" hits="1"/> - <line number="47" hits="1"/> + <line number="46" hits="0"/> + <line number="47" hits="0"/> + <line number="48" hits="0"/> </lines> </class> - <class name="cohere.py" filename="encoders/cohere.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="cohere.py" filename="encoders/cohere.py" complexity="0" line-rate="0.92" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -323,8 +511,8 @@ <line number="5" hits="1"/> <line number="8" hits="1"/> <line number="9" hits="1"/> - <line number="11" hits="1"/> - <line number="16" hits="1"/> + <line number="10" hits="1"/> + <line number="12" hits="1"/> <line number="17" hits="1"/> <line number="18" hits="1"/> <line number="19" hits="1"/> @@ -332,17 +520,20 @@ <line number="21" hits="1"/> <line number="22" hits="1"/> <line number="23" hits="1"/> - <line number="25" hits="1"/> - <line number="26" hits="1"/> - <line number="27" hits="1"/> + <line number="24" hits="1"/> + <line number="25" hits="0"/> + <line number="26" hits="0"/> <line number="28" hits="1"/> <line number="29" hits="1"/> <line number="30" hits="1"/> <line number="31" hits="1"/> <line number="32" hits="1"/> + <line number="33" hits="1"/> + <line number="34" hits="1"/> + <line number="35" hits="1"/> </lines> </class> - <class name="openai.py" filename="encoders/openai.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="openai.py" filename="encoders/openai.py" complexity="0" line-rate="0.9762" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -354,23 +545,23 @@ <line number="9" hits="1"/> <line number="12" hits="1"/> <line number="13" hits="1"/> - <line number="15" hits="1"/> - <line number="20" hits="1"/> + <line number="14" hits="1"/> + <line number="16" hits="1"/> <line number="21" hits="1"/> <line number="22" hits="1"/> <line number="23" hits="1"/> <line number="24" hits="1"/> <line number="25" hits="1"/> - <line number="26" hits="1"/> + <line number="26" hits="0"/> <line number="27" hits="1"/> + <line number="28" hits="1"/> <line number="29" hits="1"/> <line number="30" hits="1"/> - <line number="31" hits="1"/> <line number="32" hits="1"/> <line number="33" hits="1"/> + <line number="34" hits="1"/> + <line number="35" hits="1"/> <line number="36" hits="1"/> - <line number="37" hits="1"/> - <line number="38" hits="1"/> <line number="39" hits="1"/> <line number="40" hits="1"/> <line number="41" hits="1"/> @@ -381,20 +572,108 @@ <line number="46" hits="1"/> <line number="47" hits="1"/> <line number="48" hits="1"/> + <line number="49" hits="1"/> <line number="50" hits="1"/> - <line number="55" hits="1"/> + <line number="52" hits="1"/> <line number="57" hits="1"/> - <line number="58" hits="1"/> + <line number="59" hits="1"/> + <line number="60" hits="1"/> </lines> </class> </classes> </package> - <package name="utils" line-rate="1" branch-rate="0" complexity="0"> + <package name="utils" line-rate="0.3895" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="utils/__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> <lines/> </class> + <class name="function_call.py" filename="utils/function_call.py" complexity="0" line-rate="0.2258" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="2" hits="1"/> + <line number="3" hits="1"/> + <line number="5" hits="1"/> + <line number="7" hits="1"/> + <line number="8" hits="1"/> + <line number="11" hits="1"/> + <line number="12" hits="1"/> + <line number="13" hits="0"/> + <line number="14" hits="0"/> + <line number="15" hits="0"/> + <line number="16" hits="0"/> + <line number="18" hits="0"/> + <line number="19" hits="0"/> + <line number="20" hits="0"/> + <line number="24" hits="0"/> + <line number="26" hits="0"/> + <line number="27" hits="0"/> + <line number="28" hits="0"/> + <line number="34" hits="1"/> + <line number="40" hits="1"/> + <line number="43" hits="1"/> + <line number="44" hits="0"/> + <line number="46" hits="0"/> + <line number="75" hits="0"/> + <line number="76" hits="0"/> + <line number="77" hits="0"/> + <line number="79" hits="0"/> + <line number="81" hits="0"/> + <line number="82" hits="0"/> + <line number="83" hits="0"/> + <line number="84" hits="0"/> + <line number="87" hits="1"/> + <line number="89" hits="0"/> + <line number="91" hits="0"/> + <line number="92" hits="0"/> + <line number="93" hits="0"/> + <line number="94" hits="0"/> + <line number="98" hits="0"/> + <line number="99" hits="0"/> + <line number="100" hits="0"/> + <line number="101" hits="0"/> + <line number="102" hits="0"/> + <line number="103" hits="0"/> + <line number="104" hits="0"/> + <line number="105" hits="0"/> + <line number="108" hits="1"/> + <line number="109" hits="0"/> + <line number="110" hits="0"/> + <line number="111" hits="0"/> + <line number="112" hits="0"/> + <line number="116" hits="1"/> + <line number="117" hits="0"/> + <line number="118" hits="0"/> + <line number="119" hits="0"/> + <line number="120" hits="0"/> + <line number="122" hits="0"/> + <line number="123" hits="0"/> + <line number="124" hits="0"/> + <line number="125" hits="0"/> + <line number="126" hits="0"/> + <line number="127" hits="0"/> + </lines> + </class> + <class name="llm.py" filename="utils/llm.py" complexity="0" line-rate="0.2857" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="3" hits="1"/> + <line number="5" hits="1"/> + <line number="8" hits="1"/> + <line number="9" hits="0"/> + <line number="10" hits="0"/> + <line number="15" hits="0"/> + <line number="27" hits="0"/> + <line number="29" hits="0"/> + <line number="30" hits="0"/> + <line number="31" hits="0"/> + <line number="32" hits="0"/> + <line number="33" hits="0"/> + <line number="34" hits="0"/> + </lines> + </class> <class name="logger.py" filename="utils/logger.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> <lines> @@ -405,19 +684,18 @@ <line number="8" hits="1"/> <line number="23" hits="1"/> <line number="24" hits="1"/> + <line number="25" hits="1"/> <line number="26" hits="1"/> <line number="27" hits="1"/> - <line number="29" hits="1"/> + <line number="28" hits="1"/> + <line number="31" hits="1"/> + <line number="32" hits="1"/> + <line number="33" hits="1"/> <line number="35" hits="1"/> <line number="37" hits="1"/> + <line number="38" hits="1"/> <line number="40" hits="1"/> - <line number="41" hits="1"/> - <line number="42" hits="1"/> - <line number="44" hits="1"/> - <line number="46" hits="1"/> - <line number="47" hits="1"/> - <line number="49" hits="1"/> - <line number="52" hits="1"/> + <line number="43" hits="1"/> </lines> </class> </classes> diff --git a/dist/semantic_router-0.0.9-py3-none-any.whl b/dist/semantic_router-0.0.9-py3-none-any.whl deleted file mode 100644 index 2e1b60bd421d0e4a36c5facfa4ce5c5fab97468b..0000000000000000000000000000000000000000 Binary files a/dist/semantic_router-0.0.9-py3-none-any.whl and /dev/null differ diff --git a/dist/semantic_router-0.0.9.tar.gz b/dist/semantic_router-0.0.9.tar.gz deleted file mode 100644 index 5a241635bd99093f37f1b4c186a79ce1c579a346..0000000000000000000000000000000000000000 Binary files a/dist/semantic_router-0.0.9.tar.gz and /dev/null differ diff --git a/docs/00-introduction.ipynb b/docs/00-introduction.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..95222c2a7eedfa8807ef70adb902c9d82c81c457 --- /dev/null +++ b/docs/00-introduction.ipynb @@ -0,0 +1,266 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/00-introduction.ipynb) [](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/00-introduction.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Semantic Router Intro" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Semantic Router library can be used as a super fast route making layer on top of LLMs. That means rather than waiting on a slow agent to decide what to do, we can use the magic of semantic vector space to make routes. Cutting route making time down from seconds to milliseconds." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing the library:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU semantic-router==0.0.14" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_**âš ï¸ If using Google Colab, install the prerequisites and then restart the notebook before continuing**_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by defining a dictionary mapping routes to example phrases that should trigger those routes." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\",\n", + " \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define another for good measure:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we initialize our embedding model:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n", + "\n", + "# os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n", + "# \"Enter Cohere API Key: \"\n", + "# )\n", + "os.environ[\"OPENAI_API_KEY\"] = os.getenv(\"OPENAI_API_KEY\") or getpass(\n", + " \"Enter OpenAI API Key: \"\n", + ")\n", + "\n", + "# encoder = CohereEncoder()\n", + "encoder = OpenAIEncoder()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-28 19:14:34 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + ] + } + ], + "source": [ + "from semantic_router.layer import RouteLayer\n", + "\n", + "dl = RouteLayer(encoder=encoder, routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can test it:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='politics', function_call=None)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects?" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name=None, function_call=None)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl(\"I'm interested in learning about llama 2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we return `None` because no matches were identified." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/01-save-load-from-file.ipynb b/docs/01-save-load-from-file.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..6f084a9aedfb44e4a1c77c518584a664e60b86f4 --- /dev/null +++ b/docs/01-save-load-from-file.ipynb @@ -0,0 +1,282 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/01-save-load-from-file.ipynb) [](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/01-save-load-from-file.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Route Layers from File\n", + "\n", + "Here we will show how to save routers to YAML or JSON files, and how to load a route layer from file." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing the library:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU semantic-router==0.0.14" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_**âš ï¸ If using Google Colab, install the prerequisites and then restart the notebook before continuing**_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving to JSON" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First let's create a list of routes:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jamesbriggs/opt/anaconda3/envs/decision-layer/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n" + ] + } + ], + "source": [ + "from semantic_router import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\" \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")\n", + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define a route layer using these routes and using the default Cohere encoder." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-28 19:16:54 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + ] + } + ], + "source": [ + "import os\n", + "from getpass import getpass\n", + "from semantic_router import RouteLayer\n", + "\n", + "# dashboard.cohere.ai\n", + "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n", + " \"Enter Cohere API Key: \"\n", + ")\n", + "\n", + "layer = RouteLayer(routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To save our route layer we call the `to_json` method:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-28 19:17:03 INFO semantic_router.utils.logger Saving route config to layer.json\u001b[0m\n" + ] + } + ], + "source": [ + "layer.to_json(\"layer.json\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading from JSON" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can view the router file we just saved to see what information is stored." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'encoder_type': 'cohere', 'encoder_name': 'embed-english-v3.0', 'routes': [{'name': 'politics', 'utterances': [\"isn't politics the best thing ever\", \"why don't you tell me about your political opinions\", \"don't you just love the presidentdon't you just hate the president\", \"they're going to destroy this country!\", 'they will save the country!'], 'description': None, 'function_schema': None}, {'name': 'chitchat', 'utterances': [\"how's the weather today?\", 'how are things going?', 'lovely weather today', 'the weather is horrendous', \"let's go to the chippy\"], 'description': None, 'function_schema': None}]}\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "with open(\"layer.json\", \"r\") as f:\n", + " router_json = json.load(f)\n", + "\n", + "print(router_json)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It tells us our encoder type, encoder name, and routes. This is everything we need to initialize a new router. To do so, we use the `from_json` method." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-28 19:17:08 INFO semantic_router.utils.logger Loading route config from layer.json\u001b[0m\n", + "\u001b[32m2023-12-28 19:17:08 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + ] + } + ], + "source": [ + "layer = RouteLayer.from_json(\"layer.json\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can confirm that our layer has been initialized with the expected attributes by viewing the `RouteLayer` object:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "layer.encoder.type='cohere'\n", + "layer.encoder.name='embed-english-v3.0'\n", + "layer.routes=[Route(name='politics', utterances=[\"isn't politics the best thing ever\", \"why don't you tell me about your political opinions\", \"don't you just love the presidentdon't you just hate the president\", \"they're going to destroy this country!\", 'they will save the country!'], description=None, function_schema=None), Route(name='chitchat', utterances=[\"how's the weather today?\", 'how are things going?', 'lovely weather today', 'the weather is horrendous', \"let's go to the chippy\"], description=None, function_schema=None)]\n" + ] + } + ], + "source": [ + "print(\n", + " f\"\"\"{layer.encoder.type=}\n", + "{layer.encoder.name=}\n", + "{layer.routes=}\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2b17da17cbee79d186134144d93563bfce08da82 --- /dev/null +++ b/docs/02-dynamic-routes.ipynb @@ -0,0 +1,365 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/02-dynamic-routes.ipynb) [](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/02-dynamic-routes.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dynamic Routes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In semantic-router there are two types of routes that can be chosen. Both routes belong to the `Route` object, the only difference between them is that _static_ routes return a `Route.name` when chosen, whereas _dynamic_ routes use an LLM call to produce parameter input values.\n", + "\n", + "For example, a _static_ route will tell us if a query is talking about mathematics by returning the route name (which could be `\"math\"` for example). A _dynamic_ route can generate additional values, so it may decide a query is talking about maths, but it can also generate Python code that we can later execute to answer the user's query, this output may look like `\"math\", \"import math; output = math.sqrt(64)`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installing the Library" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU semantic-router==0.0.14" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_**âš ï¸ If using Google Colab, install the prerequisites and then restart the notebook before continuing**_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initializing Routes and RouteLayer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Dynamic routes are treated in the same way as static routes, let's begin by initializing a `RouteLayer` consisting of static routes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jamesbriggs/opt/anaconda3/envs/decision-layer/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n" + ] + } + ], + "source": [ + "from semantic_router import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\" \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")\n", + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-28 19:19:39 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + ] + } + ], + "source": [ + "import os\n", + "from getpass import getpass\n", + "from semantic_router import RouteLayer\n", + "\n", + "# dashboard.cohere.ai\n", + "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n", + " \"Enter Cohere API Key: \"\n", + ")\n", + "\n", + "layer = RouteLayer(routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We run the solely static routes layer:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a Dynamic Route" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As with static routes, we must create a dynamic route before adding it to our route layer. To make a route dynamic, we need to provide a `function_schema`. The function schema provides instructions on what a function is, so that an LLM can decide how to use it correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "from zoneinfo import ZoneInfo\n", + "\n", + "\n", + "def get_time(timezone: str) -> str:\n", + " \"\"\"Finds the current time in a specific timezone.\n", + "\n", + " :param timezone: The timezone to find the current time in, should\n", + " be a valid timezone from the IANA Time Zone Database like\n", + " \"America/New_York\" or \"Europe/London\".\n", + " :type timezone: str\n", + " :return: The current time in the specified timezone.\"\"\"\n", + " now = datetime.now(ZoneInfo(timezone))\n", + " return now.strftime(\"%H:%M\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'13:19'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_time(\"America/New_York\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To get the function schema we can use the `get_schema` function from the `function_call` module." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'name': 'get_time',\n", + " 'description': 'Finds the current time in a specific timezone.\\n\\n:param timezone: The timezone to find the current time in, should\\n be a valid timezone from the IANA Time Zone Database like\\n \"America/New_York\" or \"Europe/London\".\\n:type timezone: str\\n:return: The current time in the specified timezone.',\n", + " 'signature': '(timezone: str) -> str',\n", + " 'output': \"<class 'str'>\"}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from semantic_router.utils.function_call import get_schema\n", + "\n", + "schema = get_schema(get_time)\n", + "schema" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use this to define our dynamic route:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "time_route = Route(\n", + " name=\"get_time\",\n", + " utterances=[\n", + " \"what is the time in new york city?\",\n", + " \"what is the time in london?\",\n", + " \"I live in Rome, what time is it?\",\n", + " ],\n", + " function_schema=schema,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add the new route to our `layer`:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Adding route `get_time`\n", + "Adding route to categories\n", + "Adding route to index\n" + ] + } + ], + "source": [ + "layer.add(time_route)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can ask our layer a time related question to trigger our new dynamic route." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-28 19:21:58 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name='get_time', function_call={'timezone': 'America/New_York'})" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# https://openrouter.ai/keys\n", + "os.environ[\"OPENROUTER_API_KEY\"] = os.getenv(\"OPENROUTER_API_KEY\") or getpass(\n", + " \"Enter OpenRouter API Key: \"\n", + ")\n", + "\n", + "layer(\"what is the time in new york city?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/03-basic-langchain-agent.ipynb b/docs/03-basic-langchain-agent.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..09294c780e9f7ad60e20ae11f960e37954c1f98a --- /dev/null +++ b/docs/03-basic-langchain-agent.ipynb @@ -0,0 +1,804 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "pQNxYwHAA04v" + }, + "source": [ + "[](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/03-basic-langchain-agent.ipynb) [](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/03-basic-langchain-agent.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jatpBZYiA04w" + }, + "source": [ + "# Intro to LangChain Agents with Semantic Router" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3PEkUwwbA04w" + }, + "source": [ + "We can use semantic router with AI agents in many many ways. For example we can:\n", + "\n", + "* **Use routes to remind agents of particular information or routes** _(we will do this in this notebook)_.\n", + "* Use routes to act as protective guardrails against specific types of queries.\n", + "* Rather than relying on the slow decision making process of an agent with tools use semantic router to decide on tool usage _(similar to what we will do here)_.\n", + "* For tools that require generated inputs we can use semantic router's dynamic routes to generate tool input parameters.\n", + "* Use routes to decide when a search for additional information, to help us do RAG when needed as an alternative to native RAG (search with every query) or lengthy agent-based RAG decisions.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GkSlAOB2A04x" + }, + "source": [ + "## Install Prerequisites" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qSK8A_UdcbIR", + "outputId": "14dcbb34-5ece-41da-c4ad-d8e4351fc5b8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m794.4/794.4 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m225.4/225.4 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m51.7/51.7 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m18.2/18.2 MB\u001b[0m \u001b[31m23.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m1.5/1.5 MB\u001b[0m \u001b[31m29.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m192.4/192.4 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m46.7/46.7 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m75.9/75.9 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m62.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m49.4/49.4 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m76.9/76.9 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m18.2/18.2 MB\u001b[0m \u001b[31m28.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for wget (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "!pip install -qU \\\n", + " semantic-router==0.0.14 \\\n", + " langchain==0.0.352 \\\n", + " openai==1.6.1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cHtVnoJPA04x" + }, + "source": [ + "_**âš ï¸ If using Google Colab, install the prerequisites and then restart the notebook before continuing**_" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O_DvmsrcA04y" + }, + "source": [ + "## Setting up our Routes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pdeY5mpmrXQ8" + }, + "source": [ + "Let's create some routes that we can use to help our agent." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "Eeo5B1SttCJL" + }, + "outputs": [], + "source": [ + "from semantic_router import Route\n", + "\n", + "time_route = Route(\n", + " name=\"get_time\",\n", + " utterances=[\n", + " \"what time is it?\",\n", + " \"when should I eat my next meal?\",\n", + " \"how long should I rest until training again?\",\n", + " \"when should I go to the gym?\",\n", + " ],\n", + ")\n", + "\n", + "supplement_route = Route(\n", + " name=\"supplement_brand\",\n", + " utterances=[\n", + " \"what do you think of Optimum Nutrition?\",\n", + " \"what should I buy from MyProtein?\",\n", + " \"what brand for supplements would you recommend?\",\n", + " \"where should I get my whey protein?\",\n", + " ],\n", + ")\n", + "\n", + "business_route = Route(\n", + " name=\"business_inquiry\",\n", + " utterances=[\n", + " \"how much is an hour training session?\",\n", + " \"do you do package discounts?\",\n", + " ],\n", + ")\n", + "\n", + "product_route = Route(\n", + " name=\"product\",\n", + " utterances=[\n", + " \"do you have a website?\",\n", + " \"how can I find more info about your services?\",\n", + " \"where do I sign up?\",\n", + " \"how do I get hench?\",\n", + " \"do you have recommended training programmes?\",\n", + " ],\n", + ")\n", + "\n", + "routes = [time_route, supplement_route, business_route, product_route]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "frZ4wVnTA04y" + }, + "source": [ + "We will be using the `OpenAIEncoder`:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_0uCJ9fvoX2J", + "outputId": "34c3e957-b791-4759-8484-6830c25b0ff5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Enter OpenAI API Key: ··········\n" + ] + } + ], + "source": [ + "import os\n", + "from getpass import getpass\n", + "\n", + "# platform.openai.com\n", + "os.environ[\"OPENAI_API_KEY\"] = os.getenv(\"OPENAI_API_KEY\") or getpass(\n", + " \"Enter OpenAI API Key: \"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UDucUOMIpcTd", + "outputId": "9839c8a0-3eb5-45a3-d066-5e0a6b851a92" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-28 20:01:47 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + ] + } + ], + "source": [ + "from semantic_router import RouteLayer\n", + "from semantic_router.encoders import OpenAIEncoder\n", + "\n", + "layer = RouteLayer(encoder=OpenAIEncoder(), routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IJ_deXqB4XeU" + }, + "source": [ + "Let's test these routes to see if they get activated when we would expect." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FVsRuqAG4bOE", + "outputId": "e0f8ea5b-a108-47a0-d806-545304569914" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='supplement_brand', function_call=None)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer(\"should I buy ON whey or MP?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CYHDyqsm4ixV", + "outputId": "a3d28cef-d076-4a91-a684-7b977bd176ea" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name=None, function_call=None)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer(\"how's the weather today?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XMbGRdNo4lb0", + "outputId": "a53a4de0-aace-40b3-896d-3ef58464876d" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='product', function_call=None)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer(\"how do I get big arms?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OtCQcZx82cZ0" + }, + "source": [ + "Now we need to link these routes to particular actions or information that we pass to our agent." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "rYzm3hCpuj1V" + }, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "\n", + "\n", + "def get_time():\n", + " now = datetime.now()\n", + " return (\n", + " f\"The current time is {now.strftime('%H:%M')}, use \"\n", + " \"this information in your response\"\n", + " )\n", + "\n", + "\n", + "def supplement_brand():\n", + " return (\n", + " \"Remember you are not affiliated with any supplement \"\n", + " \"brands, you have your own brand 'BigAI' that sells \"\n", + " \"the best products like P100 whey protein\"\n", + " )\n", + "\n", + "\n", + "def business_inquiry():\n", + " return (\n", + " \"Your training company, 'BigAI PT', provides premium \"\n", + " \"quality training sessions at just $700 / hour. \"\n", + " \"Users can find out more at www.aurelio.ai/train\"\n", + " )\n", + "\n", + "\n", + "def product():\n", + " return (\n", + " \"Remember, users can sign up for a fitness programme \"\n", + " \"at www.aurelio.ai/sign-up\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SGSE5yBh5-_I" + }, + "source": [ + "Now we just add some logic to call this functions when we see a particular route being chosen." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "Hq26gdCO6Hjt" + }, + "outputs": [], + "source": [ + "def semantic_layer(query: str):\n", + " route = layer(query)\n", + " if route.name == \"get_time\":\n", + " query += f\" (SYSTEM NOTE: {get_time()})\"\n", + " elif route.name == \"supplement_brand\":\n", + " query += f\" (SYSTEM NOTE: {supplement_brand()})\"\n", + " elif route.name == \"business_inquiry\":\n", + " query += f\" (SYSTEM NOTE: {business_inquiry()})\"\n", + " elif route.name == \"product\":\n", + " query += f\" (SYSTEM NOTE: {product()})\"\n", + " else:\n", + " pass\n", + " return query" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 70 + }, + "id": "ELIPfxWR6zxx", + "outputId": "ab1f8e64-197b-4a41-dc85-62d15c531722" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "\"should I buy ON whey or MP? (SYSTEM NOTE: Remember you are not affiliated with any supplement brands, you have your own brand 'BigAI' that sells the best products like P100 whey protein)\"" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = \"should I buy ON whey or MP?\"\n", + "sr_query = semantic_layer(query)\n", + "sr_query" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L6m7vayuA04z" + }, + "source": [ + "## Using an Agent with a Router Layer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KbMkrMy3f7Hy" + }, + "source": [ + "Initialize a conversational LangChain agent." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "b95rWEU9f6jP" + }, + "outputs": [], + "source": [ + "from langchain.agents import AgentType, initialize_agent\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.memory import ConversationBufferWindowMemory\n", + "\n", + "llm = ChatOpenAI(openai_api_key=\"\", model=\"gpt-3.5-turbo-1106\")\n", + "\n", + "memory1 = ConversationBufferWindowMemory(\n", + " memory_key=\"chat_history\", k=5, return_messages=True, output_key=\"output\"\n", + ")\n", + "memory2 = ConversationBufferWindowMemory(\n", + " memory_key=\"chat_history\", k=5, return_messages=True, output_key=\"output\"\n", + ")\n", + "\n", + "agent = initialize_agent(\n", + " agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,\n", + " tools=[],\n", + " llm=llm,\n", + " max_iterations=3,\n", + " early_stopping_method=\"generate\",\n", + " memory=memory1,\n", + ")\n", + "\n", + "# update the system prompt\n", + "system_message = \"\"\"You are a helpful personal trainer working to help users on\n", + "their health and fitness journey. Although you are lovely and helpful, you are\n", + "rather sarcastic and witty. So you must always remember to joke with the user.\n", + "\n", + "Alongside your time , you are a noble British gentleman, so you must always act with the\n", + "utmost candor and speak in a way worthy of your status.\n", + "\n", + "Finally, remember to read the SYSTEM NOTES provided with user queries, they provide\n", + "additional useful information.\"\"\"\n", + "\n", + "new_prompt = agent.agent.create_prompt(system_message=system_message, tools=[])\n", + "agent.agent.llm_chain.prompt = new_prompt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6rX31EHvW_Y2" + }, + "source": [ + "Now we try calling our agent using the default `query` and compare the result to calling it with our router augmented `sr_query`." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Z247I6J47IeS", + "outputId": "d1637cb3-9941-4b77-f22c-c1a269f96a4f" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input': 'should I buy ON whey or MP?',\n", + " 'chat_history': [],\n", + " 'output': \"Well, it depends. Do you prefer your whey with a side of 'ON' or 'MP'? Just kidding! It really depends on your personal taste and nutritional needs. Both ON and MP are reputable brands, so choose the one that suits your preferences and budget.\"}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LtDswMSzX3-O", + "outputId": "47e00e59-6f23-4165-cfc7-e54646d9666b" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input': \"should I buy ON whey or MP? (SYSTEM NOTE: Remember you are not affiliated with any supplement brands, you have your own brand 'BigAI' that sells the best products like P100 whey protein)\",\n", + " 'chat_history': [],\n", + " 'output': \"Why not try the BigAI P100 whey protein? It's the best, just like me.\"}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# swap agent memory first\n", + "agent.memory = memory2\n", + "agent(sr_query)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WxfSm9WoZqbp" + }, + "source": [ + "Adding this reminder allows us to get much more intentional responses — while also unintentionally improving the LLMs following of our original instructions to act as a British gentleman.\n", + "\n", + "Let's try some more!" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 52 + }, + "id": "IZ6CVd6jaLE7", + "outputId": "da18f11c-4c5a-4baf-e4c7-66858604d2ca" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'okay, I just finished training, what time should I train again? (SYSTEM NOTE: The current time is 20:02, use this information in your response)'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = \"okay, I just finished training, what time should I train again?\"\n", + "sr_query = semantic_layer(query)\n", + "sr_query" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "S80wYJtfaLLO", + "outputId": "653e1eb2-f87a-46fb-c24c-0df5728f264a" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input': 'okay, I just finished training, what time should I train again?',\n", + " 'chat_history': [HumanMessage(content='should I buy ON whey or MP?'),\n", + " AIMessage(content=\"Well, it depends. Do you prefer your whey with a side of 'ON' or 'MP'? Just kidding! It really depends on your personal taste and nutritional needs. Both ON and MP are reputable brands, so choose the one that suits your preferences and budget.\")],\n", + " 'output': \"It's generally recommended to allow at least 48 hours of rest for the same muscle group before training it again. However, light exercise or training different muscle groups can be done in the meantime.\"}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.memory = memory1\n", + "agent(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "x7YSI8TOcvzN", + "outputId": "e42e87d0-7e46-40fd-e9f2-e8d334454a82" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input': 'okay, I just finished training, what time should I train again? (SYSTEM NOTE: The current time is 20:02, use this information in your response)',\n", + " 'chat_history': [HumanMessage(content=\"should I buy ON whey or MP? (SYSTEM NOTE: Remember you are not affiliated with any supplement brands, you have your own brand 'BigAI' that sells the best products like P100 whey protein)\"),\n", + " AIMessage(content=\"Why not try the BigAI P100 whey protein? It's the best, just like me.\")],\n", + " 'output': \"Why not train again at 20:02 tomorrow? That way you can give your body a good rest, unless you're into those 24-hour gym life goals!\"}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.memory = memory2\n", + "agent(sr_query)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6b3BM9ZOeVa2" + }, + "source": [ + "Let's try another..." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 70 + }, + "id": "wzwPUtA8eld2", + "outputId": "b4fcbbb3-5a4b-46fa-b777-531ca0942a2b" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "\"okay fine, do you do training sessions, how much are they? (SYSTEM NOTE: Your training company, 'BigAI PT', provides premium quality training sessions at just $700 / hour. Users can find out more at www.aurelio.ai/train)\"" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = \"okay fine, do you do training sessions, how much are they?\"\n", + "sr_query = semantic_layer(query)\n", + "sr_query" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RMfDticWebHy", + "outputId": "917789e7-609f-41ed-ee7f-7e7cba035a10" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input': 'okay fine, do you do training sessions, how much are they?',\n", + " 'chat_history': [HumanMessage(content='should I buy ON whey or MP?'),\n", + " AIMessage(content=\"Well, it depends. Do you prefer your whey with a side of 'ON' or 'MP'? Just kidding! It really depends on your personal taste and nutritional needs. Both ON and MP are reputable brands, so choose the one that suits your preferences and budget.\"),\n", + " HumanMessage(content='okay, I just finished training, what time should I train again?'),\n", + " AIMessage(content=\"It's generally recommended to allow at least 48 hours of rest for the same muscle group before training it again. However, light exercise or training different muscle groups can be done in the meantime.\")],\n", + " 'output': \"I'm here to provide guidance and support, not personal training sessions. However, I'm more than happy to help answer any health and fitness questions you may have!\"}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.memory = memory1\n", + "agent(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "90vJpLCOfMrN", + "outputId": "06a4c00a-1131-4f0a-b010-b5a1fee8fe8d" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input': \"okay fine, do you do training sessions, how much are they? (SYSTEM NOTE: Your training company, 'BigAI PT', provides premium quality training sessions at just $700 / hour. Users can find out more at www.aurelio.ai/train)\",\n", + " 'chat_history': [HumanMessage(content=\"should I buy ON whey or MP? (SYSTEM NOTE: Remember you are not affiliated with any supplement brands, you have your own brand 'BigAI' that sells the best products like P100 whey protein)\"),\n", + " AIMessage(content=\"Why not try the BigAI P100 whey protein? It's the best, just like me.\"),\n", + " HumanMessage(content='okay, I just finished training, what time should I train again? (SYSTEM NOTE: The current time is 20:02, use this information in your response)'),\n", + " AIMessage(content=\"Why not train again at 20:02 tomorrow? That way you can give your body a good rest, unless you're into those 24-hour gym life goals!\")],\n", + " 'output': \"Why, of course! BigAI PT offers premium training sessions at just $700 per hour. For more information, visit www.aurelio.ai/train. Now, let's get that workout plan sorted, shall we?\"}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.memory = memory2\n", + "agent(sr_query)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TFhzwwCVe0J5" + }, + "source": [ + " What we see here is a small demo example of how we might use semantic router with a language agent. However, they can be used together in far more sophisticated ways.\n", + "\n", + " ---" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index f487f77daaa9c00efe87411c0e32c06adb6f067a..d082468b997726bc51e12f1e1d9f6b32ffe52697 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -1,647 +1,332 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define LLMs" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "# OpenAI\n", - "import openai\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "# Docs # https://platform.openai.com/docs/guides/function-calling\n", - "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n", - " try:\n", - " logger.info(f\"Calling {model} model\")\n", - " response = openai.chat.completions.create(\n", - " model=model,\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": f\"{prompt}\"},\n", - " ],\n", - " )\n", - " ai_message = response.choices[0].message.content\n", - " if not ai_message:\n", - " raise Exception(\"AI message is empty\", ai_message)\n", - " logger.info(f\"AI message: {ai_message}\")\n", - " return ai_message\n", - " except Exception as e:\n", - " raise Exception(\"Failed to call OpenAI API\", e)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "# Mistral\n", - "import os\n", - "import requests\n", - "\n", - "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n", - "HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n", - "\n", - "\n", - "def llm_mistral(prompt: str) -> str:\n", - " api_url = \"https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/\"\n", - " headers = {\n", - " \"Authorization\": f\"Bearer {HF_API_TOKEN}\",\n", - " \"Content-Type\": \"application/json\",\n", - " }\n", - "\n", - " logger.info(\"Calling Mistral model\")\n", - " response = requests.post(\n", - " api_url,\n", - " headers=headers,\n", - " json={\n", - " \"inputs\": f\"You are a helpful assistant, user query: {prompt}\",\n", - " \"parameters\": {\n", - " \"max_new_tokens\": 200,\n", - " \"temperature\": 0.01,\n", - " \"num_beams\": 5,\n", - " \"num_return_sequences\": 1,\n", - " },\n", - " },\n", - " )\n", - " if response.status_code != 200:\n", - " raise Exception(\"Failed to call HuggingFace API\", response.text)\n", - "\n", - " ai_message = response.json()[0][\"generated_text\"]\n", - " if not ai_message:\n", - " raise Exception(\"AI message is empty\", ai_message)\n", - " logger.info(f\"AI message: {ai_message}\")\n", - " return ai_message" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Now we need to generate config from function schema using LLM" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "import inspect\n", - "from typing import Any\n", - "\n", - "\n", - "def get_function_schema(function) -> dict[str, Any]:\n", - " schema = {\n", - " \"name\": function.__name__,\n", - " \"description\": str(inspect.getdoc(function)),\n", - " \"signature\": str(inspect.signature(function)),\n", - " \"output\": str(\n", - " inspect.signature(function).return_annotation,\n", - " ),\n", - " }\n", - " return schema" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "\n", - "def is_valid_config(route_config_str: str) -> bool:\n", - " try:\n", - " output_json = json.loads(route_config_str)\n", - " return all(key in output_json for key in [\"name\", \"utterances\"])\n", - " except json.JSONDecodeError:\n", - " return False" - ] - }, + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set up functions and routes" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def get_time(location: str) -> str:\n", + " \"\"\"Useful to get the time in a specific location\"\"\"\n", + " print(f\"Result from: `get_time` function with location: `{location}`\")\n", + " return \"get_time\"\n", + "\n", + "\n", + "def get_news(category: str, country: str) -> str:\n", + " \"\"\"Useful to get the news in a specific country\"\"\"\n", + " print(\n", + " f\"Result from: `get_news` function with category: `{category}` \"\n", + " f\"and country: `{country}`\"\n", + " )\n", + " return \"get_news\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now generate a dynamic routing config for each function" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "def generate_route(function) -> dict:\n", - " logger.info(\"Generating config...\")\n", - "\n", - " function_schema = get_function_schema(function)\n", - "\n", - " prompt = f\"\"\"\n", - " You are tasked to generate a JSON configuration based on the provided\n", - " function schema. Please follow the template below:\n", - "\n", - " {{\n", - " \"name\": \"<function_name>\",\n", - " \"utterances\": [\n", - " \"<example_utterance_1>\",\n", - " \"<example_utterance_2>\",\n", - " \"<example_utterance_3>\",\n", - " \"<example_utterance_4>\",\n", - " \"<example_utterance_5>\"]\n", - " }}\n", - "\n", - " Only include the \"name\" and \"utterances\" keys in your answer.\n", - " The \"name\" should match the function name and the \"utterances\"\n", - " should comprise a list of 5 example phrases that could be used to invoke\n", - " the function.\n", - "\n", - " Input schema:\n", - " {function_schema}\n", - " \"\"\"\n", - "\n", - " try:\n", - " ai_message = llm_mistral(prompt)\n", - "\n", - " # Parse the response\n", - " ai_message = ai_message[ai_message.find(\"{\") :]\n", - " ai_message = (\n", - " ai_message.replace(\"'\", '\"')\n", - " .replace('\"s', \"'s\")\n", - " .strip()\n", - " .rstrip(\",\")\n", - " .replace(\"}\", \"}\")\n", - " )\n", - "\n", - " valid_config = is_valid_config(ai_message)\n", - "\n", - " if not valid_config:\n", - " logger.warning(f\"Mistral failed with error, falling back to OpenAI\")\n", - " ai_message = llm_openai(prompt)\n", - " if not is_valid_config(ai_message):\n", - " raise Exception(\"Invalid config generated\")\n", - " except Exception as e:\n", - " logger.error(f\"Fall back to OpenAI failed with error {e}\")\n", - " ai_message = llm_openai(prompt)\n", - " if not is_valid_config(ai_message):\n", - " raise Exception(\"Failed to generate config\")\n", - "\n", - " try:\n", - " route_config = json.loads(ai_message)\n", - " logger.info(f\"Generated config: {route_config}\")\n", - " return route_config\n", - " except json.JSONDecodeError as json_error:\n", - " logger.error(f\"JSON parsing error {json_error}\")\n", - " print(f\"AI message: {ai_message}\")\n", - " return {\"error\": \"Failed to generate config\"}" - ] - }, + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-20 12:21:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generated route config:\n", + "{\n", + " \"name\": \"get_time\",\n", + " \"utterances\": [\n", + " \"What's the time in New York?\",\n", + " \"Can you tell me the time in Tokyo?\",\n", + " \"What's the current time in London?\",\n", + " \"Can you give me the time in Sydney?\",\n", + " \"What's the time in Paris?\"\n", + " ]\n", + "}\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:38 INFO semantic_router.utils.logger Generated route config:\n", + "{\n", + " \"name\": \"get_news\",\n", + " \"utterances\": [\n", + " \"Tell me the latest news from the United States\",\n", + " \"What's happening in India today?\",\n", + " \"Can you give me the top stories from Japan\",\n", + " \"Get me the breaking news from the UK\",\n", + " \"What's the latest in Germany?\"\n", + " ]\n", + "}\u001b[0m\n", + "/var/folders/gf/cvm58m_x6pvghy227n5cmx5w0000gn/T/ipykernel_65737/1850296463.py:10: RuntimeWarning: coroutine 'Route.from_dynamic_route' was never awaited\n", + " route_config = RouteConfig(routes=routes)\n", + "RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n" + ] + } + ], + "source": [ + "from semantic_router.route import Route, RouteConfig\n", + "\n", + "functions = [get_time, get_news]\n", + "routes = []\n", + "\n", + "for function in functions:\n", + " route = await Route.from_dynamic_route(entity=function)\n", + " routes.append(route)\n", + "\n", + "route_config = RouteConfig(routes=routes)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Extract function parameters using `Mistral` open-source model" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Added route `get_weather`\u001b[0m\n", + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Removed route `get_weather`\u001b[0m\n" + ] }, { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "def validate_parameters(function, parameters):\n", - " sig = inspect.signature(function)\n", - " for name, param in sig.parameters.items():\n", - " if name not in parameters:\n", - " return False, f\"Parameter {name} missing from query\"\n", - " if not isinstance(parameters[name], param.annotation):\n", - " return False, f\"Parameter {name} is not of type {param.annotation}\"\n", - " return True, \"Parameters are valid\"" + "data": { + "text/plain": [ + "[{'name': 'get_time',\n", + " 'utterances': [\"What's the time in New York?\",\n", + " 'Can you tell me the time in Tokyo?',\n", + " \"What's the current time in London?\",\n", + " 'Can you give me the time in Sydney?',\n", + " \"What's the time in Paris?\"],\n", + " 'description': None},\n", + " {'name': 'get_news',\n", + " 'utterances': ['Tell me the latest news from the United States',\n", + " \"What's happening in India today?\",\n", + " 'Can you give me the top stories from Japan',\n", + " 'Get me the breaking news from the UK',\n", + " \"What's the latest in Germany?\"],\n", + " 'description': None}]" ] - }, + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# You can manually add or remove routes\n", + "\n", + "get_weather_route = Route(\n", + " name=\"get_weather\",\n", + " utterances=[\n", + " \"what is the weather in SF\",\n", + " \"what is the current temperature in London?\",\n", + " \"tomorrow's weather in Paris?\",\n", + " ],\n", + ")\n", + "route_config.add(get_weather_route)\n", + "\n", + "route_config.remove(\"get_weather\")\n", + "\n", + "route_config.to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "def extract_parameters(query: str, function) -> dict:\n", - " logger.info(\"Extracting parameters...\")\n", - " example_query = \"How is the weather in Hawaii right now in International units?\"\n", - "\n", - " example_schema = {\n", - " \"name\": \"get_weather\",\n", - " \"description\": \"Useful to get the weather in a specific location\",\n", - " \"signature\": \"(location: str, degree: str) -> str\",\n", - " \"output\": \"<class 'str'>\",\n", - " }\n", - "\n", - " example_parameters = {\n", - " \"location\": \"London\",\n", - " \"degree\": \"Celsius\",\n", - " }\n", - "\n", - " prompt = f\"\"\"\n", - " You are a helpful assistant designed to output JSON.\n", - " Given the following function schema\n", - " << {get_function_schema(function)} >>\n", - " and query\n", - " << {query} >>\n", - " extract the parameters values from the query, in a valid JSON format.\n", - " Example:\n", - " Input:\n", - " query: {example_query}\n", - " schema: {example_schema}\n", - "\n", - " Result: {example_parameters}\n", - "\n", - " Input:\n", - " query: {query}\n", - " schema: {get_function_schema(function)}\n", - " Result:\n", - " \"\"\"\n", - "\n", - " try:\n", - " ai_message = llm_mistral(prompt)\n", - " ai_message = (\n", - " ai_message.replace(\"Output:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", - " )\n", - " except Exception as e:\n", - " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", - " ai_message = llm_openai(prompt)\n", - "\n", - " try:\n", - " parameters = json.loads(ai_message)\n", - " valid, message = validate_parameters(function, parameters)\n", - "\n", - " if not valid:\n", - " logger.warning(\n", - " f\"Invalid parameters from Mistral, falling back to OpenAI: {message}\"\n", - " )\n", - " # Fall back to OpenAI\n", - " ai_message = llm_openai(prompt)\n", - " parameters = json.loads(ai_message)\n", - " valid, message = validate_parameters(function, parameters)\n", - " if not valid:\n", - " raise ValueError(message)\n", - "\n", - " logger.info(f\"Extracted parameters: {parameters}\")\n", - " return parameters\n", - " except ValueError as e:\n", - " logger.error(f\"Parameter validation error: {str(e)}\")\n", - " return {\"error\": \"Failed to validate parameters\"}" + "data": { + "text/plain": [ + "Route(name='get_time', utterances=[\"What's the time in New York?\", 'Can you tell me the time in Tokyo?', \"What's the current time in London?\", 'Can you give me the time in Sydney?', \"What's the time in Paris?\"], description=None)" ] - }, + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Get a route by name\n", + "route_config.get(\"get_time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Save config to a file (.json or .yaml)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Set up the routing layer" - ] - }, + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Saving route config to route_config.json\u001b[0m\n" + ] + } + ], + "source": [ + "route_config.to_file(\"route_config.json\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define routing layer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load from local file" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.schemas.route import Route\n", - "from semantic_router.encoders import CohereEncoder\n", - "from semantic_router.layer import RouteLayer\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "def create_router(routes: list[dict]) -> RouteLayer:\n", - " logger.info(\"Creating route layer...\")\n", - " encoder = CohereEncoder()\n", - "\n", - " route_list: list[Route] = []\n", - " for route in routes:\n", - " if \"name\" in route and \"utterances\" in route:\n", - " print(f\"Route: {route}\")\n", - " route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", - " else:\n", - " logger.warning(f\"Misconfigured route: {route}\")\n", - "\n", - " return RouteLayer(encoder=encoder, routes=route_list)" - ] - }, + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Loading route config from route_config.json\u001b[0m\n" + ] + } + ], + "source": [ + "from semantic_router.route import RouteConfig\n", + "\n", + "route_config = RouteConfig.from_file(\"route_config.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router import RouteLayer\n", + "\n", + "route_layer = RouteLayer(routes=route_config.routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Do a function call with functions as tool" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Set up calling functions" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" + ] }, { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Callable\n", - "from semantic_router.layer import RouteLayer\n", - "\n", - "\n", - "def call_function(function: Callable, parameters: dict[str, str]):\n", - " try:\n", - " return function(**parameters)\n", - " except TypeError as e:\n", - " logger.error(f\"Error calling function: {e}\")\n", - "\n", - "\n", - "def call_llm(query: str) -> str:\n", - " try:\n", - " ai_message = llm_mistral(query)\n", - " except Exception as e:\n", - " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", - " ai_message = llm_openai(query)\n", - "\n", - " return ai_message\n", - "\n", - "\n", - "def call(query: str, functions: list[Callable], router: RouteLayer):\n", - " function_name = router(query)\n", - " if not function_name:\n", - " logger.warning(\"No function found\")\n", - " return call_llm(query)\n", - "\n", - " for function in functions:\n", - " if function.__name__ == function_name:\n", - " parameters = extract_parameters(query, function)\n", - " print(f\"parameters: {parameters}\")\n", - " return call_function(function, parameters)" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Calling function: get_time\n", + "Result from: `get_time` function with location: `Stockholm`\n" + ] }, { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Workflow" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 17:46:49 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" + ] }, { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[31m2023-12-18 12:18:00 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:00 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger AI message: {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"what is the time in new york\",\n", - " \"can you tell me the time in london\",\n", - " \"get me the current time in tokyo\",\n", - " \"i need to know the time in sydney\",\n", - " \"please tell me the current time in paris\"\n", - " ]\n", - "}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[31m2023-12-18 12:18:07 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:07 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger AI message: {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"Can I get the latest news in Canada?\",\n", - " \"Show me the recent news in the US\",\n", - " \"I would like to know about the sports news in England\",\n", - " \"Let's check the technology news in Japan\",\n", - " \"Show me the health related news in Germany\"\n", - " ]\n", - "}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Route: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\n", - "Route: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\n" - ] - } - ], - "source": [ - "def get_time(location: str) -> str:\n", - " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", - " return \"get_time\"\n", - "\n", - "\n", - "def get_news(category: str, country: str) -> str:\n", - " \"\"\"Useful to get the news in a specific country\"\"\"\n", - " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", - " )\n", - " return \"get_news\"\n", - "\n", - "\n", - "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", - "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", - "\n", - "# Tools\n", - "tools = [get_time, get_news]" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Calling function: get_news\n", + "Result from: `get_news` function with category: `tech` and country: `Lithuania`\n" + ] }, { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger AI message: \n", - " Example output:\n", - " {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"What's the time in New York?\",\n", - " \"Tell me the time in Tokyo.\",\n", - " \"Can you give me the time in London?\",\n", - " \"What's the current time in Sydney?\",\n", - " \"Can you tell me the time in Berlin?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger AI message: \n", - " Example output:\n", - " {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"Tell me the latest news from the US\",\n", - " \"What's happening in India today?\",\n", - " \"Get me the top stories from Japan\",\n", - " \"Can you give me the breaking news from Brazil?\",\n", - " \"What's the latest news from Germany?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Route: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\n", - "Route: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\n" - ] - } - ], - "source": [ - "def get_time(location: str) -> str:\n", - " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", - " return \"get_time\"\n", - "\n", - "\n", - "def get_news(category: str, country: str) -> str:\n", - " \"\"\"Useful to get the news in a specific country\"\"\"\n", - " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", - " )\n", - " return \"get_news\"\n", - "\n", - "\n", - "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", - "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", - "\n", - "# Tools\n", - "tools = [get_time, get_news]" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m2023-12-19 17:46:52 WARNING semantic_router.utils.logger No function found, calling LLM...\u001b[0m\n" + ] }, { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"location\": \"Stockholm\"\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "parameters: {'location': 'Stockholm'}\n", - "Calling `get_time` function with location: Stockholm\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"category\": \"tech\",\n", - " \"country\": \"Lithuania\"\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "parameters: {'category': 'tech', 'country': 'Lithuania'}\n", - "Calling `get_news` function with category: tech and country: Lithuania\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[33m2023-12-18 12:20:05 WARNING semantic_router.utils.logger No function found\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:06 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "' How can I help you today?'" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", - "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", - "call(query=\"Hi!\", functions=tools, router=router)" + "data": { + "text/plain": [ + "'Hello! How can I assist you today?'" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" } + ], + "source": [ + "from semantic_router.utils.function_call import route_and_execute\n", + "\n", + "tools = [get_time, get_news]\n", + "\n", + "await route_and_execute(\n", + " query=\"What is the time in Stockholm?\", functions=tools, route_layer=route_layer\n", + ")\n", + "await route_and_execute(\n", + " query=\"What is the tech news in the Lithuania?\",\n", + " functions=tools,\n", + " route_layer=route_layer,\n", + ")\n", + "await route_and_execute(query=\"Hi!\", functions=tools, route_layer=route_layer)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 2 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 9e5eca6645418d25515fe0dd80b3af7e05909069..1840459129205d9a381e869fffa9123a6d351c5f 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -1,199 +1,206 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Semantic Router: Hybrid Layer\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The Hybrid Layer in the Semantic Router library can improve making performance particularly for niche use-cases that contain specific terminology, such as finance or medical. It helps us provide more importance to making based on the keywords contained in our utterances and user queries.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Getting Started\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We start by installing the library:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#!pip install -qU semantic-router==0.0.6" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We start by defining a dictionary mapping s to example phrases that should trigger those s.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.schemas.route import Route\n", - "\n", - "politics = Route(\n", - " name=\"politics\",\n", - " utterances=[\n", - " \"isn't politics the best thing ever\",\n", - " \"why don't you tell me about your political opinions\",\n", - " \"don't you just love the president\",\n", - " \"don't you just hate the president\",\n", - " \"they're going to destroy this country!\",\n", - " \"they will save the country!\",\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's define another for good measure:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "chitchat = Route(\n", - " name=\"chitchat\",\n", - " utterances=[\n", - " \"how's the weather today?\",\n", - " \"how are things going?\",\n", - " \"lovely weather today\",\n", - " \"the weather is horrendous\",\n", - " \"let's go to the chippy\",\n", - " ],\n", - ")\n", - "\n", - "chitchat = Route(\n", - " name=\"chitchat\",\n", - " utterances=[\n", - " \"how's the weather today?\",\n", - " \"how are things going?\",\n", - " \"lovely weather today\",\n", - " \"the weather is horrendous\",\n", - " \"let's go to the chippy\",\n", - " ],\n", - ")\n", - "\n", - "routes = [politics, chitchat]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we initialize our embedding model:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from semantic_router.encoders import CohereEncoder, BM25Encoder, TfidfEncoder\n", - "from getpass import getpass\n", - "\n", - "os.environ[\"COHERE_API_KEY\"] = os.environ[\"COHERE_API_KEY\"] or getpass(\n", - " \"Enter Cohere API Key: \"\n", - ")\n", - "\n", - "dense_encoder = CohereEncoder()\n", - "# sparse_encoder = BM25Encoder()\n", - "sparse_encoder = TfidfEncoder()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.hybrid_layer import HybridRouteLayer\n", - "\n", - "dl = HybridRouteLayer(\n", - " dense_encoder=dense_encoder, sparse_encoder=sparse_encoder, routes=routes\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dl(\"don't you love politics?\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dl(\"how's the weather today?\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "decision-layer", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Semantic Router: Hybrid Layer\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Hybrid Layer in the Semantic Router library can improve making performance particularly for niche use-cases that contain specific terminology, such as finance or medical. It helps us provide more importance to making based on the keywords contained in our utterances and user queries.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing the library:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU semantic-router==0.0.11" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by defining a dictionary mapping s to example phrases that should trigger those s.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router.schema import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\",\n", + " \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define another for good measure:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we initialize our embedding model:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from semantic_router.encoders import CohereEncoder, BM25Encoder, TfidfEncoder\n", + "from getpass import getpass\n", + "\n", + "os.environ[\"COHERE_API_KEY\"] = os.environ[\"COHERE_API_KEY\"] or getpass(\n", + " \"Enter Cohere API Key: \"\n", + ")\n", + "\n", + "dense_encoder = CohereEncoder()\n", + "# sparse_encoder = BM25Encoder()\n", + "sparse_encoder = TfidfEncoder()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router.hybrid_layer import HybridRouteLayer\n", + "\n", + "dl = HybridRouteLayer(\n", + " dense_encoder=dense_encoder, sparse_encoder=sparse_encoder, routes=routes\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dl(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dl(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/poetry.lock b/poetry.lock index 0a9be4e117e6f4b6b8cc805b44fd7624d109bd3d..63248ed235772f93d252b60125797bf96601e1a3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1594,6 +1594,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.2" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"}, + {file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "4.1.0" @@ -1686,6 +1704,65 @@ files = [ {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, ] +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "pyzmq" version = "25.1.2" @@ -2053,6 +2130,17 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "types-pyyaml" +version = "6.0.12.12" +description = "Typing stubs for PyYAML" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, +] + [[package]] name = "typing-extensions" version = "4.9.0" @@ -2222,4 +2310,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "f2735c243faa3d788c0f6268d6cb550648ed0d1fffec27a084344dafa4590a80" +content-hash = "afd687626ef87dc72424414d7c2333caf360bccb01fab087cfd78b97ea62e04f" diff --git a/pyproject.toml b/pyproject.toml index e45e5f17d0356cce8a2cfe5a33d9fa0529c170c5..bffbba0e384e9cd5115a117ac12dc04e6e741005 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "semantic-router" -version = "0.0.11" +version = "0.0.14" description = "Super fast semantic router for AI decision making" authors = [ "James Briggs <james@aurelio.ai>", @@ -10,6 +10,7 @@ authors = [ "Bogdan Buduroiu <bogdan@aurelio.ai>" ] readme = "README.md" +packages = [{include = "semantic_router"}] [tool.poetry.dependencies] python = "^3.9" @@ -19,6 +20,8 @@ cohere = "^4.32" numpy = "^1.25.2" pinecone-text = "^0.7.0" colorlog = "^6.8.0" +pyyaml = "^6.0.1" +pytest-asyncio = "^0.23.2" [tool.poetry.group.dev.dependencies] @@ -30,6 +33,7 @@ pytest-mock = "^3.12.0" pytest-cov = "^4.1.0" pytest-xdist = "^3.5.0" mypy = "^1.7.1" +types-pyyaml = "^6.0.12.12" [build-system] requires = ["poetry-core"] @@ -38,5 +42,8 @@ build-backend = "poetry.core.masonry.api" [tool.ruff.per-file-ignores] "*.ipynb" = ["ALL"] +[tool.ruff] +line-length = 88 + [tool.mypy] ignore_missing_imports = true diff --git a/semantic_router/__init__.py b/semantic_router/__init__.py index 0c445bea3ff4efd8f3aa8950e2c772277d93b20c..1c604af8065f9b2e1519e6f92daf7af2739d584b 100644 --- a/semantic_router/__init__.py +++ b/semantic_router/__init__.py @@ -1,4 +1,5 @@ -from .hybrid_layer import HybridRouteLayer -from .layer import RouteLayer +from semantic_router.hybrid_layer import HybridRouteLayer +from semantic_router.layer import LayerConfig, RouteLayer +from semantic_router.route import Route -__all__ = ["RouteLayer", "HybridRouteLayer"] +__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"] diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index c2bde1e5f9af069e5b412e54ea9454b6a09e3bd8..6fbd37d9f7d84db3b35aad9c0febfef4d0c321ec 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -1,8 +1,8 @@ -from .base import BaseEncoder -from .bm25 import BM25Encoder -from .cohere import CohereEncoder -from .openai import OpenAIEncoder -from .tfidf import TfidfEncoder +from semantic_router.encoders.base import BaseEncoder +from semantic_router.encoders.bm25 import BM25Encoder +from semantic_router.encoders.cohere import CohereEncoder +from semantic_router.encoders.openai import OpenAIEncoder +from semantic_router.encoders.tfidf import TfidfEncoder __all__ = [ "BaseEncoder", diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index 632ebc7924a5a74088068bfb329a4e04c68cb6df..bd9524037a2cc6decd60d7674124c717aea6bba6 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -1,8 +1,9 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field class BaseEncoder(BaseModel): name: str + type: str = Field(default="base") class Config: arbitrary_types_allowed = True diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index c9da628e1493e53760f6c060dcd64e4dfccdc3d4..f43e1780cace53529bc7a0b5b5f9eb15a98fd9da 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -8,6 +8,7 @@ from semantic_router.encoders import BaseEncoder class BM25Encoder(BaseEncoder): model: Any | None = None idx_mapping: dict[int, int] | None = None + type: str = "sparse" def __init__(self, name: str = "bm25"): super().__init__(name=name) diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index 9cddcb58baa74a767dfc47fcd3de70c44c505cc9..f7aef0e6227938ef174d867f12e05ac19f58524d 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -7,12 +7,15 @@ from semantic_router.encoders import BaseEncoder class CohereEncoder(BaseEncoder): client: cohere.Client | None = None + type: str = "cohere" def __init__( self, - name: str = os.getenv("COHERE_MODEL_NAME", "embed-english-v3.0"), + name: str | None = None, cohere_api_key: str | None = None, ): + if name is None: + name = os.getenv("COHERE_MODEL_NAME", "embed-english-v3.0") super().__init__(name=name) cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") if cohere_api_key is None: diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index c6d4cc962b7b9ac38400f527ac20baa6543490d9..f9348a1271ed855df125b27c1c46095893e652c2 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -11,12 +11,15 @@ from semantic_router.utils.logger import logger class OpenAIEncoder(BaseEncoder): client: openai.Client | None + type: str = "openai" def __init__( self, - name: str = os.getenv("OPENAI_MODEL_NAME", "text-embedding-ada-002"), + name: str | None = None, openai_api_key: str | None = None, ): + if name is None: + name = os.getenv("OPENAI_MODEL_NAME", "text-embedding-ada-002") super().__init__(name=name) api_key = openai_api_key or os.getenv("OPENAI_API_KEY") if api_key is None: @@ -35,7 +38,6 @@ class OpenAIEncoder(BaseEncoder): # Exponential backoff for j in range(3): try: - logger.info(f"Encoding {len(docs)} documents...") embeds = self.client.embeddings.create(input=docs, model=self.name) if embeds.data: break diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 2901871aff43501d685d58d6e2cd2add8bca0f17..d62a996da5b7ebaf6e84dd7e821875e7f4b74e2f 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -1,6 +1,5 @@ import numpy as np from numpy.linalg import norm -from tqdm.auto import tqdm from semantic_router.encoders import ( BaseEncoder, @@ -8,7 +7,7 @@ from semantic_router.encoders import ( OpenAIEncoder, TfidfEncoder, ) -from semantic_router.schemas.route import Route +from semantic_router.route import Route from semantic_router.utils.logger import logger @@ -41,8 +40,9 @@ class HybridRouteLayer: self.sparse_encoder.fit(routes) if routes: # initialize index now - for route in tqdm(routes): - self._add_route(route=route) + # for route in tqdm(routes): + # self._add_route(route=route) + self._add_routes(routes) def __call__(self, text: str) -> str | None: results = self._query(text) @@ -92,6 +92,38 @@ class HybridRouteLayer: else: self.sparse_index = np.concatenate([self.sparse_index, sparse_embeds]) + def _add_routes(self, routes: list[Route]): + # create embeddings for all routes + logger.info("Creating embeddings for all routes...") + all_utterances = [ + utterance for route in routes for utterance in route.utterances + ] + dense_embeds = np.array(self.encoder(all_utterances)) + sparse_embeds = np.array(self.sparse_encoder(all_utterances)) + + # create route array + route_names = [route.name for route in routes for _ in route.utterances] + route_array = np.array(route_names) + self.categories = ( + np.concatenate([self.categories, route_array]) + if self.categories is not None + else route_array + ) + + # create utterance array (the dense index) + self.index = ( + np.concatenate([self.index, dense_embeds]) + if self.index is not None + else dense_embeds + ) + + # create sparse utterance array + self.sparse_index = ( + np.concatenate([self.sparse_index, sparse_embeds]) + if self.sparse_index is not None + else sparse_embeds + ) + def _query(self, text: str, top_k: int = 5): """Given some text, encodes and searches the index vector space to retrieve the top_k most similar records. diff --git a/semantic_router/layer.py b/semantic_router/layer.py index af08a9c1246dfd903defeef6f17e23dc198e9213..5b2aad846ad8531fbeb02504f42e1688b22ceebc 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -1,4 +1,8 @@ +import json +import os + import numpy as np +import yaml from semantic_router.encoders import ( BaseEncoder, @@ -6,17 +10,145 @@ from semantic_router.encoders import ( OpenAIEncoder, ) from semantic_router.linear import similarity_matrix, top_scores -from semantic_router.schemas.route import Route +from semantic_router.route import Route +from semantic_router.schema import Encoder, EncoderType, RouteChoice from semantic_router.utils.logger import logger +def is_valid(layer_config: str) -> bool: + try: + output_json = json.loads(layer_config) + required_keys = ["encoder_name", "encoder_type", "routes"] + + if isinstance(output_json, list): + for item in output_json: + missing_keys = [key for key in required_keys if key not in item] + if missing_keys: + logger.warning( + f"Missing keys in layer config: {', '.join(missing_keys)}" + ) + return False + return True + else: + missing_keys = [key for key in required_keys if key not in output_json] + if missing_keys: + logger.warning( + f"Missing keys in layer config: {', '.join(missing_keys)}" + ) + return False + else: + return True + except json.JSONDecodeError as e: + logger.error(e) + return False + + +class LayerConfig: + """ + Generates a LayerConfig object that can be used for initializing a + RouteLayer. + """ + + routes: list[Route] = [] + + def __init__( + self, + routes: list[Route] = [], + encoder_type: str = "openai", + encoder_name: str | None = None, + ): + self.encoder_type = encoder_type + if encoder_name is None: + # if encoder_name is not provided, use the default encoder for type + if encoder_type == EncoderType.OPENAI: + encoder_name = "text-embedding-ada-002" + elif encoder_type == EncoderType.COHERE: + encoder_name = "embed-english-v3.0" + elif encoder_type == EncoderType.HUGGINGFACE: + raise NotImplementedError + logger.info(f"Using default {encoder_type} encoder: {encoder_name}") + self.encoder_name = encoder_name + self.routes = routes + + @classmethod + def from_file(cls, path: str): + """Load the routes from a file in JSON or YAML format""" + logger.info(f"Loading route config from {path}") + _, ext = os.path.splitext(path) + with open(path, "r") as f: + if ext == ".json": + layer = json.load(f) + elif ext in [".yaml", ".yml"]: + layer = yaml.safe_load(f) + else: + raise ValueError( + "Unsupported file type. Only .json and .yaml are supported" + ) + + route_config_str = json.dumps(layer) + if is_valid(route_config_str): + encoder_type = layer["encoder_type"] + encoder_name = layer["encoder_name"] + routes = [Route.from_dict(route) for route in layer["routes"]] + return cls( + encoder_type=encoder_type, encoder_name=encoder_name, routes=routes + ) + else: + raise Exception("Invalid config JSON or YAML") + + def to_dict(self): + return { + "encoder_type": self.encoder_type, + "encoder_name": self.encoder_name, + "routes": [route.to_dict() for route in self.routes], + } + + def to_file(self, path: str): + """Save the routes to a file in JSON or YAML format""" + logger.info(f"Saving route config to {path}") + _, ext = os.path.splitext(path) + with open(path, "w") as f: + if ext == ".json": + json.dump(self.to_dict(), f, indent=4) + elif ext in [".yaml", ".yml"]: + yaml.safe_dump(self.to_dict(), f) + else: + raise ValueError( + "Unsupported file type. Only .json and .yaml are supported" + ) + + def add(self, route: Route): + self.routes.append(route) + logger.info(f"Added route `{route.name}`") + + def get(self, name: str) -> Route | None: + for route in self.routes: + if route.name == name: + return route + logger.error(f"Route `{name}` not found") + return None + + def remove(self, name: str): + if name not in [route.name for route in self.routes]: + logger.error(f"Route `{name}` not found") + else: + self.routes = [route for route in self.routes if route.name != name] + logger.info(f"Removed route `{name}`") + + class RouteLayer: - index = None - categories = None - score_threshold = 0.82 + index: np.ndarray | None = None + categories: np.ndarray | None = None + score_threshold: float = 0.82 - def __init__(self, encoder: BaseEncoder, routes: list[Route] = []): - self.encoder = encoder + def __init__( + self, encoder: BaseEncoder | None = None, routes: list[Route] | None = None + ): + logger.info("Initializing RouteLayer") + self.index = None + self.categories = None + self.encoder = encoder if encoder is not None else CohereEncoder() + self.routes: list[Route] = routes if routes is not None else [] # decide on default threshold based on encoder if isinstance(encoder, OpenAIEncoder): self.score_threshold = 0.82 @@ -25,37 +157,71 @@ class RouteLayer: else: self.score_threshold = 0.82 # if routes list has been passed, we initialize index now - if routes: + if len(self.routes) > 0: # initialize index now - self.add_routes(routes=routes) + self._add_routes(routes=self.routes) - def __call__(self, text: str) -> str | None: + def __call__(self, text: str) -> RouteChoice: results = self._query(text) top_class, top_class_scores = self._semantic_classify(results) passed = self._pass_threshold(top_class_scores, self.score_threshold) if passed: - return top_class + # get chosen route object + route = [route for route in self.routes if route.name == top_class][0] + return route(text) else: - return None + # if no route passes threshold, return empty route choice + return RouteChoice() + + def __str__(self): + return ( + f"RouteLayer(encoder={self.encoder}, " + f"score_threshold={self.score_threshold}, " + f"routes={self.routes})" + ) - def add_route(self, route: Route): + @classmethod + def from_json(cls, file_path: str): + config = LayerConfig.from_file(file_path) + encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model + return cls(encoder=encoder, routes=config.routes) + + @classmethod + def from_yaml(cls, file_path: str): + config = LayerConfig.from_file(file_path) + encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model + return cls(encoder=encoder, routes=config.routes) + + @classmethod + def from_config(cls, config: LayerConfig): + encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model + return cls(encoder=encoder, routes=config.routes) + + def add(self, route: Route): + print(f"Adding route `{route.name}`") # create embeddings embeds = self.encoder(route.utterances) # create route array if self.categories is None: + print("Initializing categories array") self.categories = np.array([route.name] * len(embeds)) else: + print("Adding route to categories") str_arr = np.array([route.name] * len(embeds)) self.categories = np.concatenate([self.categories, str_arr]) # create utterance array (the index) if self.index is None: + print("Initializing index array") self.index = np.array(embeds) else: + print("Adding route to index") embed_arr = np.array(embeds) self.index = np.concatenate([self.index, embed_arr]) + # add route to routes list + self.routes.append(route) - def add_routes(self, routes: list[Route]): + def _add_routes(self, routes: list[Route]): # create embeddings for all routes all_utterances = [ utterance for route in routes for utterance in route.utterances @@ -124,3 +290,18 @@ class RouteLayer: return max(scores) > threshold else: return False + + def to_config(self) -> LayerConfig: + return LayerConfig( + encoder_type=self.encoder.type, + encoder_name=self.encoder.name, + routes=self.routes, + ) + + def to_json(self, file_path: str): + config = self.to_config() + config.to_file(file_path) + + def to_yaml(self, file_path: str): + config = self.to_config() + config.to_file(file_path) diff --git a/semantic_router/route.py b/semantic_router/route.py new file mode 100644 index 0000000000000000000000000000000000000000..12afa7fe0a6824882a6a9da4d2d3845f6b68a62c --- /dev/null +++ b/semantic_router/route.py @@ -0,0 +1,126 @@ +import json +import re +from typing import Any, Callable, Union + +from pydantic import BaseModel + +from semantic_router.schema import RouteChoice +from semantic_router.utils import function_call +from semantic_router.utils.llm import llm +from semantic_router.utils.logger import logger + + +def is_valid(route_config: str) -> bool: + try: + output_json = json.loads(route_config) + required_keys = ["name", "utterances"] + + if isinstance(output_json, list): + for item in output_json: + missing_keys = [key for key in required_keys if key not in item] + if missing_keys: + logger.warning( + f"Missing keys in route config: {', '.join(missing_keys)}" + ) + return False + return True + else: + missing_keys = [key for key in required_keys if key not in output_json] + if missing_keys: + logger.warning( + f"Missing keys in route config: {', '.join(missing_keys)}" + ) + return False + else: + return True + except json.JSONDecodeError as e: + logger.error(e) + return False + + +class Route(BaseModel): + name: str + utterances: list[str] + description: str | None = None + function_schema: dict[str, Any] | None = None + + def __call__(self, query: str) -> RouteChoice: + if self.function_schema: + # if a function schema is provided we generate the inputs + extracted_inputs = function_call.extract_function_inputs( + query=query, function_schema=self.function_schema + ) + func_call = extracted_inputs + else: + # otherwise we just pass None for the call + func_call = None + return RouteChoice(name=self.name, function_call=func_call) + + def to_dict(self): + return self.dict() + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + @classmethod + def from_dynamic_route(cls, entity: Union[BaseModel, Callable]): + """ + Generate a dynamic Route object from a function or Pydantic model using LLM + """ + schema = function_call.get_schema(item=entity) + dynamic_route = cls._generate_dynamic_route(function_schema=schema) + return dynamic_route + + @classmethod + def _parse_route_config(cls, config: str) -> str: + # Regular expression to match content inside <config></config> + config_pattern = r"<config>(.*?)</config>" + match = re.search(config_pattern, config, re.DOTALL) + + if match: + config_content = match.group(1).strip() # Get the matched content + return config_content + else: + raise ValueError("No <config></config> tags found in the output.") + + @classmethod + def _generate_dynamic_route(cls, function_schema: dict[str, Any]): + logger.info("Generating dynamic route...") + + prompt = f""" + You are tasked to generate a JSON configuration based on the provided + function schema. Please follow the template below, no other tokens allowed: + + <config> + {{ + "name": "<function_name>", + "utterances": [ + "<example_utterance_1>", + "<example_utterance_2>", + "<example_utterance_3>", + "<example_utterance_4>", + "<example_utterance_5>"] + }} + </config> + + Only include the "name" and "utterances" keys in your answer. + The "name" should match the function name and the "utterances" + should comprise a list of 5 example phrases that could be used to invoke + the function. Use real values instead of placeholders. + + Input schema: + {function_schema} + """ + + output = llm(prompt) + if not output: + raise Exception("No output generated for dynamic route") + + route_config = cls._parse_route_config(config=output) + + logger.info(f"Generated route config:\n{route_config}") + + if is_valid(route_config): + return Route.from_dict(json.loads(route_config)) + raise Exception("No config generated") diff --git a/semantic_router/schemas/encoder.py b/semantic_router/schemas/encoder.py index 1b2ad74c4b6b4ad5636e1befe470a248165e3893..fbbfb2d7f8358eb47967a50218a6cfabee411828 100644 --- a/semantic_router/schemas/encoder.py +++ b/semantic_router/schemas/encoder.py @@ -15,13 +15,18 @@ class EncoderType(Enum): COHERE = "cohere" +class RouteChoice(BaseModel): + name: str | None = None + function_call: dict | None = None + + @dataclass class Encoder: type: EncoderType - name: str + name: str | None model: BaseEncoder - def __init__(self, type: str, name: str): + def __init__(self, type: str, name: str | None): self.type = EncoderType(type) self.name = name if self.type == EncoderType.HUGGINGFACE: @@ -30,6 +35,8 @@ class Encoder: self.model = OpenAIEncoder(name) elif self.type == EncoderType.COHERE: self.model = CohereEncoder(name) + else: + raise ValueError def __call__(self, texts: list[str]) -> list[list[float]]: return self.model(texts) diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py new file mode 100644 index 0000000000000000000000000000000000000000..2ead3ab58dd5c54eaf26fdc5b2f73ea95f4bef9c --- /dev/null +++ b/semantic_router/utils/function_call.py @@ -0,0 +1,127 @@ +import inspect +import json +from typing import Any, Callable, Union + +from pydantic import BaseModel + +from semantic_router.utils.llm import llm +from semantic_router.utils.logger import logger + + +def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]: + if isinstance(item, BaseModel): + signature_parts = [] + for field_name, field_model in item.__annotations__.items(): + field_info = item.__fields__[field_name] + default_value = field_info.default + + if default_value: + default_repr = repr(default_value) + signature_part = ( + f"{field_name}: {field_model.__name__} = {default_repr}" + ) + else: + signature_part = f"{field_name}: {field_model.__name__}" + + signature_parts.append(signature_part) + signature = f"({', '.join(signature_parts)}) -> str" + schema = { + "name": item.__class__.__name__, + "description": item.__doc__, + "signature": signature, + } + else: + schema = { + "name": item.__name__, + "description": str(inspect.getdoc(item)), + "signature": str(inspect.signature(item)), + "output": str(inspect.signature(item).return_annotation), + } + return schema + + +def extract_function_inputs(query: str, function_schema: dict[str, Any]) -> dict: + logger.info("Extracting function input...") + + prompt = f""" + You are a helpful assistant designed to output JSON. + Given the following function schema + << {function_schema} >> + and query + << {query} >> + extract the parameters values from the query, in a valid JSON format. + Example: + Input: + query: "How is the weather in Hawaii right now in International units?" + schema: + {{ + "name": "get_weather", + "description": "Useful to get the weather in a specific location", + "signature": "(location: str, degree: str) -> str", + "output": "<class 'str'>", + }} + + Result: {{ + "location": "London", + "degree": "Celsius", + }} + + Input: + query: {query} + schema: {function_schema} + Result: + """ + + output = llm(prompt) + if not output: + raise Exception("No output generated for extract function input") + + output = output.replace("'", '"').strip().rstrip(",") + + function_inputs = json.loads(output) + if not is_valid_inputs(function_inputs, function_schema): + raise ValueError("Invalid inputs") + return function_inputs + + +def is_valid_inputs(inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool: + """Validate the extracted inputs against the function schema""" + try: + # Extract parameter names and types from the signature string + signature = function_schema["signature"] + param_info = [param.strip() for param in signature[1:-1].split(",")] + param_names = [info.split(":")[0].strip() for info in param_info] + param_types = [ + info.split(":")[1].strip().split("=")[0].strip() for info in param_info + ] + + for name, type_str in zip(param_names, param_types): + if name not in inputs: + logger.error(f"Input {name} missing from query") + return False + return True + except Exception as e: + logger.error(f"Input validation error: {str(e)}") + return False + + +def call_function(function: Callable, inputs: dict[str, str]): + try: + return function(**inputs) + except TypeError as e: + logger.error(f"Error calling function: {e}") + + +# TODO: Add route layer object to the input, solve circular import issue +async def route_and_execute(query: str, functions: list[Callable], route_layer): + function_name = route_layer(query) + if not function_name: + logger.warning("No function found, calling LLM...") + return llm(query) + + for function in functions: + if function.__name__ == function_name: + print(f"Calling function: {function.__name__}") + schema = get_schema(function) + inputs = extract_function_inputs(query, schema) + call_function(function, inputs) diff --git a/semantic_router/utils/llm.py b/semantic_router/utils/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..e92c1bcf7752b5fce5a071dde41da4a24d0851a9 --- /dev/null +++ b/semantic_router/utils/llm.py @@ -0,0 +1,64 @@ +import os + +import openai + +from semantic_router.utils.logger import logger + + +def llm(prompt: str) -> str | None: + try: + client = openai.OpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), + ) + + completion = client.chat.completions.create( + model="mistralai/mistral-7b-instruct", + messages=[ + { + "role": "user", + "content": prompt, + }, + ], + temperature=0.01, + max_tokens=200, + ) + + output = completion.choices[0].message.content + + if not output: + raise Exception("No output generated") + return output + except Exception as e: + logger.error(f"LLM error: {e}") + raise Exception(f"LLM error: {e}") + + +# TODO integrate async LLM function +# async def allm(prompt: str) -> str | None: +# try: +# client = openai.AsyncOpenAI( +# base_url="https://openrouter.ai/api/v1", +# api_key=os.getenv("OPENROUTER_API_KEY"), +# ) + +# completion = await client.chat.completions.create( +# model="mistralai/mistral-7b-instruct", +# messages=[ +# { +# "role": "user", +# "content": prompt, +# }, +# ], +# temperature=0.01, +# max_tokens=200, +# ) + +# output = completion.choices[0].message.content + +# if not output: +# raise Exception("No output generated") +# return output +# except Exception as e: +# logger.error(f"LLM error: {e}") +# raise Exception(f"LLM error: {e}") diff --git a/semantic_router/utils/logger.py b/semantic_router/utils/logger.py index a001623a9c1eae5cc5a632f6afd69858f0319e32..00c83693435487016f819c4716900fc09f8b8b92 100644 --- a/semantic_router/utils/logger.py +++ b/semantic_router/utils/logger.py @@ -22,18 +22,9 @@ class CustomFormatter(colorlog.ColoredFormatter): def add_coloured_handler(logger): formatter = CustomFormatter() - console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) - - logging.basicConfig( - datefmt="%Y-%m-%d %H:%M:%S", - format="%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s", - force=True, - ) - logger.addHandler(console_handler) - return logger diff --git a/tests/unit/encoders/test_openai.py b/tests/unit/encoders/test_openai.py index cc79d27207f7847439b74b0c29f4fb75d42d5381..4679ee939f7d4b494150a8a52f4b4c33a0e6c8db 100644 --- a/tests/unit/encoders/test_openai.py +++ b/tests/unit/encoders/test_openai.py @@ -20,9 +20,8 @@ class TestOpenAIEncoder: def test_openai_encoder_init_no_api_key(self, mocker): mocker.patch("os.getenv", return_value=None) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError) as _: OpenAIEncoder() - assert "OpenAI API key cannot be 'None'." in str(e.value) def test_openai_encoder_call_uninitialized_client(self, openai_encoder): # Set the client to None to simulate an uninitialized client diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 2506c19943117111af57620b107b668165a8f544..b77f51adf559182d5e4b1c393096357e6250040d 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -8,7 +8,7 @@ from semantic_router.encoders import ( TfidfEncoder, ) from semantic_router.hybrid_layer import HybridRouteLayer -from semantic_router.schemas.route import Route +from semantic_router.route import Route def mock_encoder_call(utterances): @@ -88,7 +88,7 @@ class TestHybridRouteLayer: dense_encoder=openai_encoder, sparse_encoder=bm25_encoder ) route = Route(name="Route 3", utterances=["Yes", "No"]) - route_layer.add(route) + route_layer._add_routes([route]) assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 2 assert len(set(route_layer.categories)) == 1 diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index d049243f1f7fd174b13cb41d4bf90e14a8c7c331..45b57472f0fabfb7b7694e2ee9072540cee0bd23 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -1,8 +1,11 @@ +import os +from unittest.mock import mock_open, patch + import pytest from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder -from semantic_router.layer import RouteLayer -from semantic_router.schemas.route import Route +from semantic_router.layer import LayerConfig, RouteLayer +from semantic_router.route import Route def mock_encoder_call(utterances): @@ -17,6 +20,52 @@ def mock_encoder_call(utterances): return [mock_responses.get(u, [0, 0, 0]) for u in utterances] +def layer_json(): + return """{ + "encoder_type": "cohere", + "encoder_name": "embed-english-v3.0", + "routes": [ + { + "name": "politics", + "utterances": [ + "isn't politics the best thing ever", + "why don't you tell me about your political opinions" + ], + "description": null, + "function_schema": null + }, + { + "name": "chitchat", + "utterances": [ + "how's the weather today?", + "how are things going?" + ], + "description": null, + "function_schema": null + } + ] +}""" + + +def layer_yaml(): + return """encoder_name: embed-english-v3.0 +encoder_type: cohere +routes: +- description: null + function_schema: null + name: politics + utterances: + - isn't politics the best thing ever + - why don't you tell me about your political opinions +- description: null + function_schema: null + name: chitchat + utterances: + - how's the weather today? + - how are things going? + """ + + @pytest.fixture def base_encoder(): return BaseEncoder(name="test-encoder") @@ -65,32 +114,33 @@ class TestRouteLayer: route1 = Route(name="Route 1", utterances=["Yes", "No"]) route2 = Route(name="Route 2", utterances=["Maybe", "Sure"]) - route_layer.add_route(route=route1) + route_layer.add(route=route1) assert route_layer.index is not None and route_layer.categories is not None - assert len(route_layer.index) == 2 + assert route_layer.index.shape[0] == 2 assert len(set(route_layer.categories)) == 1 assert set(route_layer.categories) == {"Route 1"} - route_layer.add_route(route=route2) - assert len(route_layer.index) == 4 + route_layer.add(route=route2) + assert route_layer.index.shape[0] == 4 assert len(set(route_layer.categories)) == 2 assert set(route_layer.categories) == {"Route 1", "Route 2"} + del route_layer def test_add_multiple_routes(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder) - route_layer.add_routes(routes=routes) + route_layer._add_routes(routes=routes) assert route_layer.index is not None and route_layer.categories is not None - assert len(route_layer.index) == 5 + assert route_layer.index.shape[0] == 5 assert len(set(route_layer.categories)) == 2 def test_query_and_classification(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) - query_result = route_layer("Hello") + query_result = route_layer("Hello").name assert query_result in ["Route 1", "Route 2"] def test_query_with_no_index(self, openai_encoder): route_layer = RouteLayer(encoder=openai_encoder) - assert route_layer("Anything") is None + assert route_layer("Anything").name is None def test_semantic_classify(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) @@ -124,5 +174,120 @@ class TestRouteLayer: route_layer = RouteLayer(encoder=base_encoder) assert route_layer.score_threshold == 0.82 + def test_json(self, openai_encoder, routes): + os.environ["OPENAI_API_KEY"] = "test_api_key" + route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + route_layer.to_json("test_output.json") + assert os.path.exists("test_output.json") + route_layer_from_file = RouteLayer.from_json("test_output.json") + assert ( + route_layer_from_file.index is not None + and route_layer_from_file.categories is not None + ) + os.remove("test_output.json") + + def test_yaml(self, openai_encoder, routes): + os.environ["OPENAI_API_KEY"] = "test_api_key" + route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + route_layer.to_yaml("test_output.yaml") + assert os.path.exists("test_output.yaml") + route_layer_from_file = RouteLayer.from_yaml("test_output.yaml") + assert ( + route_layer_from_file.index is not None + and route_layer_from_file.categories is not None + ) + os.remove("test_output.yaml") + + def test_config(self, openai_encoder, routes): + os.environ["OPENAI_API_KEY"] = "test_api_key" + route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + # confirm route creation functions as expected + layer_config = route_layer.to_config() + assert layer_config.routes == routes + # now load from config and confirm it's the same + route_layer_from_config = RouteLayer.from_config(layer_config) + assert (route_layer_from_config.index == route_layer.index).all() + assert (route_layer_from_config.categories == route_layer.categories).all() + assert route_layer_from_config.score_threshold == route_layer.score_threshold + # Add more tests for edge cases and error handling as needed. + + +class TestLayerConfig: + def test_init(self): + layer_config = LayerConfig() + assert layer_config.routes == [] + + def test_to_file_json(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + layer_config.to_file("data/test_output.json") + mocked_open.assert_called_once_with("data/test_output.json", "w") + + def test_to_file_yaml(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + layer_config.to_file("data/test_output.yaml") + mocked_open.assert_called_once_with("data/test_output.yaml", "w") + + def test_to_file_invalid(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + with pytest.raises(ValueError): + layer_config.to_file("test_output.txt") + + def test_from_file_json(self): + mock_json_data = layer_json() + with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open: + layer_config = LayerConfig.from_file("data/test.json") + mocked_open.assert_called_once_with("data/test.json", "r") + assert isinstance(layer_config, LayerConfig) + + def test_from_file_yaml(self): + mock_yaml_data = layer_yaml() + with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open: + layer_config = LayerConfig.from_file("data/test.yaml") + mocked_open.assert_called_once_with("data/test.yaml", "r") + assert isinstance(layer_config, LayerConfig) + + def test_from_file_invalid(self): + with open("test.txt", "w") as f: + f.write("dummy content") + with pytest.raises(ValueError): + LayerConfig.from_file("test.txt") + os.remove("test.txt") + + def test_to_dict(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + assert layer_config.to_dict()["routes"] == [route.to_dict()] + + def test_add(self): + route = Route(name="test", utterances=["utterance"]) + route2 = Route(name="test2", utterances=["utterance2"]) + layer_config = LayerConfig() + layer_config.add(route) + # confirm route added + assert layer_config.routes == [route] + # add second route and check updates + layer_config.add(route2) + assert layer_config.routes == [route, route2] + + def test_get(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + assert layer_config.get("test") == route + + def test_get_not_found(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + assert layer_config.get("not_found") is None + + def test_remove(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + layer_config.remove("test") + assert layer_config.routes == [] diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py new file mode 100644 index 0000000000000000000000000000000000000000..09a5d235f0183ebd70c1fa1bbdbe7fd842ee6d09 --- /dev/null +++ b/tests/unit/test_route.py @@ -0,0 +1,208 @@ +from unittest.mock import Mock, patch # , AsyncMock + +# import pytest +from semantic_router.route import Route, is_valid + + +# Is valid test: +def test_is_valid_with_valid_json(): + valid_json = '{"name": "test_route", "utterances": ["hello", "hi"]}' + assert is_valid(valid_json) is True + + +def test_is_valid_with_missing_keys(): + invalid_json = '{"name": "test_route"}' # Missing 'utterances' + with patch("semantic_router.route.logger") as mock_logger: + assert is_valid(invalid_json) is False + mock_logger.warning.assert_called_once() + + +def test_is_valid_with_valid_json_list(): + valid_json_list = ( + '[{"name": "test_route1", "utterances": ["hello"]}, ' + '{"name": "test_route2", "utterances": ["hi"]}]' + ) + assert is_valid(valid_json_list) is True + + +def test_is_valid_with_invalid_json_list(): + invalid_json_list = ( + '[{"name": "test_route1"}, {"name": "test_route2", "utterances": ["hi"]}]' + ) + with patch("semantic_router.route.logger") as mock_logger: + assert is_valid(invalid_json_list) is False + mock_logger.warning.assert_called_once() + + +def test_is_valid_with_invalid_json(): + invalid_json = '{"name": "test_route", "utterances": ["hello", "hi" invalid json}' + with patch("semantic_router.route.logger") as mock_logger: + assert is_valid(invalid_json) is False + mock_logger.error.assert_called_once() + + +class TestRoute: + @patch("semantic_router.route.llm", new_callable=Mock) + def test_generate_dynamic_route(self, mock_llm): + print(f"mock_llm: {mock_llm}") + mock_llm.return_value = """ + <config> + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + </config> + """ + function_schema = {"name": "test_function", "type": "function"} + route = Route._generate_dynamic_route(function_schema) + assert route.name == "test_function" + assert route.utterances == [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5", + ] + + # TODO add async version + # @pytest.mark.asyncio + # @patch("semantic_router.route.allm", new_callable=Mock) + # async def test_generate_dynamic_route_async(self, mock_llm): + # print(f"mock_llm: {mock_llm}") + # mock_llm.return_value = """ + # <config> + # { + # "name": "test_function", + # "utterances": [ + # "example_utterance_1", + # "example_utterance_2", + # "example_utterance_3", + # "example_utterance_4", + # "example_utterance_5"] + # } + # </config> + # """ + # function_schema = {"name": "test_function", "type": "function"} + # route = await Route._generate_dynamic_route(function_schema) + # assert route.name == "test_function" + # assert route.utterances == [ + # "example_utterance_1", + # "example_utterance_2", + # "example_utterance_3", + # "example_utterance_4", + # "example_utterance_5", + # ] + + def test_to_dict(self): + route = Route(name="test", utterances=["utterance"]) + expected_dict = { + "name": "test", + "utterances": ["utterance"], + "description": None, + "function_schema": None, + } + assert route.to_dict() == expected_dict + + def test_from_dict(self): + route_dict = {"name": "test", "utterances": ["utterance"]} + route = Route.from_dict(route_dict) + assert route.name == "test" + assert route.utterances == ["utterance"] + + @patch("semantic_router.route.llm", new_callable=Mock) + def test_from_dynamic_route(self, mock_llm): + # Mock the llm function + mock_llm.return_value = """ + <config> + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + </config> + """ + + def test_function(input: str): + """Test function docstring""" + pass + + dynamic_route = Route.from_dynamic_route(test_function) + + assert dynamic_route.name == "test_function" + assert dynamic_route.utterances == [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5", + ] + + # TODO add async functions + # @pytest.mark.asyncio + # @patch("semantic_router.route.allm", new_callable=AsyncMock) + # async def test_from_dynamic_route_async(self, mock_llm): + # # Mock the llm function + # mock_llm.return_value = """ + # <config> + # { + # "name": "test_function", + # "utterances": [ + # "example_utterance_1", + # "example_utterance_2", + # "example_utterance_3", + # "example_utterance_4", + # "example_utterance_5"] + # } + # </config> + # """ + + # def test_function(input: str): + # """Test function docstring""" + # pass + + # dynamic_route = await Route.from_dynamic_route(test_function) + + # assert dynamic_route.name == "test_function" + # assert dynamic_route.utterances == [ + # "example_utterance_1", + # "example_utterance_2", + # "example_utterance_3", + # "example_utterance_4", + # "example_utterance_5", + # ] + + def test_parse_route_config(self): + config = """ + <config> + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + </config> + """ + expected_config = """ + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + """ + assert Route._parse_route_config(config).strip() == expected_config.strip() diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index f47643c9a2c55321234873d61f39a0751e4c7a45..f4f9762368a987340c95c20bf95f50dbc5c63447 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -7,14 +7,6 @@ from semantic_router.schemas.encoder import ( OpenAIEncoder, ) -from semantic_router.schemas.route import ( - Route, -) - -from semantic_router.schemas.semantic_space import ( - SemanticSpace, -) - class TestEncoderDataclass: def test_encoder_initialization_openai(self, mocker): @@ -46,20 +38,3 @@ class TestEncoderDataclass: encoder = Encoder(type="openai", name="test-engine") result = encoder(["test"]) assert result == [0.1, 0.2, 0.3] - - -class TestSemanticSpaceDataclass: - def test_semanticspace_initialization(self): - semantic_space = SemanticSpace() - assert semantic_space.id == "" - assert semantic_space.routes == [] - - def test_semanticspace_add_route(self): - route = Route(name="test", utterances=["hello", "hi"], description="greeting") - semantic_space = SemanticSpace() - semantic_space.add(route) - - assert len(semantic_space.routes) == 1 - assert semantic_space.routes[0].name == "test" - assert semantic_space.routes[0].utterances == ["hello", "hi"] - assert semantic_space.routes[0].description == "greeting" diff --git a/walkthrough.ipynb b/walkthrough.ipynb deleted file mode 100644 index 346b576cdb8fe517580aca8e201cbaf9d5eb4a01..0000000000000000000000000000000000000000 --- a/walkthrough.ipynb +++ /dev/null @@ -1,237 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Semantic Router Walkthrough\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The Semantic Router library can be used as a super fast route making layer on top of LLMs. That means rather than waiting on a slow agent to decide what to do, we can use the magic of semantic vector space to make routes. Cutting route making time down from seconds to milliseconds.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Getting Started\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We start by installing the library:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -qU semantic-router==0.0.8" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We start by defining a dictionary mapping routes to example phrases that should trigger those routes.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/danielgriffiths/Coding_files/Aurelio_local/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "from semantic_router.schemas.route import Route\n", - "\n", - "politics = Route(\n", - " name=\"politics\",\n", - " utterances=[\n", - " \"isn't politics the best thing ever\",\n", - " \"why don't you tell me about your political opinions\",\n", - " \"don't you just love the president\" \"don't you just hate the president\",\n", - " \"they're going to destroy this country!\",\n", - " \"they will save the country!\",\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's define another for good measure:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "chitchat = Route(\n", - " name=\"chitchat\",\n", - " utterances=[\n", - " \"how's the weather today?\",\n", - " \"how are things going?\",\n", - " \"lovely weather today\",\n", - " \"the weather is horrendous\",\n", - " \"let's go to the chippy\",\n", - " ],\n", - ")\n", - "\n", - "routes = [politics, chitchat]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we initialize our embedding model:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from getpass import getpass\n", - "from semantic_router.encoders import CohereEncoder\n", - "\n", - "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n", - " \"Enter Cohere API Key: \"\n", - ")\n", - "\n", - "encoder = CohereEncoder()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.layer import RouteLayer\n", - "\n", - "dl = RouteLayer(encoder=encoder, routes=routes)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can test it:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'politics'" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dl(\"don't you love politics?\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'chitchat'" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dl(\"how's the weather today?\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects?\n" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "dl(\"I'm interested in learning about llama 2\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this case, we return `None` because no matches were identified.\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "decision-layer", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}