Skip to content
Snippets Groups Projects
Commit f2f6e5fc authored by Kenny's avatar Kenny
Browse files

fastembed upd & test added

parent be312b04
No related branches found
No related tags found
No related merge requests found
......@@ -2,5 +2,12 @@ from semantic_router.encoders.base import BaseEncoder
from semantic_router.encoders.bm25 import BM25Encoder
from semantic_router.encoders.cohere import CohereEncoder
from semantic_router.encoders.openai import OpenAIEncoder
from semantic_router.encoders.fastembed import FastEmbedEncoder
__all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder"]
__all__ = [
"BaseEncoder",
"CohereEncoder",
"OpenAIEncoder",
"BM25Encoder",
"FastEmbedEncoder",
]
from typing import List, Optional
from typing import Any, List, Optional
import numpy as np
from semantic_router.encoders.base import BaseEncoder
from pydantic import BaseModel, PrivateAttr
class FastEmbedEncoder(BaseEncoder):
class FastEmbedEncoder(BaseModel):
type: str = "fastembed"
model_name: str = "BAAI/bge-small-en-v1.5"
max_length: int = 512
cache_dir: Optional[str] = None
threads: Optional[int] = None
type: str = "fastembed"
_client: Any = PrivateAttr()
def __init__(self, **data):
super().__init__(**data)
self._client = self._initialize_client()
def init(self):
def _initialize_client(self):
try:
from fastembed.embedding import FlagEmbedding as Embedding
except ImportError:
......@@ -23,20 +27,19 @@ class FastEmbedEncoder(BaseEncoder):
embedding_args = {
"model_name": self.model_name,
"max_length": self.max_length,
"cache_dir": self.cache_dir,
"threads": self.threads,
}
if self.cache_dir is not None:
embedding_args["cache_dir"] = self.cache_dir
if self.threads is not None:
embedding_args["threads"] = self.threads
self.client = Embedding(**embedding_args)
embedding_args = {k: v for k, v in embedding_args.items() if v is not None}
embedding = Embedding(**embedding_args)
return embedding
def __call__(self, docs: list[str]) -> list[list[float]]:
try:
embeds: List[np.ndarray] = list(self.client.embed(docs))
embeds: List[np.ndarray] = list(self._client.embed(docs))
embeddings: List[List[float]] = [e.tolist() for e in embeds]
return embeddings
except Exception as e:
raise ValueError(f"FastEmbed embed failed. Error: {e}")
from semantic_router.encoders import FastEmbedEncoder
class TestFastEmbedEncoder:
def test_fastembed_encoder(self):
encode = FastEmbedEncoder()
test_docs = ["This is a test", "This is another test"]
embeddings = encode(test_docs)
assert isinstance(embeddings, list)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment