From f2f6e5fc78a76033b23f89c9120891d5fc784b86 Mon Sep 17 00:00:00 2001
From: Kenny <kenny@kdco.llc>
Date: Wed, 3 Jan 2024 20:51:34 -0500
Subject: [PATCH] fastembed upd & test added

---
 semantic_router/encoders/__init__.py  |  9 +++++++-
 semantic_router/encoders/fastembed.py | 31 +++++++++++++++------------
 test_output.txt                       |  0
 tests/unit/encoders/test_fastembed.py | 10 +++++++++
 4 files changed, 35 insertions(+), 15 deletions(-)
 create mode 100644 test_output.txt
 create mode 100644 tests/unit/encoders/test_fastembed.py

diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py
index ac27ebb4..4bb1eb37 100644
--- a/semantic_router/encoders/__init__.py
+++ b/semantic_router/encoders/__init__.py
@@ -2,5 +2,12 @@ from semantic_router.encoders.base import BaseEncoder
 from semantic_router.encoders.bm25 import BM25Encoder
 from semantic_router.encoders.cohere import CohereEncoder
 from semantic_router.encoders.openai import OpenAIEncoder
+from semantic_router.encoders.fastembed import FastEmbedEncoder
 
-__all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder"]
+__all__ = [
+    "BaseEncoder",
+    "CohereEncoder",
+    "OpenAIEncoder",
+    "BM25Encoder",
+    "FastEmbedEncoder",
+]
diff --git a/semantic_router/encoders/fastembed.py b/semantic_router/encoders/fastembed.py
index 6b700ade..d324058d 100644
--- a/semantic_router/encoders/fastembed.py
+++ b/semantic_router/encoders/fastembed.py
@@ -1,17 +1,21 @@
-from typing import List, Optional
-
+from typing import Any, List, Optional
 import numpy as np
-from semantic_router.encoders.base import BaseEncoder
+from pydantic import BaseModel, PrivateAttr
 
 
-class FastEmbedEncoder(BaseEncoder):
+class FastEmbedEncoder(BaseModel):
+    type: str = "fastembed"
     model_name: str = "BAAI/bge-small-en-v1.5"
     max_length: int = 512
     cache_dir: Optional[str] = None
     threads: Optional[int] = None
-    type: str = "fastembed"
+    _client: Any = PrivateAttr()
+
+    def __init__(self, **data):
+        super().__init__(**data)
+        self._client = self._initialize_client()
 
-    def init(self):
+    def _initialize_client(self):
         try:
             from fastembed.embedding import FlagEmbedding as Embedding
         except ImportError:
@@ -23,20 +27,19 @@ class FastEmbedEncoder(BaseEncoder):
         embedding_args = {
             "model_name": self.model_name,
             "max_length": self.max_length,
+            "cache_dir": self.cache_dir,
+            "threads": self.threads,
         }
-        if self.cache_dir is not None:
-            embedding_args["cache_dir"] = self.cache_dir
-        if self.threads is not None:
-            embedding_args["threads"] = self.threads
 
-        self.client = Embedding(**embedding_args)
+        embedding_args = {k: v for k, v in embedding_args.items() if v is not None}
+
+        embedding = Embedding(**embedding_args)
+        return embedding
 
     def __call__(self, docs: list[str]) -> list[list[float]]:
         try:
-            embeds: List[np.ndarray] = list(self.client.embed(docs))
-
+            embeds: List[np.ndarray] = list(self._client.embed(docs))
             embeddings: List[List[float]] = [e.tolist() for e in embeds]
-
             return embeddings
         except Exception as e:
             raise ValueError(f"FastEmbed embed failed. Error: {e}")
diff --git a/test_output.txt b/test_output.txt
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/encoders/test_fastembed.py b/tests/unit/encoders/test_fastembed.py
new file mode 100644
index 00000000..5efdbfbe
--- /dev/null
+++ b/tests/unit/encoders/test_fastembed.py
@@ -0,0 +1,10 @@
+from semantic_router.encoders import FastEmbedEncoder
+
+
+class TestFastEmbedEncoder:
+    def test_fastembed_encoder(self):
+        encode = FastEmbedEncoder()
+        test_docs = ["This is a test", "This is another test"]
+
+        embeddings = encode(test_docs)
+        assert isinstance(embeddings, list)
-- 
GitLab