{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/08-multi-modal.ipynb) [](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/08-multi-modal.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Multi-Modal Routes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Semantic Router library can also be used for detection of specific images or videos, for example the detection of **N**ot **S**hrek **F**or **W**ork (NSFW) and **S**hrek **F**or **W**ork (SFW) images as we will demonstrate in this walkthrough." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Getting Started" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start by installing the library:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -qU \\\n", " \"semantic-router[local]==0.0.25\" \\\n", " datasets==2.17.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start by downloading a multi-modal dataset, we'll be using the `aurelio-ai/shrek-detection` dataset from Hugging Face." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "data = load_dataset(\"aurelio-ai/shrek-detection\", split=\"train\", trust_remote_code=True)\n", "data[3][\"image\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will grab the images that are labeled with `is_shrek`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "shrek_pics = [d[\"image\"] for d in data if d[\"is_shrek\"]]\n", "not_shrek_pics = [d[\"image\"] for d in data if not d[\"is_shrek\"]]\n", "print(f\"We have {len(shrek_pics)} shrek pics, and {len(not_shrek_pics)} not shrek pics\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start by defining a dictionary mapping routes to example phrases that should trigger those routes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from semantic_router import Route\n", "\n", "shrek = Route(\n", " name=\"shrek\",\n", " utterances=shrek_pics,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's define another for good measure:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "not_shrek = Route(\n", " name=\"not_shrek\",\n", " utterances=not_shrek_pics,\n", ")\n", "\n", "routes = [shrek, not_shrek]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we initialize our embedding model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from semantic_router.encoders.clip import CLIPEncoder\n", "\n", "encoder = CLIPEncoder()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from semantic_router.layer import RouteLayer\n", "\n", "rl = RouteLayer(encoder=encoder, routes=routes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can test it with _text_ to see if we hit the routes that we defined with images:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rl(\"don't you love politics?\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rl(\"shrek\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rl(\"dwayne the rock johnson\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Everything is being classified accurately, let's pull in some images that we haven't seen before and see if we can classify them as NSFW or SFW." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_data = load_dataset(\n", " \"aurelio-ai/shrek-detection\", split=\"test\", trust_remote_code=True\n", ")\n", "test_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, we return `None` because no matches were identified." ] } ], "metadata": { "kernelspec": { "display_name": "decision-layer", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 2 }