From 922b2df5df8b9b7251141496770e4be8c9fc8fc6 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sun, 25 Feb 2024 14:14:52 +0400 Subject: [PATCH] add mps check --- docs/08-multi-modal.ipynb | 4 ++-- semantic_router/encoders/clip.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/08-multi-modal.ipynb b/docs/08-multi-modal.ipynb index ae5d40d2..e5e42399 100644 --- a/docs/08-multi-modal.ipynb +++ b/docs/08-multi-modal.ipynb @@ -18,7 +18,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The Semantic Router library can also be used for detection of specific images or videos, for example the detection of NSFW (no shrek for work) images as we will demonstrate in this walkthrough." + "The Semantic Router library can also be used for detection of specific images or videos, for example the detection of **N**ot **S**hrek **F**or **W**ork (NSFW) and **S**hrek **F**or **W**ork (SFW) images as we will demonstrate in this walkthrough." ] }, { @@ -42,7 +42,7 @@ "outputs": [], "source": [ "!pip install -qU \\\n", - " \"semantic-router[local]==0.0.21\" \\\n", + " \"semantic-router[local]==0.0.25\" \\\n", " datasets==2.17.0" ] }, diff --git a/semantic_router/encoders/clip.py b/semantic_router/encoders/clip.py index c7210707..4f4342f1 100644 --- a/semantic_router/encoders/clip.py +++ b/semantic_router/encoders/clip.py @@ -87,13 +87,14 @@ class CLIPEncoder(BaseEncoder): model = CLIPModel.from_pretrained(self.name, **self.model_kwargs) if self.device: - model.to(self.device) - + pass + elif self._torch.cuda.is_available(): + self.device = "cuda" + elif self._torch.backends.mps.is_available(): + self.device = "mps" else: - device = "cuda" if self._torch.cuda.is_available() else "cpu" - model.to(device) - self.device = device - + self.device = "cpu" + model.to(self.device) return tokenizer, processor, model def _encode_text(self, docs: List[str]) -> Any: -- GitLab