From 922b2df5df8b9b7251141496770e4be8c9fc8fc6 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Sun, 25 Feb 2024 14:14:52 +0400
Subject: [PATCH] add mps check

---
 docs/08-multi-modal.ipynb        |  4 ++--
 semantic_router/encoders/clip.py | 13 +++++++------
 2 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/docs/08-multi-modal.ipynb b/docs/08-multi-modal.ipynb
index ae5d40d2..e5e42399 100644
--- a/docs/08-multi-modal.ipynb
+++ b/docs/08-multi-modal.ipynb
@@ -18,7 +18,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "The Semantic Router library can also be used for detection of specific images or videos, for example the detection of NSFW (no shrek for work) images as we will demonstrate in this walkthrough."
+    "The Semantic Router library can also be used for detection of specific images or videos, for example the detection of **N**ot **S**hrek **F**or **W**ork (NSFW) and **S**hrek **F**or **W**ork (SFW) images as we will demonstrate in this walkthrough."
    ]
   },
   {
@@ -42,7 +42,7 @@
    "outputs": [],
    "source": [
     "!pip install -qU \\\n",
-    "    \"semantic-router[local]==0.0.21\" \\\n",
+    "    \"semantic-router[local]==0.0.25\" \\\n",
     "    datasets==2.17.0"
    ]
   },
diff --git a/semantic_router/encoders/clip.py b/semantic_router/encoders/clip.py
index c7210707..4f4342f1 100644
--- a/semantic_router/encoders/clip.py
+++ b/semantic_router/encoders/clip.py
@@ -87,13 +87,14 @@ class CLIPEncoder(BaseEncoder):
         model = CLIPModel.from_pretrained(self.name, **self.model_kwargs)
 
         if self.device:
-            model.to(self.device)
-
+            pass
+        elif self._torch.cuda.is_available():
+            self.device = "cuda"
+        elif self._torch.backends.mps.is_available():
+            self.device = "mps"
         else:
-            device = "cuda" if self._torch.cuda.is_available() else "cpu"
-            model.to(device)
-            self.device = device
-
+            self.device = "cpu"
+        model.to(self.device)
         return tokenizer, processor, model
 
     def _encode_text(self, docs: List[str]) -> Any:
-- 
GitLab