From 44e8d001181a4fd0ee94570af72d35a6f2c97387 Mon Sep 17 00:00:00 2001
From: Bogdan Buduroiu <bogdan@buduroiu.com>
Date: Mon, 19 Feb 2024 12:27:34 +0800
Subject: [PATCH] Updates Splitters and Video Splitting example

---
 docs/examples/video-splitter.ipynb           | 37 ++++++++++++++------
 semantic_router/encoders/vit.py              | 13 +++++--
 semantic_router/index/base.py                |  5 +--
 semantic_router/index/local.py               |  6 ++--
 semantic_router/index/pinecone.py            | 12 ++++---
 semantic_router/layer.py                     |  6 ++--
 semantic_router/schema.py                    |  4 +--
 semantic_router/splitters/base.py            |  4 +--
 semantic_router/splitters/consecutive_sim.py |  4 +--
 tests/unit/encoders/test_vit.py              |  3 +-
 tests/unit/test_layer.py                     |  2 +-
 11 files changed, 61 insertions(+), 35 deletions(-)

diff --git a/docs/examples/video-splitter.ipynb b/docs/examples/video-splitter.ipynb
index 967512b8..0f06a5e9 100644
--- a/docs/examples/video-splitter.ipynb
+++ b/docs/examples/video-splitter.ipynb
@@ -73,7 +73,7 @@
    "id": "e054e963-2f02-4f8e-9fb9-6ec8172adf67",
    "metadata": {},
    "source": [
-    "Now that we have the frames loaded, we can go ahead and use the `semantic_splitter` functionality to create splits based on frame similarity\n",
+    "Now that we have the frames loaded, we can go ahead and use the `Splitter` functionality to create splits based on frame similarity\n",
     "\n",
     "First, lets initialise our ViT Encoder"
    ]
@@ -99,9 +99,19 @@
     "encoder = VitEncoder(device=\"mps\")"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "cfe1da2d-71ef-4e78-a56c-d96d1f55607e",
+   "metadata": {},
+   "source": [
+    "Now lets initialise our Splitter. \n",
+    "\n",
+    "> Note: currently, we can only use `semantic_router.splitters.ConsecutiveSimSplitter` for image content"
+   ]
+  },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 5,
    "id": "12bd43fc-83a8-44bd-8394-4e31555fa05d",
    "metadata": {},
    "outputs": [
@@ -111,20 +121,23 @@
        "2"
       ]
      },
-     "execution_count": 4,
+     "execution_count": 5,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "from semantic_router.utils.splitters import semantic_splitter\n",
-    "splits = semantic_splitter(encoder=encoder, docs=image_frames, threshold=0.5)\n",
+    "from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter\n",
+    "\n",
+    "splitter = ConsecutiveSimSplitter(encoder=encoder, score_threshold=0.5)\n",
+    "\n",
+    "splits = splitter(docs=image_frames)\n",
     "len(splits)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 6,
    "id": "8eda19e2-55dd-4091-bdcc-d2bde6f9cc5e",
    "metadata": {},
    "outputs": [
@@ -180,7 +193,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 7,
    "id": "0789d2d9-0e6c-45a6-99ef-67552364fb8f",
    "metadata": {},
    "outputs": [
@@ -190,7 +203,7 @@
        "1139"
       ]
      },
