{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/08-multi-modal.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/08-multi-modal.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-Modal Routes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Semantic Router library can also be used for detection of specific images or videos, for example the detection of **N**ot **S**hrek **F**or **W**ork (NSFW) and **S**hrek **F**or **W**ork (SFW) images as we will demonstrate in this walkthrough."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting Started"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We start by installing the library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -qU \\\n",
    "    \"semantic-router[local]==0.0.25\" \\\n",
    "    datasets==2.17.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We start by downloading a multi-modal dataset, we'll be using the `aurelio-ai/shrek-detection` dataset from Hugging Face."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "data = load_dataset(\"aurelio-ai/shrek-detection\", split=\"train\", trust_remote_code=True)\n",
    "data[3][\"image\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will grab the images that are labeled with `is_shrek`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shrek_pics = [d[\"image\"] for d in data if d[\"is_shrek\"]]\n",
    "not_shrek_pics = [d[\"image\"] for d in data if not d[\"is_shrek\"]]\n",
    "print(f\"We have {len(shrek_pics)} shrek pics, and {len(not_shrek_pics)} not shrek pics\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We start by defining a dictionary mapping routes to example phrases that should trigger those routes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semantic_router import Route\n",
    "\n",
    "shrek = Route(\n",
    "    name=\"shrek\",\n",
    "    utterances=shrek_pics,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's define another for good measure:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "not_shrek = Route(\n",
    "    name=\"not_shrek\",\n",
    "    utterances=not_shrek_pics,\n",
    ")\n",
    "\n",
    "routes = [shrek, not_shrek]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we initialize our embedding model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semantic_router.encoders.clip import CLIPEncoder\n",
    "\n",
    "encoder = CLIPEncoder()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semantic_router.layer import RouteLayer\n",
    "\n",
    "rl = RouteLayer(encoder=encoder, routes=routes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can test it with _text_ to see if we hit the routes that we defined with images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rl(\"don't you love politics?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rl(\"shrek\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rl(\"dwayne the rock johnson\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Everything is being classified accurately, let's pull in some images that we haven't seen before and see if we can classify them as NSFW or SFW."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = load_dataset(\n",
    "    \"aurelio-ai/shrek-detection\", split=\"test\", trust_remote_code=True\n",
    ")\n",
    "test_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case, we return `None` because no matches were identified."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "decision-layer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}