Skip to content
Snippets Groups Projects
Unverified Commit 922b2df5 authored by James Briggs's avatar James Briggs
Browse files

add mps check

parent 6be31413
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/08-multi-modal.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/08-multi-modal.ipynb)
%% Cell type:markdown id: tags:
# Multi-Modal Routes
%% Cell type:markdown id: tags:
The Semantic Router library can also be used for detection of specific images or videos, for example the detection of NSFW (no shrek for work) images as we will demonstrate in this walkthrough.
The Semantic Router library can also be used for detection of specific images or videos, for example the detection of **N**ot **S**hrek **F**or **W**ork (NSFW) and **S**hrek **F**or **W**ork (SFW) images as we will demonstrate in this walkthrough.
%% Cell type:markdown id: tags:
## Getting Started
%% Cell type:markdown id: tags:
We start by installing the library:
%% Cell type:code id: tags:
``` python
!pip install -qU \
"semantic-router[local]==0.0.21" \
"semantic-router[local]==0.0.25" \
datasets==2.17.0
```
%% Cell type:markdown id: tags:
We start by downloading a multi-modal dataset, we'll be using the `aurelio-ai/shrek-detection` dataset from Hugging Face.
%% Cell type:code id: tags:
``` python
from datasets import load_dataset
data = load_dataset("aurelio-ai/shrek-detection", split="train", trust_remote_code=True)
data[3]["image"]
```
%% Cell type:markdown id: tags:
We will grab the images that are labeled with `is_shrek`:
%% Cell type:code id: tags:
``` python
shrek_pics = [d["image"] for d in data if d["is_shrek"]]
not_shrek_pics = [d["image"] for d in data if not d["is_shrek"]]
print(f"We have {len(shrek_pics)} shrek pics, and {len(not_shrek_pics)} not shrek pics")
```
%% Cell type:markdown id: tags:
We start by defining a dictionary mapping routes to example phrases that should trigger those routes.
%% Cell type:code id: tags:
``` python
from semantic_router import Route
shrek = Route(
name="shrek",
utterances=shrek_pics,
)
```
%% Cell type:markdown id: tags:
Let's define another for good measure:
%% Cell type:code id: tags:
``` python
not_shrek = Route(
name="not_shrek",
utterances=not_shrek_pics,
)
routes = [shrek, not_shrek]
```
%% Cell type:markdown id: tags:
Now we initialize our embedding model:
%% Cell type:code id: tags:
``` python
from semantic_router.encoders.clip import CLIPEncoder
encoder = CLIPEncoder()
```
%% Cell type:markdown id: tags:
Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`.
%% Cell type:code id: tags:
``` python
from semantic_router.layer import RouteLayer
rl = RouteLayer(encoder=encoder, routes=routes)
```
%% Cell type:markdown id: tags:
Now we can test it with _text_ to see if we hit the routes that we defined with images:
%% Cell type:code id: tags:
``` python
rl("don't you love politics?")
```
%% Cell type:code id: tags:
``` python
rl("shrek")
```
%% Cell type:code id: tags:
``` python
rl("dwayne the rock johnson")
```
%% Cell type:markdown id: tags:
Everything is being classified accurately, let's pull in some images that we haven't seen before and see if we can classify them as NSFW or SFW.
%% Cell type:code id: tags:
``` python
test_data = load_dataset(
"aurelio-ai/shrek-detection", split="test", trust_remote_code=True
)
test_data
```
%% Cell type:markdown id: tags:
In this case, we return `None` because no matches were identified.
......
......@@ -87,13 +87,14 @@ class CLIPEncoder(BaseEncoder):
model = CLIPModel.from_pretrained(self.name, **self.model_kwargs)
if self.device:
model.to(self.device)
pass
elif self._torch.cuda.is_available():
self.device = "cuda"
elif self._torch.backends.mps.is_available():
self.device = "mps"
else:
device = "cuda" if self._torch.cuda.is_available() else "cpu"
model.to(device)
self.device = device
self.device = "cpu"
model.to(self.device)
return tokenizer, processor, model
def _encode_text(self, docs: List[str]) -> Any:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment