{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/06-threshold-optimization.ipynb) [](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/06-threshold-optimization)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Route Threshold Optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Route score thresholds are what defines whether a route should be chosen. If the score we identify for any given route is higher than the `Route.score_threshold` it passes, otherwise it does not and _either_ another route is chosen, or we return _no_ route.\n", "\n", "Given that this one `score_threshold` parameter can define the choice of a route, it's important to get it right — but it's incredibly inefficient to do so manually. Instead, we can use the `fit` and `evaluate` methods of our `RouteLayer`. All we must do is pass a smaller number of _(utterance, target route)_ examples to our methods, and with `fit` we will often see dramatically improved performance within seconds — we will see how to measure that performance gain with `evaluate`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -qU \"semantic-router[local]==0.0.20\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define RouteLayer\n", "\n", "As usual we will define our `RouteLayer`. The `RouteLayer` requires just `routes` and an `encoder`. If using dynamic routes you must also define an `llm` (or use the OpenAI default).\n", "\n", "We will start by defining four routes; _politics_, _chitchat_, _mathematics_, and _biology_." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/jamesbriggs/opt/anaconda3/envs/decision-layer/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from semantic_router import Route\n", "\n", "# we could use this as a guide for our chatbot to avoid political conversations\n", "politics = Route(\n", " name=\"politics\",\n", " utterances=[\n", " \"isn't politics the best thing ever\",\n", " \"why don't you tell me about your political opinions\",\n", " \"don't you just love the president\" \"don't you just hate the president\",\n", " \"they're going to destroy this country!\",\n", " \"they will save the country!\",\n", " ],\n", ")\n", "\n", "# this could be used as an indicator to our chatbot to switch to a more\n", "# conversational prompt\n", "chitchat = Route(\n", " name=\"chitchat\",\n", " utterances=[\n", " \"Did you watch the game last night?\",\n", " \"what's your favorite type of music?\",\n", " \"Have you read any good books lately?\",\n", " \"nice weather we're having\",\n", " \"Do you have any plans for the weekend?\",\n", " ],\n", ")\n", "\n", "# we can use this to switch to an agent with more math tools, prompting, and LLMs\n", "mathematics = Route(\n", " name=\"mathematics\",\n", " utterances=[\n", " \"can you explain the concept of a derivative?\",\n", " \"What is the formula for the area of a triangle?\",\n", " \"how do you solve a system of linear equations?\",\n", " \"What is the concept of a prime number?\",\n", " \"Can you explain the Pythagorean theorem?\",\n", " ],\n", ")\n", "\n", "# we can use this to switch to an agent with more biology knowledge\n", "biology = Route(\n", " name=\"biology\",\n", " utterances=[\n", " \"what is the process of osmosis?\",\n", " \"can you explain the structure of a cell?\",\n", " \"What is the role of RNA?\",\n", " \"What is genetic mutation?\",\n", " \"Can you explain the process of photosynthesis?\",\n", " ],\n", ")\n", "\n", "# we place all of our decisions together into single list\n", "routes = [politics, chitchat, mathematics, biology]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For our encoder we will use the local `HuggingFaceEncoder`. Other popular encoders include `CohereEncoder`, `FastEmbedEncoder`, `OpenAIEncoder`, and `AzureOpenAIEncoder`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/jamesbriggs/opt/anaconda3/envs/decision-layer/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", " return self.fget.__get__(instance, owner)()\n" ] } ], "source": [ "from semantic_router.encoders import HuggingFaceEncoder\n", "\n", "encoder = HuggingFaceEncoder()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we initialize our `RouteLayer`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-01-28 11:08:05 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" ] } ], "source": [ "from semantic_router.layer import RouteLayer\n", "\n", "rl = RouteLayer(encoder=encoder, routes=routes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, we should get reasonable performance:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "don't you love politics? -> politics\n", "how's the weather today? -> chitchat\n", "What's DNA? -> biology\n", "I'm interested in learning about llama 2 -> None\n" ] } ], "source": [ "for utterance in [\n", " \"don't you love politics?\",\n", " \"how's the weather today?\",\n", " \"What's DNA?\",\n", " \"I'm interested in learning about llama 2\",\n", "]:\n", " print(f\"{utterance} -> {rl(utterance).name}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can evaluate the performance of our route layer using the `evaluate` method. All we need is to pass a list of utterances and target route labels:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 100.00%\n" ] } ], "source": [ "test_data = [\n", " (\"don't you love politics?\", \"politics\"),\n", " (\"how's the weather today?\", \"chitchat\"),\n", " (\"What's DNA?\", \"biology\"),\n", " (\"I'm interested in learning about llama 2\", None),\n", "]\n", "\n", "# unpack the test data\n", "X, y = zip(*test_data)\n", "\n", "# evaluate using the default thresholds\n", "accuracy = rl.evaluate(X=X, y=y)\n", "print(f\"Accuracy: {accuracy*100:.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On this small subset we get perfect accuracy — but what if we try we a larger, more robust dataset?\n", "\n", "_Hint: try using GPT-4 or another LLM to generate some examples for your own use-cases. The more accurate examples you provide, the better you can expect the routes to perform on your actual use-case._" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "test_data = [\n", " # politics\n", " (\"What's your opinion on the current government?\", \"politics\"),\n", " (\"Who do you think will win the next election?\", \"politics\"),\n", " (\"What are your thoughts on the new policy?\", \"politics\"),\n", " (\"How do you feel about the political situation?\", \"politics\"),\n", " (\"Do you agree with the president's actions?\", \"politics\"),\n", " (\"What's your stance on the political debate?\", \"politics\"),\n", " (\"How do you see the future of our country?\", \"politics\"),\n", " (\"What do you think about the opposition party?\", \"politics\"),\n", " (\"Do you believe the government is doing enough?\", \"politics\"),\n", " (\"What's your opinion on the political scandal?\", \"politics\"),\n", " (\"Do you think the new law will make a difference?\", \"politics\"),\n", " (\"What are your thoughts on the political reform?\", \"politics\"),\n", " (\"Do you agree with the government's foreign policy?\", \"politics\"),\n", " # chitchat\n", " (\"What's the weather like?\", \"chitchat\"),\n", " (\"It's a beautiful day today.\", \"chitchat\"),\n", " (\"How's your day going?\", \"chitchat\"),\n", " (\"It's raining cats and dogs.\", \"chitchat\"),\n", " (\"Let's grab a coffee.\", \"chitchat\"),\n", " (\"What's up?\", \"chitchat\"),\n", " (\"It's a bit chilly today.\", \"chitchat\"),\n", " (\"How's it going?\", \"chitchat\"),\n", " (\"Nice weather we're having.\", \"chitchat\"),\n", " (\"It's a bit windy today.\", \"chitchat\"),\n", " (\"Let's go for a walk.\", \"chitchat\"),\n", " (\"How's your week been?\", \"chitchat\"),\n", " (\"It's quite sunny today.\", \"chitchat\"),\n", " (\"How are you feeling?\", \"chitchat\"),\n", " (\"It's a bit cloudy today.\", \"chitchat\"),\n", " # mathematics\n", " (\"What is the Pythagorean theorem?\", \"mathematics\"),\n", " (\"Can you solve this quadratic equation?\", \"mathematics\"),\n", " (\"What is the derivative of x squared?\", \"mathematics\"),\n", " (\"Explain the concept of integration.\", \"mathematics\"),\n", " (\"What is the area of a circle?\", \"mathematics\"),\n", " (\"How do you calculate the volume of a sphere?\", \"mathematics\"),\n", " (\"What is the difference between a vector and a scalar?\", \"mathematics\"),\n", " (\"Explain the concept of a matrix.\", \"mathematics\"),\n", " (\"What is the Fibonacci sequence?\", \"mathematics\"),\n", " (\"How do you calculate permutations?\", \"mathematics\"),\n", " (\"What is the concept of probability?\", \"mathematics\"),\n", " (\"Explain the binomial theorem.\", \"mathematics\"),\n", " (\"What is the difference between discrete and continuous data?\", \"mathematics\"),\n", " (\"What is a complex number?\", \"mathematics\"),\n", " (\"Explain the concept of limits.\", \"mathematics\"),\n", " # biology\n", " (\"What is photosynthesis?\", \"biology\"),\n", " (\"Explain the process of cell division.\", \"biology\"),\n", " (\"What is the function of mitochondria?\", \"biology\"),\n", " (\"What is DNA?\", \"biology\"),\n", " (\"What is the difference between prokaryotic and eukaryotic cells?\", \"biology\"),\n", " (\"What is an ecosystem?\", \"biology\"),\n", " (\"Explain the theory of evolution.\", \"biology\"),\n", " (\"What is a species?\", \"biology\"),\n", " (\"What is the role of enzymes?\", \"biology\"),\n", " (\"What is the circulatory system?\", \"biology\"),\n", " (\"Explain the process of respiration.\", \"biology\"),\n", " (\"What is a gene?\", \"biology\"),\n", " (\"What is the function of the nervous system?\", \"biology\"),\n", " (\"What is homeostasis?\", \"biology\"),\n", " (\"What is the difference between a virus and a bacteria?\", \"biology\"),\n", " (\"What is the role of the immune system?\", \"biology\"),\n", " # add some None routes to prevent excessively small thresholds\n", " (\"What is the capital of France?\", None),\n", " (\"how many people live in the US?\", None),\n", " (\"when is the best time to visit Bali?\", None),\n", " (\"how do I learn a language\", None),\n", " (\"tell me an interesting fact\", None),\n", " (\"what is the best programming language?\", None),\n", " (\"I'm interested in learning about llama 2\", None),\n", "]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 34.85%\n" ] } ], "source": [ "# unpack the test data\n", "X, y = zip(*test_data)\n", "\n", "# evaluate using the default thresholds\n", "accuracy = rl.evaluate(X=X, y=y)\n", "print(f\"Accuracy: {accuracy*100:.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ouch, that's not so good! Fortunately, we can easily improve our performance here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Route Layer Optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our optimization works by finding the best route thresholds for each `Route` in our `RouteLayer`. We can see the current, default thresholds by calling the `get_thresholds` method:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Default route thresholds: {'politics': 0.5, 'chitchat': 0.5, 'mathematics': 0.5, 'biology': 0.5}\n" ] } ], "source": [ "route_thresholds = rl.get_thresholds()\n", "print(\"Default route thresholds:\", route_thresholds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These are all preset route threshold values. Fortunately, it's very easy to optimize these — we simply call the `fit` method and provide our training utterances `X`, and target route labels `y`:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "100%|██████████| 500/500 [00:01<00:00, 379.54it/s, acc=0.91]\n" ] } ], "source": [ "# Call the fit method\n", "rl.fit(X=X, y=y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what our new thresholds look like:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updated route thresholds: {'politics': 0.20202020202020204, 'chitchat': 0.5, 'mathematics': 0.14141414141414144, 'biology': 0.1806754412815019}\n" ] } ], "source": [ "route_thresholds = rl.get_thresholds()\n", "print(\"Updated route thresholds:\", route_thresholds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These are vastly different thresholds to what we were seeing before — it's worth noting that _optimal_ values for different encoders can vary greatly. For example, OpenAI's Ada 002 model, when used with our encoders will tend to output much larger numbers in the `0.5` to `0.8` range.\n", "\n", "After training we have a final performance of:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 90.91%\n" ] } ], "source": [ "accuracy = rl.evaluate(X=X, y=y)\n", "print(f\"Accuracy: {accuracy*100:.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That is _much_ better. If we wanted to optimize this further we can focus on adding more utterances to our existing routes, analyzing _where_ exactly our failures are, and modifying our routes around those. This extended optimzation process is much more manual, but with it we can continue optimizing routes to get even better performance." ] } ], "metadata": { "kernelspec": { "display_name": "semantic_router_1", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 2 }