-     "execution_count": 6,
+     "execution_count": 7,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -225,17 +238,19 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 8,
    "id": "3e7a3b3e-1b52-4231-bf32-d3ff2dd0184c",
    "metadata": {},
    "outputs": [],
    "source": [
-    "splits = semantic_splitter(encoder=encoder, docs=image_frames, threshold=0.65)"
+    "splitter = ConsecutiveSimSplitter(encoder=encoder, score_threshold=0.65)\n",
+    "\n",
+    "splits = splitter(docs=image_frames)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 9,
    "id": "59a63e5a-7093-4ef5-857f-3afe2163bf82",
    "metadata": {},
    "outputs": [
diff --git a/semantic_router/encoders/vit.py b/semantic_router/encoders/vit.py
index b9c7be70..d2f28dda 100644
--- a/semantic_router/encoders/vit.py
+++ b/semantic_router/encoders/vit.py
@@ -1,8 +1,8 @@
 from typing import Any, List, Optional
 
-from pydantic.v1 import PrivateAttr
 from PIL import Image
 from PIL.Image import Image as _Image
+from pydantic.v1 import PrivateAttr
 
 from semantic_router.encoders import BaseEncoder
 
@@ -46,7 +46,9 @@ class VitEncoder(BaseEncoder):
         self._torch = torch
         self._T = T
 
-        processor = ViTImageProcessor.from_pretrained(self.name, **self.processor_kwargs)
+        processor = ViTImageProcessor.from_pretrained(
+            self.name, **self.processor_kwargs
+        )
 
         model = ViTModel.from_pretrained(self.name, **self.model_kwargs)
 
@@ -81,6 +83,11 @@ class VitEncoder(BaseEncoder):
             batch_imgs = imgs[i : i + batch_size]
             batch_imgs_transform = self._process_images(batch_imgs)
             with self._torch.no_grad():
-                embeddings = self._model(**batch_imgs_transform).last_hidden_state[:, 0].cpu().tolist()
+                embeddings = (
+                    self._model(**batch_imgs_transform)
+                    .last_hidden_state[:, 0]
+                    .cpu()
+                    .tolist()
+                )
             all_embeddings.extend(embeddings)
         return all_embeddings
diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index 77053ef0..351692df 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -1,6 +1,7 @@
-from pydantic.v1 import BaseModel
-from typing import Any, List, Tuple, Optional, Union
+from typing import Any, List, Optional, Tuple, Union
+
 import numpy as np
+from pydantic.v1 import BaseModel
 
 
 class BaseIndex(BaseModel):
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index 058ee2bc..0ca67c87 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -1,7 +1,9 @@
+from typing import List, Optional, Tuple
+
 import numpy as np
-from typing import List, Tuple, Optional
-from semantic_router.linear import similarity_matrix, top_scores
+
 from semantic_router.index.base import BaseIndex
+from semantic_router.linear import similarity_matrix, top_scores
 
 
 class LocalIndex(BaseIndex):
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index ae210a4f..35f7f879 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -1,12 +1,14 @@
-from pydantic.v1 import BaseModel, Field
-import requests
-import time
 import hashlib
 import os
-from typing import Any, Dict, List, Tuple, Optional, Union
+import time
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import requests
+from pydantic.v1 import BaseModel, Field
+
 from semantic_router.index.base import BaseIndex
 from semantic_router.utils.logger import logger
-import numpy as np
 
 
 def clean_route_name(route_name: str) -> str:
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 1e8a8130..640a68a7 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -1,3 +1,4 @@
+import importlib
 import json
 import os
 import random
@@ -6,15 +7,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 import numpy as np
 import yaml
 from tqdm.auto import tqdm
-import importlib
 
 from semantic_router.encoders import BaseEncoder, OpenAIEncoder
+from semantic_router.index.base import BaseIndex
+from semantic_router.index.local import LocalIndex
 from semantic_router.llms import BaseLLM, OpenAILLM
 from semantic_router.route import Route
 from semantic_router.schema import Encoder, EncoderType, RouteChoice
 from semantic_router.utils.logger import logger
-from semantic_router.index.base import BaseIndex
-from semantic_router.index.local import LocalIndex
 
 
 def is_valid(layer_config: str) -> bool:
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 6a5a0637..8bb5a65f 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -1,5 +1,5 @@
 from enum import Enum
-from typing import List, Optional
+from typing import Any, List, Optional
 
 from pydantic.v1 import BaseModel
 from pydantic.v1.dataclasses import dataclass
@@ -77,6 +77,6 @@ class Message(BaseModel):
 
 
 class DocumentSplit(BaseModel):
-    docs: List[str]
+    docs: List[Any]
     is_triggered: bool = False
     triggered_score: Optional[float] = None
diff --git a/semantic_router/splitters/base.py b/semantic_router/splitters/base.py
index edeba73b..7f66c8bc 100644
--- a/semantic_router/splitters/base.py
+++ b/semantic_router/splitters/base.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import Any, List
 
 from pydantic.v1 import BaseModel
 
@@ -10,5 +10,5 @@ class BaseSplitter(BaseModel):
     encoder: BaseEncoder
     score_threshold: float
 
-    def __call__(self, docs: List[str]) -> List[List[float]]:
+    def __call__(self, docs: List[Any]) -> List[List[float]]:
         raise NotImplementedError("Subclasses must implement this method")
diff --git a/semantic_router/splitters/consecutive_sim.py b/semantic_router/splitters/consecutive_sim.py
index 55a29a5c..f30bbc75 100644
--- a/semantic_router/splitters/consecutive_sim.py
+++ b/semantic_router/splitters/consecutive_sim.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import Any, List
 
 import numpy as np
 
@@ -22,7 +22,7 @@ class ConsecutiveSimSplitter(BaseSplitter):
         super().__init__(name=name, score_threshold=score_threshold, encoder=encoder)
         encoder.score_threshold = score_threshold
 
-    def __call__(self, docs: List[str]):
+    def __call__(self, docs: List[Any]):
         # Check if there's only a single document
         if len(docs) == 1:
             raise ValueError(
diff --git a/tests/unit/encoders/test_vit.py b/tests/unit/encoders/test_vit.py
index cd7ced93..0435058c 100644
--- a/tests/unit/encoders/test_vit.py
+++ b/tests/unit/encoders/test_vit.py
@@ -1,6 +1,5 @@
-import pytest
-
 import numpy as np
+import pytest
 from PIL import Image
 
 from semantic_router.encoders import VitEncoder
diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py
index 88a4679a..00bad4ff 100644
--- a/tests/unit/test_layer.py
+++ b/tests/unit/test_layer.py
@@ -6,8 +6,8 @@ import pytest
 
 from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder
 from semantic_router.layer import LayerConfig, RouteLayer
-from semantic_router.route import Route
 from semantic_router.llms.base import BaseLLM
+from semantic_router.route import Route
 
 
 def mock_encoder_call(utterances):
-- 
GitLab