diff --git a/Makefile b/Makefile index 8de202fa56f0de52a80f6f8a63e68ab6fe18ef33..aeb3d3b19ff9262b933b9022f6fc80c240279040 100644 --- a/Makefile +++ b/Makefile @@ -12,4 +12,4 @@ lint lint_diff: poetry run mypy $(PYTHON_FILES) test: - poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=100 + poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=80 diff --git a/coverage.xml b/coverage.xml index 899b23e42c1c74d01a00bacf9a22de78426e9f2b..27e175c33ef0897bd44b76a15c97bbb33f9afecc 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,19 +1,20 @@ <?xml version="1.0" ?> -<coverage version="7.3.3" timestamp="1702906788381" lines-valid="352" lines-covered="352" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> +<coverage version="7.3.3" timestamp="1703085147401" lines-valid="544" lines-covered="470" line-rate="0.864" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> <!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.3 --> <!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd --> <sources> <source>/Users/jakit/customers/aurelio/semantic-router/semantic_router</source> </sources> <packages> - <package name="." line-rate="1" branch-rate="0" complexity="0"> + <package name="." line-rate="0.9527" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> <line number="2" hits="1"/> - <line number="4" hits="1"/> + <line number="3" hits="1"/> + <line number="5" hits="1"/> </lines> </class> <class name="hybrid_layer.py" filename="hybrid_layer.py" complexity="0" line-rate="1" branch-rate="0"> @@ -24,175 +25,185 @@ <line number="4" hits="1"/> <line number="10" hits="1"/> <line number="11" hits="1"/> - <line number="14" hits="1"/> - <line number="15" hits="1"/> + <line number="13" hits="1"/> <line number="16" hits="1"/> <line number="17" hits="1"/> <line number="18" hits="1"/> + <line number="19" hits="1"/> <line number="20" hits="1"/> - <line number="23" hits="1"/> - <line number="24" hits="1"/> + <line number="22" hits="1"/> <line number="25" hits="1"/> + <line number="26" hits="1"/> <line number="27" hits="1"/> - <line number="28" hits="1"/> <line number="29" hits="1"/> <line number="30" hits="1"/> + <line number="31" hits="1"/> <line number="32" hits="1"/> <line number="34" hits="1"/> + <line number="36" hits="1"/> <line number="38" hits="1"/> - <line number="40" hits="1"/> + <line number="39" hits="1"/> <line number="41" hits="1"/> <line number="42" hits="1"/> <line number="43" hits="1"/> <line number="44" hits="1"/> <line number="45" hits="1"/> - <line number="47" hits="1"/> - <line number="49" hits="1"/> + <line number="46" hits="1"/> + <line number="48" hits="1"/> <line number="50" hits="1"/> - <line number="52" hits="1"/> - <line number="54" hits="1"/> + <line number="51" hits="1"/> + <line number="53" hits="1"/> <line number="55" hits="1"/> - <line number="60" hits="1"/> + <line number="56" hits="1"/> <line number="61" hits="1"/> <line number="62" hits="1"/> - <line number="64" hits="1"/> + <line number="63" hits="1"/> <line number="65" hits="1"/> <line number="66" hits="1"/> - <line number="70" hits="1"/> + <line number="67" hits="1"/> <line number="71" hits="1"/> - <line number="73" hits="1"/> - <line number="75" hits="1"/> + <line number="72" hits="1"/> + <line number="74" hits="1"/> <line number="76" hits="1"/> - <line number="78" hits="1"/> - <line number="80" hits="1"/> - <line number="82" hits="1"/> - <line number="83" hits="1"/> + <line number="77" hits="1"/> + <line number="79" hits="1"/> + <line number="81" hits="1"/> <line number="86" hits="1"/> <line number="87" hits="1"/> + <line number="89" hits="1"/> <line number="90" hits="1"/> - <line number="91" hits="1"/> <line number="92" hits="1"/> - <line number="99" hits="1"/> + <line number="94" hits="1"/> + <line number="96" hits="1"/> + <line number="97" hits="1"/> + <line number="98" hits="1"/> + <line number="100" hits="1"/> + <line number="101" hits="1"/> + <line number="102" hits="1"/> + <line number="103" hits="1"/> + <line number="105" hits="1"/> <line number="106" hits="1"/> + <line number="107" hits="1"/> + <line number="109" hits="1"/> + <line number="110" hits="1"/> <line number="112" hits="1"/> + <line number="113" hits="1"/> + <line number="115" hits="1"/> <line number="117" hits="1"/> <line number="118" hits="1"/> - <line number="120" hits="1"/> + <line number="119" hits="1"/> <line number="121" hits="1"/> <line number="123" hits="1"/> <line number="125" hits="1"/> + <line number="126" hits="1"/> <line number="127" hits="1"/> - <line number="128" hits="1"/> <line number="129" hits="1"/> - <line number="131" hits="1"/> <line number="132" hits="1"/> <line number="133" hits="1"/> - <line number="134" hits="1"/> <line number="136" hits="1"/> <line number="137" hits="1"/> - <line number="138" hits="1"/> + <line number="139" hits="1"/> <line number="140" hits="1"/> - <line number="141" hits="1"/> + <line number="142" hits="1"/> <line number="143" hits="1"/> <line number="144" hits="1"/> <line number="146" hits="1"/> - <line number="148" hits="1"/> - <line number="149" hits="1"/> - <line number="150" hits="1"/> - <line number="152" hits="1"/> - <line number="153" hits="1"/> - <line number="154" hits="1"/> - <line number="155" hits="1"/> - <line number="156" hits="1"/> - <line number="157" hits="1"/> - <line number="158" hits="1"/> - <line number="160" hits="1"/> - <line number="163" hits="1"/> - <line number="164" hits="1"/> - <line number="167" hits="1"/> - <line number="168" hits="1"/> - <line number="170" hits="1"/> - <line number="171" hits="1"/> - <line number="173" hits="1"/> - <line number="174" hits="1"/> - <line number="175" hits="1"/> - <line number="177" hits="1"/> </lines> </class> - <class name="layer.py" filename="layer.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="layer.py" filename="layer.py" complexity="0" line-rate="0.8791" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> <line number="3" hits="1"/> - <line number="8" hits="1"/> - <line number="9" hits="1"/> - <line number="10" hits="1"/> - <line number="13" hits="1"/> + <line number="4" hits="1"/> + <line number="6" hits="1"/> + <line number="11" hits="1"/> + <line number="12" hits="1"/> <line number="14" hits="1"/> - <line number="15" hits="1"/> - <line number="16" hits="1"/> + <line number="17" hits="1"/> <line number="18" hits="1"/> <line number="19" hits="1"/> - <line number="21" hits="1"/> + <line number="20" hits="1"/> <line number="22" hits="1"/> <line number="23" hits="1"/> <line number="24" hits="1"/> - <line number="26" hits="1"/> + <line number="25" hits="1"/> + <line number="27" hits="1"/> <line number="28" hits="1"/> + <line number="29" hits="1"/> <line number="30" hits="1"/> <line number="32" hits="1"/> - <line number="33" hits="1"/> <line number="34" hits="1"/> - <line number="35" hits="1"/> <line number="36" hits="1"/> - <line number="37" hits="1"/> + <line number="38" hits="1"/> <line number="39" hits="1"/> + <line number="40" hits="1"/> <line number="41" hits="1"/> + <line number="42" hits="1"/> <line number="43" hits="1"/> - <line number="46" hits="1"/> + <line number="45" hits="1"/> <line number="47" hits="1"/> - <line number="49" hits="1"/> - <line number="50" hits="1"/> - <line number="52" hits="1"/> - <line number="53" hits="1"/> + <line number="48" hits="1"/> + <line number="49" hits="0"/> + <line number="50" hits="0"/> + <line number="51" hits="0"/> + <line number="52" hits="0"/> + <line number="54" hits="1"/> <line number="55" hits="1"/> - <line number="56" hits="1"/> - <line number="58" hits="1"/> - <line number="60" hits="1"/> + <line number="56" hits="0"/> + <line number="57" hits="0"/> + <line number="58" hits="0"/> + <line number="59" hits="0"/> + <line number="61" hits="1"/> <line number="63" hits="1"/> <line number="66" hits="1"/> <line number="67" hits="1"/> - <line number="68" hits="1"/> + <line number="69" hits="1"/> + <line number="70" hits="1"/> + <line number="72" hits="1"/> + <line number="73" hits="1"/> <line number="75" hits="1"/> <line number="76" hits="1"/> - <line number="82" hits="1"/> + <line number="78" hits="1"/> + <line number="80" hits="1"/> + <line number="83" hits="1"/> + <line number="86" hits="1"/> <line number="87" hits="1"/> <line number="88" hits="1"/> - <line number="90" hits="1"/> - <line number="92" hits="1"/> - <line number="93" hits="1"/> <line number="95" hits="1"/> <line number="96" hits="1"/> - <line number="98" hits="1"/> - <line number="99" hits="1"/> - <line number="101" hits="1"/> <line number="102" hits="1"/> - <line number="103" hits="1"/> - <line number="104" hits="1"/> - <line number="105" hits="1"/> - <line number="106" hits="1"/> <line number="107" hits="1"/> - <line number="109" hits="1"/> + <line number="108" hits="1"/> + <line number="110" hits="1"/> <line number="112" hits="1"/> <line number="113" hits="1"/> + <line number="115" hits="1"/> <line number="116" hits="1"/> - <line number="117" hits="1"/> + <line number="118" hits="1"/> <line number="119" hits="1"/> - <line number="120" hits="1"/> + <line number="121" hits="1"/> <line number="122" hits="1"/> <line number="123" hits="1"/> <line number="124" hits="1"/> + <line number="125" hits="1"/> <line number="126" hits="1"/> + <line number="127" hits="1"/> + <line number="129" hits="1"/> + <line number="132" hits="1"/> + <line number="133" hits="1"/> + <line number="136" hits="1"/> + <line number="137" hits="1"/> + <line number="139" hits="1"/> + <line number="140" hits="1"/> + <line number="142" hits="1"/> + <line number="143" hits="1"/> + <line number="144" hits="1"/> + <line number="146" hits="1"/> + <line number="148" hits="1"/> + <line number="149" hits="0"/> + <line number="150" hits="0"/> + <line number="151" hits="0"/> </lines> </class> <class name="linear.py" filename="linear.py" complexity="0" line-rate="1" branch-rate="0"> @@ -213,13 +224,124 @@ <line number="30" hits="1"/> </lines> </class> - <class name="schema.py" filename="schema.py" complexity="0" line-rate="1" branch-rate="0"> + <class name="route.py" filename="route.py" complexity="0" line-rate="0.9528" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> + <line number="2" hits="1"/> <line number="3" hits="1"/> <line number="4" hits="1"/> <line number="6" hits="1"/> + <line number="7" hits="1"/> + <line number="9" hits="1"/> + <line number="10" hits="1"/> + <line number="11" hits="1"/> + <line number="14" hits="1"/> + <line number="15" hits="1"/> + <line number="16" hits="1"/> + <line number="17" hits="1"/> + <line number="19" hits="1"/> + <line number="20" hits="1"/> + <line number="21" hits="1"/> + <line number="22" hits="1"/> + <line number="23" hits="1"/> + <line number="26" hits="1"/> + <line number="27" hits="1"/> + <line number="29" hits="1"/> + <line number="30" hits="1"/> + <line number="31" hits="1"/> + <line number="34" hits="1"/> + <line number="36" hits="1"/> + <line number="37" hits="1"/> + <line number="38" hits="1"/> + <line number="39" hits="1"/> + <line number="42" hits="1"/> + <line number="43" hits="1"/> + <line number="44" hits="1"/> + <line number="45" hits="1"/> + <line number="47" hits="1"/> + <line number="48" hits="1"/> + <line number="50" hits="1"/> + <line number="51" hits="1"/> + <line number="52" hits="1"/> + <line number="54" hits="1"/> + <line number="55" hits="1"/> + <line number="59" hits="1"/> + <line number="60" hits="1"/> + <line number="61" hits="1"/> + <line number="63" hits="1"/> + <line number="64" hits="1"/> + <line number="66" hits="1"/> + <line number="67" hits="1"/> + <line number="69" hits="1"/> + <line number="70" hits="1"/> + <line number="71" hits="1"/> + <line number="73" hits="0"/> + <line number="75" hits="1"/> + <line number="76" hits="1"/> + <line number="77" hits="1"/> + <line number="79" hits="1"/> + <line number="104" hits="1"/> + <line number="105" hits="1"/> + <line number="106" hits="0"/> + <line number="108" hits="1"/> + <line number="110" hits="1"/> + <line number="112" hits="1"/> + <line number="113" hits="1"/> + <line number="114" hits="0"/> + <line number="117" hits="1"/> + <line number="122" hits="1"/> + <line number="124" hits="1"/> + <line number="125" hits="1"/> + <line number="127" hits="1"/> + <line number="128" hits="1"/> + <line number="130" hits="1"/> + <line number="131" hits="1"/> + <line number="132" hits="1"/> + <line number="133" hits="1"/> + <line number="134" hits="1"/> + <line number="135" hits="1"/> + <line number="136" hits="1"/> + <line number="138" hits="1"/> + <line number="142" hits="1"/> + <line number="143" hits="1"/> + <line number="144" hits="1"/> + <line number="145" hits="1"/> + <line number="147" hits="0"/> + <line number="149" hits="1"/> + <line number="150" hits="1"/> + <line number="152" hits="1"/> + <line number="154" hits="1"/> + <line number="155" hits="1"/> + <line number="156" hits="1"/> + <line number="157" hits="1"/> + <line number="158" hits="1"/> + <line number="159" hits="1"/> + <line number="160" hits="1"/> + <line number="162" hits="1"/> + <line number="166" hits="1"/> + <line number="167" hits="1"/> + <line number="168" hits="1"/> + <line number="170" hits="1"/> + <line number="171" hits="1"/> + <line number="172" hits="1"/> + <line number="173" hits="1"/> + <line number="174" hits="1"/> + <line number="175" hits="1"/> + <line number="177" hits="1"/> + <line number="178" hits="1"/> + <line number="179" hits="0"/> + <line number="181" hits="1"/> + <line number="182" hits="1"/> + </lines> + </class> + <class name="schema.py" filename="schema.py" complexity="0" line-rate="1" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="3" hits="1"/> + <line number="5" hits="1"/> + <line number="6" hits="1"/> <line number="13" hits="1"/> <line number="14" hits="1"/> <line number="15" hits="1"/> @@ -228,32 +350,28 @@ <line number="20" hits="1"/> <line number="21" hits="1"/> <line number="22" hits="1"/> + <line number="23" hits="1"/> <line number="25" hits="1"/> <line number="26" hits="1"/> <line number="27" hits="1"/> <line number="28" hits="1"/> <line number="29" hits="1"/> + <line number="30" hits="1"/> <line number="31" hits="1"/> <line number="32" hits="1"/> <line number="33" hits="1"/> - <line number="34" hits="1"/> <line number="35" hits="1"/> <line number="36" hits="1"/> - <line number="37" hits="1"/> - <line number="38" hits="1"/> <line number="39" hits="1"/> + <line number="40" hits="1"/> <line number="41" hits="1"/> <line number="42" hits="1"/> + <line number="43" hits="1"/> <line number="45" hits="1"/> <line number="46" hits="1"/> <line number="47" hits="1"/> - <line number="48" hits="1"/> <line number="49" hits="1"/> - <line number="51" hits="1"/> - <line number="52" hits="1"/> - <line number="53" hits="1"/> - <line number="55" hits="1"/> - <line number="56" hits="1"/> + <line number="50" hits="1"/> </lines> </class> </classes> @@ -397,12 +515,98 @@ </class> </classes> </package> - <package name="utils" line-rate="1" branch-rate="0" complexity="0"> + <package name="utils" line-rate="0.3958" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="utils/__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> <lines/> </class> + <class name="function_call.py" filename="utils/function_call.py" complexity="0" line-rate="0.2258" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="2" hits="1"/> + <line number="3" hits="1"/> + <line number="5" hits="1"/> + <line number="7" hits="1"/> + <line number="8" hits="1"/> + <line number="11" hits="1"/> + <line number="12" hits="1"/> + <line number="13" hits="0"/> + <line number="14" hits="0"/> + <line number="15" hits="0"/> + <line number="16" hits="0"/> + <line number="18" hits="0"/> + <line number="19" hits="0"/> + <line number="20" hits="0"/> + <line number="24" hits="0"/> + <line number="26" hits="0"/> + <line number="27" hits="0"/> + <line number="28" hits="0"/> + <line number="34" hits="1"/> + <line number="40" hits="1"/> + <line number="43" hits="1"/> + <line number="44" hits="0"/> + <line number="46" hits="0"/> + <line number="75" hits="0"/> + <line number="76" hits="0"/> + <line number="77" hits="0"/> + <line number="79" hits="0"/> + <line number="81" hits="0"/> + <line number="82" hits="0"/> + <line number="83" hits="0"/> + <line number="84" hits="0"/> + <line number="87" hits="1"/> + <line number="89" hits="0"/> + <line number="91" hits="0"/> + <line number="92" hits="0"/> + <line number="93" hits="0"/> + <line number="94" hits="0"/> + <line number="98" hits="0"/> + <line number="99" hits="0"/> + <line number="100" hits="0"/> + <line number="101" hits="0"/> + <line number="102" hits="0"/> + <line number="103" hits="0"/> + <line number="104" hits="0"/> + <line number="105" hits="0"/> + <line number="108" hits="1"/> + <line number="109" hits="0"/> + <line number="110" hits="0"/> + <line number="111" hits="0"/> + <line number="112" hits="0"/> + <line number="116" hits="1"/> + <line number="117" hits="0"/> + <line number="118" hits="0"/> + <line number="119" hits="0"/> + <line number="120" hits="0"/> + <line number="122" hits="0"/> + <line number="123" hits="0"/> + <line number="124" hits="0"/> + <line number="125" hits="0"/> + <line number="126" hits="0"/> + <line number="127" hits="0"/> + </lines> + </class> + <class name="llm.py" filename="utils/llm.py" complexity="0" line-rate="0.2857" branch-rate="0"> + <methods/> + <lines> + <line number="1" hits="1"/> + <line number="3" hits="1"/> + <line number="5" hits="1"/> + <line number="8" hits="1"/> + <line number="9" hits="0"/> + <line number="10" hits="0"/> + <line number="15" hits="0"/> + <line number="27" hits="0"/> + <line number="29" hits="0"/> + <line number="30" hits="0"/> + <line number="31" hits="0"/> + <line number="32" hits="0"/> + <line number="33" hits="0"/> + <line number="34" hits="0"/> + </lines> + </class> <class name="logger.py" filename="utils/logger.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> <lines> diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 466b8f0af552168c19d8bc355ad73e58835628f9..d082468b997726bc51e12f1e1d9f6b32ffe52697 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -4,396 +4,201 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Define LLMs" + "### Set up functions and routes" ] }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "# OpenAI\n", - "import openai\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "# Docs # https://platform.openai.com/docs/guides/function-calling\n", - "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n", - " try:\n", - " logger.info(f\"Calling {model} model\")\n", - " response = openai.chat.completions.create(\n", - " model=model,\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": f\"{prompt}\"},\n", - " ],\n", - " )\n", - " ai_message = response.choices[0].message.content\n", - " if not ai_message:\n", - " raise Exception(\"AI message is empty\", ai_message)\n", - " logger.info(f\"AI message: {ai_message}\")\n", - " return ai_message\n", - " except Exception as e:\n", - " raise Exception(\"Failed to call OpenAI API\", e)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "# Mistral\n", - "import os\n", - "import requests\n", - "\n", - "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n", - "HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n", - "\n", + "def get_time(location: str) -> str:\n", + " \"\"\"Useful to get the time in a specific location\"\"\"\n", + " print(f\"Result from: `get_time` function with location: `{location}`\")\n", + " return \"get_time\"\n", "\n", - "def llm_mistral(prompt: str) -> str:\n", - " api_url = \"https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/\"\n", - " headers = {\n", - " \"Authorization\": f\"Bearer {HF_API_TOKEN}\",\n", - " \"Content-Type\": \"application/json\",\n", - " }\n", "\n", - " logger.info(\"Calling Mistral model\")\n", - " response = requests.post(\n", - " api_url,\n", - " headers=headers,\n", - " json={\n", - " \"inputs\": f\"You are a helpful assistant, user query: {prompt}\",\n", - " \"parameters\": {\n", - " \"max_new_tokens\": 200,\n", - " \"temperature\": 0.01,\n", - " \"num_beams\": 5,\n", - " \"num_return_sequences\": 1,\n", - " },\n", - " },\n", + "def get_news(category: str, country: str) -> str:\n", + " \"\"\"Useful to get the news in a specific country\"\"\"\n", + " print(\n", + " f\"Result from: `get_news` function with category: `{category}` \"\n", + " f\"and country: `{country}`\"\n", " )\n", - " if response.status_code != 200:\n", - " raise Exception(\"Failed to call HuggingFace API\", response.text)\n", - "\n", - " ai_message = response.json()[0][\"generated_text\"]\n", - " if not ai_message:\n", - " raise Exception(\"AI message is empty\", ai_message)\n", - " logger.info(f\"AI message: {ai_message}\")\n", - " return ai_message" + " return \"get_news\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Now we need to generate config from function schema using LLM" + "Now generate a dynamic routing config for each function" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-20 12:21:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generated route config:\n", + "{\n", + " \"name\": \"get_time\",\n", + " \"utterances\": [\n", + " \"What's the time in New York?\",\n", + " \"Can you tell me the time in Tokyo?\",\n", + " \"What's the current time in London?\",\n", + " \"Can you give me the time in Sydney?\",\n", + " \"What's the time in Paris?\"\n", + " ]\n", + "}\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:38 INFO semantic_router.utils.logger Generated route config:\n", + "{\n", + " \"name\": \"get_news\",\n", + " \"utterances\": [\n", + " \"Tell me the latest news from the United States\",\n", + " \"What's happening in India today?\",\n", + " \"Can you give me the top stories from Japan\",\n", + " \"Get me the breaking news from the UK\",\n", + " \"What's the latest in Germany?\"\n", + " ]\n", + "}\u001b[0m\n", + "/var/folders/gf/cvm58m_x6pvghy227n5cmx5w0000gn/T/ipykernel_65737/1850296463.py:10: RuntimeWarning: coroutine 'Route.from_dynamic_route' was never awaited\n", + " route_config = RouteConfig(routes=routes)\n", + "RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n" + ] + } + ], "source": [ - "import inspect\n", - "from typing import Any\n", - "\n", + "from semantic_router.route import Route, RouteConfig\n", "\n", - "def get_function_schema(function) -> dict[str, Any]:\n", - " schema = {\n", - " \"name\": function.__name__,\n", - " \"description\": str(inspect.getdoc(function)),\n", - " \"signature\": str(inspect.signature(function)),\n", - " \"output\": str(\n", - " inspect.signature(function).return_annotation,\n", - " ),\n", - " }\n", - " return schema" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", + "functions = [get_time, get_news]\n", + "routes = []\n", "\n", + "for function in functions:\n", + " route = await Route.from_dynamic_route(entity=function)\n", + " routes.append(route)\n", "\n", - "def is_valid_config(route_config_str: str) -> bool:\n", - " try:\n", - " output_json = json.loads(route_config_str)\n", - " return all(key in output_json for key in [\"name\", \"utterances\"])\n", - " except json.JSONDecodeError:\n", - " return False" + "route_config = RouteConfig(routes=routes)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Added route `get_weather`\u001b[0m\n", + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Removed route `get_weather`\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "[{'name': 'get_time',\n", + " 'utterances': [\"What's the time in New York?\",\n", + " 'Can you tell me the time in Tokyo?',\n", + " \"What's the current time in London?\",\n", + " 'Can you give me the time in Sydney?',\n", + " \"What's the time in Paris?\"],\n", + " 'description': None},\n", + " {'name': 'get_news',\n", + " 'utterances': ['Tell me the latest news from the United States',\n", + " \"What's happening in India today?\",\n", + " 'Can you give me the top stories from Japan',\n", + " 'Get me the breaking news from the UK',\n", + " \"What's the latest in Germany?\"],\n", + " 'description': None}]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "import json\n", - "\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "def generate_route(function) -> dict:\n", - " logger.info(\"Generating config...\")\n", - "\n", - " function_schema = get_function_schema(function)\n", - "\n", - " prompt = f\"\"\"\n", - " You are tasked to generate a JSON configuration based on the provided\n", - " function schema. Please follow the template below:\n", + "# You can manually add or remove routes\n", "\n", - " {{\n", - " \"name\": \"<function_name>\",\n", - " \"utterances\": [\n", - " \"<example_utterance_1>\",\n", - " \"<example_utterance_2>\",\n", - " \"<example_utterance_3>\",\n", - " \"<example_utterance_4>\",\n", - " \"<example_utterance_5>\"]\n", - " }}\n", + "get_weather_route = Route(\n", + " name=\"get_weather\",\n", + " utterances=[\n", + " \"what is the weather in SF\",\n", + " \"what is the current temperature in London?\",\n", + " \"tomorrow's weather in Paris?\",\n", + " ],\n", + ")\n", + "route_config.add(get_weather_route)\n", "\n", - " Only include the \"name\" and \"utterances\" keys in your answer.\n", - " The \"name\" should match the function name and the \"utterances\"\n", - " should comprise a list of 5 example phrases that could be used to invoke\n", - " the function.\n", + "route_config.remove(\"get_weather\")\n", "\n", - " Input schema:\n", - " {function_schema}\n", - " \"\"\"\n", - "\n", - " try:\n", - " ai_message = llm_mistral(prompt)\n", - "\n", - " # Parse the response\n", - " ai_message = ai_message[ai_message.find(\"{\") :]\n", - " ai_message = (\n", - " ai_message.replace(\"'\", '\"')\n", - " .replace('\"s', \"'s\")\n", - " .strip()\n", - " .rstrip(\",\")\n", - " .replace(\"}\", \"}\")\n", - " )\n", - "\n", - " valid_config = is_valid_config(ai_message)\n", - "\n", - " if not valid_config:\n", - " logger.warning(f\"Mistral failed with error, falling back to OpenAI\")\n", - " ai_message = llm_openai(prompt)\n", - " if not is_valid_config(ai_message):\n", - " raise Exception(\"Invalid config generated\")\n", - " except Exception as e:\n", - " logger.error(f\"Fall back to OpenAI failed with error {e}\")\n", - " ai_message = llm_openai(prompt)\n", - " if not is_valid_config(ai_message):\n", - " raise Exception(\"Failed to generate config\")\n", - "\n", - " try:\n", - " route_config = json.loads(ai_message)\n", - " logger.info(f\"Generated config: {route_config}\")\n", - " return route_config\n", - " except json.JSONDecodeError as json_error:\n", - " logger.error(f\"JSON parsing error {json_error}\")\n", - " print(f\"AI message: {ai_message}\")\n", - " return {\"error\": \"Failed to generate config\"}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Extract function parameters using `Mistral` open-source model" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "def validate_parameters(function, parameters):\n", - " sig = inspect.signature(function)\n", - " for name, param in sig.parameters.items():\n", - " if name not in parameters:\n", - " return False, f\"Parameter {name} missing from query\"\n", - " if not isinstance(parameters[name], param.annotation):\n", - " return False, f\"Parameter {name} is not of type {param.annotation}\"\n", - " return True, \"Parameters are valid\"" + "route_config.to_dict()" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Route(name='get_time', utterances=[\"What's the time in New York?\", 'Can you tell me the time in Tokyo?', \"What's the current time in London?\", 'Can you give me the time in Sydney?', \"What's the time in Paris?\"], description=None)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "def extract_parameters(query: str, function) -> dict:\n", - " logger.info(\"Extracting parameters...\")\n", - " example_query = \"How is the weather in Hawaii right now in International units?\"\n", - "\n", - " example_schema = {\n", - " \"name\": \"get_weather\",\n", - " \"description\": \"Useful to get the weather in a specific location\",\n", - " \"signature\": \"(location: str, degree: str) -> str\",\n", - " \"output\": \"<class 'str'>\",\n", - " }\n", - "\n", - " example_parameters = {\n", - " \"location\": \"London\",\n", - " \"degree\": \"Celsius\",\n", - " }\n", - "\n", - " prompt = f\"\"\"\n", - " You are a helpful assistant designed to output JSON.\n", - " Given the following function schema\n", - " << {get_function_schema(function)} >>\n", - " and query\n", - " << {query} >>\n", - " extract the parameters values from the query, in a valid JSON format.\n", - " Example:\n", - " Input:\n", - " query: {example_query}\n", - " schema: {example_schema}\n", - "\n", - " Result: {example_parameters}\n", - "\n", - " Input:\n", - " query: {query}\n", - " schema: {get_function_schema(function)}\n", - " Result:\n", - " \"\"\"\n", - "\n", - " try:\n", - " ai_message = llm_mistral(prompt)\n", - " ai_message = (\n", - " ai_message.replace(\"Output:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", - " )\n", - " except Exception as e:\n", - " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", - " ai_message = llm_openai(prompt)\n", - "\n", - " try:\n", - " parameters = json.loads(ai_message)\n", - " valid, message = validate_parameters(function, parameters)\n", - "\n", - " if not valid:\n", - " logger.warning(\n", - " f\"Invalid parameters from Mistral, falling back to OpenAI: {message}\"\n", - " )\n", - " # Fall back to OpenAI\n", - " ai_message = llm_openai(prompt)\n", - " parameters = json.loads(ai_message)\n", - " valid, message = validate_parameters(function, parameters)\n", - " if not valid:\n", - " raise ValueError(message)\n", - "\n", - " logger.info(f\"Extracted parameters: {parameters}\")\n", - " return parameters\n", - " except ValueError as e:\n", - " logger.error(f\"Parameter validation error: {str(e)}\")\n", - " return {\"error\": \"Failed to validate parameters\"}" + "# Get a route by name\n", + "route_config.get(\"get_time\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Set up the routing layer" + "Save config to a file (.json or .yaml)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Saving route config to route_config.json\u001b[0m\n" + ] + } + ], "source": [ - "from semantic_router.schema import Route\n", - "from semantic_router.encoders import CohereEncoder\n", - "from semantic_router.layer import RouteLayer\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "def create_router(routes: list[dict]) -> RouteLayer:\n", - " logger.info(\"Creating route layer...\")\n", - " encoder = CohereEncoder()\n", - "\n", - " route_list: list[Route] = []\n", - " for route in routes:\n", - " if \"name\" in route and \"utterances\" in route:\n", - " print(f\"Route: {route}\")\n", - " route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", - " else:\n", - " logger.warning(f\"Misconfigured route: {route}\")\n", - "\n", - " return RouteLayer(encoder=encoder, routes=route_list)" + "route_config.to_file(\"route_config.json\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Set up calling functions" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Callable\n", - "from semantic_router.layer import RouteLayer\n", - "\n", - "\n", - "def call_function(function: Callable, parameters: dict[str, str]):\n", - " try:\n", - " return function(**parameters)\n", - " except TypeError as e:\n", - " logger.error(f\"Error calling function: {e}\")\n", - "\n", - "\n", - "def call_llm(query: str) -> str:\n", - " try:\n", - " ai_message = llm_mistral(query)\n", - " except Exception as e:\n", - " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", - " ai_message = llm_openai(query)\n", - "\n", - " return ai_message\n", - "\n", - "\n", - "def call(query: str, functions: list[Callable], router: RouteLayer):\n", - " function_name = router(query)\n", - " if not function_name:\n", - " logger.warning(\"No function found\")\n", - " return call_llm(query)\n", - "\n", - " for function in functions:\n", - " if function.__name__ == function_name:\n", - " parameters = extract_parameters(query, function)\n", - " print(f\"parameters: {parameters}\")\n", - " return call_function(function, parameters)" + "### Define routing layer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Workflow" + "Load from local file" ] }, { @@ -405,149 +210,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 14:47:47 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:47 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:50 INFO semantic_router.utils.logger AI message: \n", - " Example output:\n", - " {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"What's the time in New York?\",\n", - " \"Tell me the time in Tokyo.\",\n", - " \"Can you give me the time in London?\",\n", - " \"What's the current time in Sydney?\",\n", - " \"Can you tell me the time in Berlin?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:50 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:50 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:50 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:54 INFO semantic_router.utils.logger AI message: \n", - " Example output:\n", - " {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"Tell me the latest news from the US\",\n", - " \"What's happening in India today?\",\n", - " \"Get me the top stories from Japan\",\n", - " \"Can you give me the breaking news from Brazil?\",\n", - " \"What's the latest news from Germany?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:54 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:54 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Route: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\n", - "Route: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\n" + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Loading route config from route_config.json\u001b[0m\n" ] } ], "source": [ - "def get_time(location: str) -> str:\n", - " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", - " return \"get_time\"\n", + "from semantic_router.route import RouteConfig\n", "\n", - "\n", - "def get_news(category: str, country: str) -> str:\n", - " \"\"\"Useful to get the news in a specific country\"\"\"\n", - " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", - " )\n", - " return \"get_news\"\n", - "\n", - "\n", - "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", - "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", - "\n", - "# Tools\n", - "tools = [get_time, get_news]" + "route_config = RouteConfig.from_file(\"route_config.json\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 14:47:55 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:55 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:58 INFO semantic_router.utils.logger AI message: \n", - " Example output:\n", - " {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"What's the time in New York?\",\n", - " \"Tell me the time in Tokyo.\",\n", - " \"Can you give me the time in London?\",\n", - " \"What's the current time in Sydney?\",\n", - " \"Can you tell me the time in Berlin?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:58 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:58 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 14:47:58 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:02 INFO semantic_router.utils.logger AI message: \n", - " Example output:\n", - " {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"Tell me the latest news from the US\",\n", - " \"What's happening in India today?\",\n", - " \"Get me the top stories from Japan\",\n", - " \"Can you give me the breaking news from Brazil?\",\n", - " \"What's the latest news from Germany?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:02 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:02 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Route: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\n", - "Route: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\n" - ] - } - ], + "outputs": [], "source": [ - "def get_time(location: str) -> str:\n", - " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", - " return \"get_time\"\n", - "\n", - "\n", - "def get_news(category: str, country: str) -> str:\n", - " \"\"\"Useful to get the news in a specific country\"\"\"\n", - " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", - " )\n", - " return \"get_news\"\n", - "\n", - "\n", - "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", + "from semantic_router import RouteLayer\n", "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", - "\n", - "# Tools\n", - "tools = [get_time, get_news]" + "route_layer = RouteLayer(routes=route_config.routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Do a function call with functions as tool" ] }, { @@ -559,58 +247,43 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 14:48:02 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:02 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:04 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"location\": \"Stockholm\"\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "parameters: {'location': 'Stockholm'}\n", - "Calling `get_time` function with location: Stockholm\n" + "Calling function: get_time\n", + "Result from: `get_time` function with location: `Stockholm`\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 14:48:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:05 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"category\": \"tech\",\n", - " \"country\": \"Lithuania\"\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:05 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:49 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "parameters: {'category': 'tech', 'country': 'Lithuania'}\n", - "Calling `get_news` function with category: tech and country: Lithuania\n" + "Calling function: get_news\n", + "Result from: `get_news` function with category: `tech` and country: `Lithuania`\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2023-12-18 14:48:05 WARNING semantic_router.utils.logger No function found\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 14:48:06 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" + "\u001b[33m2023-12-19 17:46:52 WARNING semantic_router.utils.logger No function found, calling LLM...\u001b[0m\n" ] }, { "data": { "text/plain": [ - "' How can I help you today?'" + "'Hello! How can I assist you today?'" ] }, "execution_count": 12, @@ -619,17 +292,20 @@ } ], "source": [ - "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", - "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", - "call(query=\"Hi!\", functions=tools, router=router)" + "from semantic_router.utils.function_call import route_and_execute\n", + "\n", + "tools = [get_time, get_news]\n", + "\n", + "await route_and_execute(\n", + " query=\"What is the time in Stockholm?\", functions=tools, route_layer=route_layer\n", + ")\n", + "await route_and_execute(\n", + " query=\"What is the tech news in the Lithuania?\",\n", + " functions=tools,\n", + " route_layer=route_layer,\n", + ")\n", + "await route_and_execute(query=\"Hi!\", functions=tools, route_layer=route_layer)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -648,7 +324,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.3" } }, "nbformat": 4, diff --git a/docs/examples/route_config.json b/docs/examples/route_config.json new file mode 100644 index 0000000000000000000000000000000000000000..f76a73859e4c3534f37bc99c0aec17627e0d2ee6 --- /dev/null +++ b/docs/examples/route_config.json @@ -0,0 +1 @@ +[{"name": "get_time", "utterances": ["What's the time in New York?", "Can you tell me the time in Tokyo?", "What's the current time in London?", "Can you give me the time in Sydney?", "What's the time in Paris?"], "description": null}, {"name": "get_news", "utterances": ["Tell me the latest news from the United States", "What's happening in India today?", "Can you give me the top stories from Japan", "Get me the breaking news from the UK", "What's the latest in Germany?"], "description": null}] diff --git a/poetry.lock b/poetry.lock index 216d298ddcb9dbc839733929399a6990cbd8584b..7efeda7e28a550d04e2e108fa1901478a7483589 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1594,6 +1594,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.2" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"}, + {file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "4.1.0" @@ -1686,6 +1704,55 @@ files = [ {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, ] +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "pyzmq" version = "25.1.2" @@ -2053,6 +2120,17 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "types-pyyaml" +version = "6.0.12.12" +description = "Typing stubs for PyYAML" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, +] + [[package]] name = "typing-extensions" version = "4.9.0" @@ -2222,4 +2300,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "f2735c243faa3d788c0f6268d6cb550648ed0d1fffec27a084344dafa4590a80" +content-hash = "afd687626ef87dc72424414d7c2333caf360bccb01fab087cfd78b97ea62e04f" diff --git a/pyproject.toml b/pyproject.toml index e45e5f17d0356cce8a2cfe5a33d9fa0529c170c5..5e430824737bf6e097d3cbad30a633f3a73aeadb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ authors = [ "Bogdan Buduroiu <bogdan@aurelio.ai>" ] readme = "README.md" +packages = [{include = "semantic_router"}] [tool.poetry.dependencies] python = "^3.9" @@ -19,6 +20,8 @@ cohere = "^4.32" numpy = "^1.25.2" pinecone-text = "^0.7.0" colorlog = "^6.8.0" +pyyaml = "^6.0.1" +pytest-asyncio = "^0.23.2" [tool.poetry.group.dev.dependencies] @@ -30,6 +33,7 @@ pytest-mock = "^3.12.0" pytest-cov = "^4.1.0" pytest-xdist = "^3.5.0" mypy = "^1.7.1" +types-pyyaml = "^6.0.12.12" [build-system] requires = ["poetry-core"] @@ -38,5 +42,8 @@ build-backend = "poetry.core.masonry.api" [tool.ruff.per-file-ignores] "*.ipynb" = ["ALL"] +[tool.ruff] +line-length = 88 + [tool.mypy] ignore_missing_imports = true diff --git a/semantic_router/__init__.py b/semantic_router/__init__.py index 0c445bea3ff4efd8f3aa8950e2c772277d93b20c..2659bfe3bf4cebe2b022c01ec7139658aeb43eb1 100644 --- a/semantic_router/__init__.py +++ b/semantic_router/__init__.py @@ -1,4 +1,5 @@ from .hybrid_layer import HybridRouteLayer from .layer import RouteLayer +from .route import Route, RouteConfig -__all__ = ["RouteLayer", "HybridRouteLayer"] +__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "RouteConfig"] diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 9f353e941cc4b56262a92d55ff65906aa49e1892..475a12f09b4bcac340d2ceabf77199ceee9cb071 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -7,9 +7,10 @@ from semantic_router.encoders import ( CohereEncoder, OpenAIEncoder, ) -from semantic_router.schema import Route from semantic_router.utils.logger import logger +from .route import Route + class HybridRouteLayer: index = None diff --git a/semantic_router/layer.py b/semantic_router/layer.py index cb408c5c5f452b78e9750976d7c669a01028450a..2fa3b8634de794fa59592c20af5180a4a5dee851 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -1,4 +1,7 @@ +import json + import numpy as np +import yaml from semantic_router.encoders import ( BaseEncoder, @@ -6,17 +9,19 @@ from semantic_router.encoders import ( OpenAIEncoder, ) from semantic_router.linear import similarity_matrix, top_scores -from semantic_router.schema import Route from semantic_router.utils.logger import logger +from .route import Route + class RouteLayer: index = None categories = None score_threshold = 0.82 - def __init__(self, encoder: BaseEncoder, routes: list[Route] = []): - self.encoder = encoder + def __init__(self, encoder: BaseEncoder | None = None, routes: list[Route] = []): + self.encoder = encoder if encoder is not None else CohereEncoder() + self.routes: list[Route] = routes # decide on default threshold based on encoder if isinstance(encoder, OpenAIEncoder): self.score_threshold = 0.82 @@ -27,7 +32,7 @@ class RouteLayer: # if routes list has been passed, we initialize index now if routes: # initialize index now - self.add_routes(routes=routes) + self._add_routes(routes=routes) def __call__(self, text: str) -> str | None: results = self._query(text) @@ -38,7 +43,21 @@ class RouteLayer: else: return None - def add_route(self, route: Route): + @classmethod + def from_json(cls, file_path: str): + with open(file_path, "r") as f: + routes_data = json.load(f) + routes = [Route.from_dict(route_data) for route_data in routes_data] + return cls(routes=routes) + + @classmethod + def from_yaml(cls, file_path: str): + with open(file_path, "r") as f: + routes_data = yaml.load(f, Loader=yaml.FullLoader) + routes = [Route.from_dict(route_data) for route_data in routes_data] + return cls(routes=routes) + + def add(self, route: Route): # create embeddings embeds = self.encoder(route.utterances) @@ -55,7 +74,7 @@ class RouteLayer: embed_arr = np.array(embeds) self.index = np.concatenate([self.index, embed_arr]) - def add_routes(self, routes: list[Route]): + def _add_routes(self, routes: list[Route]): # create embeddings for all routes all_utterances = [ utterance for route in routes for utterance in route.utterances @@ -124,3 +143,8 @@ class RouteLayer: return max(scores) > threshold else: return False + + def to_json(self, file_path: str): + routes = [route.to_dict() for route in self.routes] + with open(file_path, "w") as f: + json.dump(routes, f, indent=4) diff --git a/semantic_router/route.py b/semantic_router/route.py new file mode 100644 index 0000000000000000000000000000000000000000..99a7945bf35563941ddd2d49685bc06b421e734f --- /dev/null +++ b/semantic_router/route.py @@ -0,0 +1,182 @@ +import json +import os +import re +from typing import Any, Callable, Union + +import yaml +from pydantic import BaseModel + +from semantic_router.utils import function_call +from semantic_router.utils.llm import llm +from semantic_router.utils.logger import logger + + +def is_valid(route_config: str) -> bool: + try: + output_json = json.loads(route_config) + required_keys = ["name", "utterances"] + + if isinstance(output_json, list): + for item in output_json: + missing_keys = [key for key in required_keys if key not in item] + if missing_keys: + logger.warning( + f"Missing keys in route config: {', '.join(missing_keys)}" + ) + return False + return True + else: + missing_keys = [key for key in required_keys if key not in output_json] + if missing_keys: + logger.warning( + f"Missing keys in route config: {', '.join(missing_keys)}" + ) + return False + else: + return True + except json.JSONDecodeError as e: + logger.error(e) + return False + + +class Route(BaseModel): + name: str + utterances: list[str] + description: str | None = None + + def to_dict(self): + return self.dict() + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + @classmethod + async def from_dynamic_route(cls, entity: Union[BaseModel, Callable]): + """ + Generate a dynamic Route object from a function or Pydantic model using LLM + """ + schema = function_call.get_schema(item=entity) + dynamic_route = await cls._generate_dynamic_route(function_schema=schema) + return dynamic_route + + @classmethod + def _parse_route_config(cls, config: str) -> str: + # Regular expression to match content inside <config></config> + config_pattern = r"<config>(.*?)</config>" + match = re.search(config_pattern, config, re.DOTALL) + + if match: + config_content = match.group(1).strip() # Get the matched content + return config_content + else: + raise ValueError("No <config></config> tags found in the output.") + + @classmethod + async def _generate_dynamic_route(cls, function_schema: dict[str, Any]): + logger.info("Generating dynamic route...") + + prompt = f""" + You are tasked to generate a JSON configuration based on the provided + function schema. Please follow the template below, no other tokens allowed: + + <config> + {{ + "name": "<function_name>", + "utterances": [ + "<example_utterance_1>", + "<example_utterance_2>", + "<example_utterance_3>", + "<example_utterance_4>", + "<example_utterance_5>"] + }} + </config> + + Only include the "name" and "utterances" keys in your answer. + The "name" should match the function name and the "utterances" + should comprise a list of 5 example phrases that could be used to invoke + the function. Use real values instead of placeholders. + + Input schema: + {function_schema} + """ + + output = await llm(prompt) + if not output: + raise Exception("No output generated for dynamic route") + + route_config = cls._parse_route_config(config=output) + + logger.info(f"Generated route config:\n{route_config}") + + if is_valid(route_config): + return Route.from_dict(json.loads(route_config)) + raise Exception("No config generated") + + +class RouteConfig: + """ + Generates a RouteConfig object from a list of Route objects + """ + + routes: list[Route] = [] + + def __init__(self, routes: list[Route] = []): + self.routes = routes + + @classmethod + def from_file(cls, path: str): + """Load the routes from a file in JSON or YAML format""" + logger.info(f"Loading route config from {path}") + _, ext = os.path.splitext(path) + with open(path, "r") as f: + if ext == ".json": + routes = json.load(f) + elif ext in [".yaml", ".yml"]: + routes = yaml.safe_load(f) + else: + raise ValueError( + "Unsupported file type. Only .json and .yaml are supported" + ) + + route_config_str = json.dumps(routes) + if is_valid(route_config_str): + routes = [Route.from_dict(route) for route in routes] + return cls(routes=routes) + else: + raise Exception("Invalid config JSON or YAML") + + def to_dict(self): + return [route.to_dict() for route in self.routes] + + def to_file(self, path: str): + """Save the routes to a file in JSON or YAML format""" + logger.info(f"Saving route config to {path}") + _, ext = os.path.splitext(path) + with open(path, "w") as f: + if ext == ".json": + json.dump(self.to_dict(), f) + elif ext in [".yaml", ".yml"]: + yaml.safe_dump(self.to_dict(), f) + else: + raise ValueError( + "Unsupported file type. Only .json and .yaml are supported" + ) + + def add(self, route: Route): + self.routes.append(route) + logger.info(f"Added route `{route.name}`") + + def get(self, name: str) -> Route | None: + for route in self.routes: + if route.name == name: + return route + logger.error(f"Route `{name}` not found") + return None + + def remove(self, name: str): + if name not in [route.name for route in self.routes]: + logger.error(f"Route `{name}` not found") + else: + self.routes = [route for route in self.routes if route.name != name] + logger.info(f"Removed route `{name}`") diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 007cddcbeb2c9e464e02a6c7f6cd12d2e9769cbc..4646a637dbffd4ed7ad1b8e2d4dd23cef6df22de 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,8 +1,8 @@ from enum import Enum -from pydantic import BaseModel from pydantic.dataclasses import dataclass +from semantic_router import Route from semantic_router.encoders import ( BaseEncoder, CohereEncoder, @@ -10,12 +10,6 @@ from semantic_router.encoders import ( ) -class Route(BaseModel): - name: str - utterances: list[str] - description: str | None = None - - class EncoderType(Enum): HUGGINGFACE = "huggingface" OPENAI = "openai" diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py new file mode 100644 index 0000000000000000000000000000000000000000..c1b4fceed0492124fc98401f0f1789e89dfddb21 --- /dev/null +++ b/semantic_router/utils/function_call.py @@ -0,0 +1,127 @@ +import inspect +import json +from typing import Any, Callable, Union + +from pydantic import BaseModel + +from semantic_router.utils.llm import llm +from semantic_router.utils.logger import logger + + +def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]: + if isinstance(item, BaseModel): + signature_parts = [] + for field_name, field_model in item.__annotations__.items(): + field_info = item.__fields__[field_name] + default_value = field_info.default + + if default_value: + default_repr = repr(default_value) + signature_part = ( + f"{field_name}: {field_model.__name__} = {default_repr}" + ) + else: + signature_part = f"{field_name}: {field_model.__name__}" + + signature_parts.append(signature_part) + signature = f"({', '.join(signature_parts)}) -> str" + schema = { + "name": item.__class__.__name__, + "description": item.__doc__, + "signature": signature, + } + else: + schema = { + "name": item.__name__, + "description": str(inspect.getdoc(item)), + "signature": str(inspect.signature(item)), + "output": str(inspect.signature(item).return_annotation), + } + return schema + + +async def extract_function_inputs(query: str, function_schema: dict[str, Any]) -> dict: + logger.info("Extracting function input...") + + prompt = f""" + You are a helpful assistant designed to output JSON. + Given the following function schema + << {function_schema} >> + and query + << {query} >> + extract the parameters values from the query, in a valid JSON format. + Example: + Input: + query: "How is the weather in Hawaii right now in International units?" + schema: + {{ + "name": "get_weather", + "description": "Useful to get the weather in a specific location", + "signature": "(location: str, degree: str) -> str", + "output": "<class 'str'>", + }} + + Result: {{ + "location": "London", + "degree": "Celsius", + }} + + Input: + query: {query} + schema: {function_schema} + Result: + """ + + output = await llm(prompt) + if not output: + raise Exception("No output generated for extract function input") + + output = output.replace("'", '"').strip().rstrip(",") + + function_inputs = json.loads(output) + if not is_valid_inputs(function_inputs, function_schema): + raise ValueError("Invalid inputs") + return function_inputs + + +def is_valid_inputs(inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool: + """Validate the extracted inputs against the function schema""" + try: + # Extract parameter names and types from the signature string + signature = function_schema["signature"] + param_info = [param.strip() for param in signature[1:-1].split(",")] + param_names = [info.split(":")[0].strip() for info in param_info] + param_types = [ + info.split(":")[1].strip().split("=")[0].strip() for info in param_info + ] + + for name, type_str in zip(param_names, param_types): + if name not in inputs: + logger.error(f"Input {name} missing from query") + return False + return True + except Exception as e: + logger.error(f"Input validation error: {str(e)}") + return False + + +def call_function(function: Callable, inputs: dict[str, str]): + try: + return function(**inputs) + except TypeError as e: + logger.error(f"Error calling function: {e}") + + +# TODO: Add route layer object to the input, solve circular import issue +async def route_and_execute(query: str, functions: list[Callable], route_layer): + function_name = route_layer(query) + if not function_name: + logger.warning("No function found, calling LLM...") + return await llm(query) + + for function in functions: + if function.__name__ == function_name: + print(f"Calling function: {function.__name__}") + schema = get_schema(function) + inputs = await extract_function_inputs(query, schema) + call_function(function, inputs) diff --git a/semantic_router/utils/llm.py b/semantic_router/utils/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..e912ee1f8ea53cdeaa69a83669384a5d6d165c1c --- /dev/null +++ b/semantic_router/utils/llm.py @@ -0,0 +1,34 @@ +import os + +import openai + +from semantic_router.utils.logger import logger + + +async def llm(prompt: str) -> str | None: + try: + client = openai.AsyncOpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), + ) + + completion = await client.chat.completions.create( + model="mistralai/mistral-7b-instruct", + messages=[ + { + "role": "user", + "content": prompt, + }, + ], + temperature=0.01, + max_tokens=200, + ) + + output = completion.choices[0].message.content + + if not output: + raise Exception("No output generated") + return output + except Exception as e: + logger.error(f"LLM error: {e}") + raise Exception(f"LLM error: {e}") diff --git a/test_output.json b/test_output.json new file mode 100644 index 0000000000000000000000000000000000000000..1f93008593dc770f1f001a47b819d652c14af179 --- /dev/null +++ b/test_output.json @@ -0,0 +1 @@ +[{"name": "test", "utterances": ["utterance"], "description": null}] diff --git a/test_output.txt b/test_output.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test_output.yaml b/test_output.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b71676477f7a48fff6174221c48d3c0595dbf14d --- /dev/null +++ b/test_output.yaml @@ -0,0 +1,4 @@ +- description: null + name: test + utterances: + - utterance diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index c00f887d686f771580237dbb0be96508e1bacf61..f87cb1d281b2884a10a9817b9a838c21e64a9881 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -2,7 +2,7 @@ import pytest from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.hybrid_layer import HybridRouteLayer -from semantic_router.schema import Route +from semantic_router.route import Route def mock_encoder_call(utterances): diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 66e0d53bb9350c77578682f9ea0742b1d3dfe0b2..21b489172b943ca09bb1bf81e31f88e92aa18a08 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -2,7 +2,7 @@ import pytest from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.layer import RouteLayer -from semantic_router.schema import Route +from semantic_router.route import Route def mock_encoder_call(utterances): @@ -65,20 +65,20 @@ class TestRouteLayer: route1 = Route(name="Route 1", utterances=["Yes", "No"]) route2 = Route(name="Route 2", utterances=["Maybe", "Sure"]) - route_layer.add_route(route=route1) + route_layer.add(route=route1) assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 2 assert len(set(route_layer.categories)) == 1 assert set(route_layer.categories) == {"Route 1"} - route_layer.add_route(route=route2) + route_layer.add(route=route2) assert len(route_layer.index) == 4 assert len(set(route_layer.categories)) == 2 assert set(route_layer.categories) == {"Route 1", "Route 2"} def test_add_multiple_routes(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder) - route_layer.add_routes(routes=routes) + route_layer._add_routes(routes=routes) assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py new file mode 100644 index 0000000000000000000000000000000000000000..1de3f0e5174faf9e5bc7a7e66cf069a0170025fd --- /dev/null +++ b/tests/unit/test_route.py @@ -0,0 +1,222 @@ +import os +from unittest.mock import AsyncMock, mock_open, patch + +import pytest + +from semantic_router.route import Route, RouteConfig, is_valid + + +# Is valid test: +def test_is_valid_with_valid_json(): + valid_json = '{"name": "test_route", "utterances": ["hello", "hi"]}' + assert is_valid(valid_json) is True + + +def test_is_valid_with_missing_keys(): + invalid_json = '{"name": "test_route"}' # Missing 'utterances' + with patch("semantic_router.route.logger") as mock_logger: + assert is_valid(invalid_json) is False + mock_logger.warning.assert_called_once() + + +def test_is_valid_with_valid_json_list(): + valid_json_list = ( + '[{"name": "test_route1", "utterances": ["hello"]}, ' + '{"name": "test_route2", "utterances": ["hi"]}]' + ) + assert is_valid(valid_json_list) is True + + +def test_is_valid_with_invalid_json_list(): + invalid_json_list = ( + '[{"name": "test_route1"}, {"name": "test_route2", "utterances": ["hi"]}]' + ) + with patch("semantic_router.route.logger") as mock_logger: + assert is_valid(invalid_json_list) is False + mock_logger.warning.assert_called_once() + + +def test_is_valid_with_invalid_json(): + invalid_json = '{"name": "test_route", "utterances": ["hello", "hi" invalid json}' + with patch("semantic_router.route.logger") as mock_logger: + assert is_valid(invalid_json) is False + mock_logger.error.assert_called_once() + + +class TestRoute: + @pytest.mark.asyncio + @patch("semantic_router.route.llm", new_callable=AsyncMock) + async def test_generate_dynamic_route(self, mock_llm): + print(f"mock_llm: {mock_llm}") + mock_llm.return_value = """ + <config> + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + </config> + """ + function_schema = {"name": "test_function", "type": "function"} + route = await Route._generate_dynamic_route(function_schema) + assert route.name == "test_function" + assert route.utterances == [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5", + ] + + def test_to_dict(self): + route = Route(name="test", utterances=["utterance"]) + expected_dict = { + "name": "test", + "utterances": ["utterance"], + "description": None, + } + assert route.to_dict() == expected_dict + + def test_from_dict(self): + route_dict = {"name": "test", "utterances": ["utterance"]} + route = Route.from_dict(route_dict) + assert route.name == "test" + assert route.utterances == ["utterance"] + + @pytest.mark.asyncio + @patch("semantic_router.route.llm", new_callable=AsyncMock) + async def test_from_dynamic_route(self, mock_llm): + # Mock the llm function + mock_llm.return_value = """ + <config> + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + </config> + """ + + def test_function(input: str): + """Test function docstring""" + pass + + dynamic_route = await Route.from_dynamic_route(test_function) + + assert dynamic_route.name == "test_function" + assert dynamic_route.utterances == [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5", + ] + + def test_parse_route_config(self): + config = """ + <config> + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + </config> + """ + expected_config = """ + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + """ + assert Route._parse_route_config(config).strip() == expected_config.strip() + + +class TestRouteConfig: + def test_init(self): + route_config = RouteConfig() + assert route_config.routes == [] + + def test_to_file_json(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + route_config.to_file("data/test_output.json") + mocked_open.assert_called_once_with("data/test_output.json", "w") + + def test_to_file_yaml(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + route_config.to_file("data/test_output.yaml") + mocked_open.assert_called_once_with("data/test_output.yaml", "w") + + def test_to_file_invalid(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with pytest.raises(ValueError): + route_config.to_file("test_output.txt") + + def test_from_file_json(self): + mock_json_data = '[{"name": "test", "utterances": ["utterance"]}]' + with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open: + route_config = RouteConfig.from_file("data/test.json") + mocked_open.assert_called_once_with("data/test.json", "r") + assert isinstance(route_config, RouteConfig) + + def test_from_file_yaml(self): + mock_yaml_data = "- name: test\n utterances:\n - utterance" + with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open: + route_config = RouteConfig.from_file("data/test.yaml") + mocked_open.assert_called_once_with("data/test.yaml", "r") + assert isinstance(route_config, RouteConfig) + + def test_from_file_invalid(self): + with open("test.txt", "w") as f: + f.write("dummy content") + with pytest.raises(ValueError): + RouteConfig.from_file("test.txt") + os.remove("test.txt") + + def test_to_dict(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.to_dict() == [route.to_dict()] + + def test_add(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig() + route_config.add(route) + assert route_config.routes == [route] + + def test_get(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.get("test") == route + + def test_get_not_found(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.get("not_found") is None + + def test_remove(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + route_config.remove("test") + assert route_config.routes == [] diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index f471755c35796d33ac9329a2ddb3a20816230cda..27c73c9fc781850011d2f1732fe4c6958a5bcd3b 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -1,11 +1,11 @@ import pytest +from semantic_router.route import Route from semantic_router.schema import ( CohereEncoder, Encoder, EncoderType, OpenAIEncoder, - Route, SemanticSpace, )