From 30a50b4bea30cc5d3252997994a64ea24bb80482 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Sun, 12 Nov 2023 07:10:15 -0800
Subject: [PATCH] restructure and testing different decision layer structures

---
 semantic_router/encoders/__init__.py          |   5 -
 semantic_router/encoders/huggingface.py       |   9 -
 semantic_router/layer.py                      | 246 +++++++++++++++++-
 semantic_router/matchers/__init__.py          |   0
 semantic_router/matchers/base.py              |  18 ++
 semantic_router/matchers/ranker_only.py       |   1 +
 semantic_router/matchers/two_stage.py         |  59 +++++
 semantic_router/rankers/__init__.py           |   0
 semantic_router/rankers/base.py               |  12 +
 semantic_router/rankers/cohere.py             |  31 +++
 semantic_router/retrievers/__init__.py        |   5 +
 .../{encoders => retrievers}/base.py          |   4 +-
 semantic_router/retrievers/bm25.py            |  21 ++
 .../{encoders => retrievers}/cohere.py        |  10 +-
 semantic_router/retrievers/huggingface.py     |   9 +
 .../{encoders => retrievers}/openai.py        |   8 +-
 semantic_router/schema.py                     |  32 +--
 17 files changed, 420 insertions(+), 50 deletions(-)
 delete mode 100644 semantic_router/encoders/__init__.py
 delete mode 100644 semantic_router/encoders/huggingface.py
 create mode 100644 semantic_router/matchers/__init__.py
 create mode 100644 semantic_router/matchers/base.py
 create mode 100644 semantic_router/matchers/ranker_only.py
 create mode 100644 semantic_router/matchers/two_stage.py
 create mode 100644 semantic_router/rankers/__init__.py
 create mode 100644 semantic_router/rankers/base.py
 create mode 100644 semantic_router/rankers/cohere.py
 create mode 100644 semantic_router/retrievers/__init__.py
 rename semantic_router/{encoders => retrievers}/base.py (67%)
 create mode 100644 semantic_router/retrievers/bm25.py
 rename semantic_router/{encoders => retrievers}/cohere.py (71%)
 create mode 100644 semantic_router/retrievers/huggingface.py
 rename semantic_router/{encoders => retrievers}/openai.py (81%)

diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py
deleted file mode 100644
index 3fc94815..00000000
--- a/semantic_router/encoders/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .base import BaseEncoder
-from .cohere import CohereEncoder
-from .openai import OpenAIEncoder
-
-__all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder"]
diff --git a/semantic_router/encoders/huggingface.py b/semantic_router/encoders/huggingface.py
deleted file mode 100644
index 258c5037..00000000
--- a/semantic_router/encoders/huggingface.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from semantic_router.encoders import BaseEncoder
-
-
-class HuggingFaceEncoder(BaseEncoder):
-    def __init__(self, name: str):
-        self.name = name
-
-    def __call__(self, texts: list[str]) -> list[float]:
-        raise NotImplementedError
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 089f2793..12b6e80a 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -1,24 +1,124 @@
 import numpy as np
 from numpy.linalg import norm
 
-from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder
+from semantic_router.retrievers import (
+    BaseRetriever,
+    CohereRetriever,
+    OpenAIRetriever,
+    BM25Retriever
+)
+from semantic_router.rankers import BaseRanker
+from semantic_router.matchers import BaseMatcher
 from semantic_router.schema import Decision
 
 
+class MatcherDecisionLayer:
+    index: None
+    decision_arr: None
+    score_threshold: float
+
+    def __init__(self, matcher: BaseMatcher, decisions: list[Decision] = []):
+        self.matcher = matcher
+        # if decisions list has been passed and we have retriever 
+        # we initialize index now
+        if matcher.retriever and decisions:
+            # initialize index now
+            for decision in decisions:
+                self._add_decision(decision=decision)
+
+    def __call__(self, text: str) -> str | None:
+        raise NotImplementedError
+
+class RankDecisionLayer:
+    def __init__(self, ranker: BaseRanker, decisions: list[Decision] = []):
+        self.ranker = ranker
+        # if decisions list has been passed, we initialize decision array
+        if decisions:
+            for decision in decisions:
+                self._add_decision(decision=decision)
+
+    def __call__(self, text: str) -> str | None:
+        results = self._query(text)
+        top_class, top_class_scores = self._semantic_classify(results)
+        passed = self._pass_threshold(top_class_scores, self.score_threshold)
+        if passed:
+            return top_class
+        else:
+            return None
+
+    def add(self, decision: Decision):
+        self._add_decision(decision.utterances)
+
+    def _add_decision(self, decision: Decision):
+        # create decision categories array
+        if self.categories is None:
+            self.categories = np.array([decision.name] * len(decision.utterances))
+            self.utterances = np.array(decision.utterances)
+        else:
+            str_arr = np.array([decision.name] * len(decision.utterances))
+            self.categories = np.concatenate([self.categories, str_arr])
+            self.utterances = np.concatenate([
+                self.utterances,
+                np.array(decision.utterances)
+            ])
+
+    def _query(self, text: str, top_k: int = 5):
+        """Given some text, encodes and searches the index vector space to
+        retrieve the top_k most similar records.
+        """
+        if self.categories:
+            self.rerank.top_n = top_k
+            idx, docs = self.ranker(query=text, docs=self.utterances)
+            # create scores based on rank
+            scores = [1/(i+1) for i in range(len(docs))]
+            # get the utterance categories (decision names)
+            decisions = self.categories[idx] if self.categories is not None else []
+            return [
+                {"decision": d, "score": s.item()} for d, s in zip(decisions, scores)
+            ]
+        else:
+            return []
+        
+    def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]:
+        scores_by_class = {}
+        for result in query_results:
+            score = result["score"]
+            decision = result["decision"]
+            if decision in scores_by_class:
+                scores_by_class[decision].append(score)
+            else:
+                scores_by_class[decision] = [score]
+
+        # Calculate total score for each class
+        total_scores = {
+            decision: sum(scores) for decision, scores in scores_by_class.items()
+        }
+        top_class = max(total_scores, key=lambda x: total_scores[x], default=None)
+
+        # Return the top class and its associated scores
+        return str(top_class), scores_by_class.get(top_class, [])
+
+    def _pass_threshold(self, scores: list[float], threshold: float) -> bool:
+        if scores:
+            return max(scores) > threshold
+        else:
+            return False
+
+
 class DecisionLayer:
     index = None
     categories = None
-    similarity_threshold = 0.82
+    score_threshold = 0.82
 
-    def __init__(self, encoder: BaseEncoder, decisions: list[Decision] = []):
+    def __init__(self, encoder: BaseRetriever, decisions: list[Decision] = []):
         self.encoder = encoder
         # decide on default threshold based on encoder
-        if isinstance(encoder, OpenAIEncoder):
-            self.similarity_threshold = 0.82
-        elif isinstance(encoder, CohereEncoder):
-            self.similarity_threshold = 0.3
+        if isinstance(encoder, OpenAIRetriever):
+            self.score_threshold = 0.82
+        elif isinstance(encoder, CohereRetriever):
+            self.score_threshold = 0.3
         else:
-            self.similarity_threshold = 0.82
+            self.score_threshold = 0.82
         # if decisions list has been passed, we initialize index now
         if decisions:
             # initialize index now
@@ -28,7 +128,7 @@ class DecisionLayer:
     def __call__(self, text: str) -> str | None:
         results = self._query(text)
         top_class, top_class_scores = self._semantic_classify(results)
-        passed = self._pass_threshold(top_class_scores, self.similarity_threshold)
+        passed = self._pass_threshold(top_class_scores, self.score_threshold)
         if passed:
             return top_class
         else:
@@ -102,3 +202,131 @@ class DecisionLayer:
             return max(scores) > threshold
         else:
             return False
+
+
+class HybridDecisionLayer:
+    index = None
+    categories = None
+    score_threshold = 0.82
+
+    def __init__(
+        self,
+        encoder: BaseRetriever,
+        decisions: list[Decision] = [],
+        alpha: float = 0.3
+    ):
+        self.encoder = encoder
+        self.sparse_encoder = BM25Retriever()
+        # decide on default threshold based on encoder
+        if isinstance(encoder, OpenAIRetriever):
+            self.score_threshold = 0.82
+        elif isinstance(encoder, CohereRetriever):
+            self.score_threshold = 0.3
+        else:
+            self.score_threshold = 0.82
+        # if decisions list has been passed, we initialize index now
+        if decisions:
+            # initialize index now
+            for decision in decisions:
+                self._add_decision(decision=decision)
+
+    def __call__(self, text: str) -> str | None:
+        results = self._query(text)
+        top_class, top_class_scores = self._semantic_classify(results)
+        passed = self._pass_threshold(top_class_scores, self.score_threshold)
+        if passed:
+            return top_class
+        else:
+            return None
+
+    def add(self, decision: Decision):
+        self._add_decision(decision=decision)
+
+    def _add_decision(self, decision: Decision):
+        # create embeddings
+        dense_embeds = self.encoder(decision.utterances)
+        sparse_embeds = self.sparse_encoder(decision.utterances)
+        # concatenate vectors to create hybrid vecs
+        embeds = np.concatenate([
+            dense_embeds, sparse_embeds
+        ], axis=1)
+
+        # create decision array
+        if self.categories is None:
+            self.categories = np.array([decision.name] * len(embeds))
+            self.utterances = np.array(decision.utterances)
+        else:
+            str_arr = np.array([decision.name] * len(embeds))
+            self.categories = np.concatenate([self.categories, str_arr])
+            self.utterances = np.concatenate([
+                self.utterances,
+                np.array(decision.utterances)
+            ])
+        # create utterance array (the index)
+        if self.index is None:
+            self.index = np.array(embeds)
+        else:
+            embed_arr = np.array(embeds)
+            self.index = np.concatenate([self.index, embed_arr])
+
+    def _query(self, text: str, top_k: int = 5):
+        """Given some text, encodes and searches the index vector space to
+        retrieve the top_k most similar records.
+        """
+        # create dense query vector
+        xq_d = np.array(self.encoder([text]))
+        xq_d = np.squeeze(xq_d)  # Reduce to 1d array.
+        # create sparse query vector
+        xq_s = np.array(self.sparse_encoder([text]))
+        xq_s = np.squeeze(xq_s)
+        # convex scaling
+        xq_d, xq_s = self._convex_scaling(xq_d, xq_s)
+        # concatenate to create single hybrid vec
+        xq = np.concatenate([xq_d, xq_s], axis=1)
+
+        if self.index is not None:
+            index_norm = norm(self.index, axis=1)
+            xq_norm = norm(xq.T)
+            sim = np.dot(self.index, xq.T) / (index_norm * xq_norm)
+            # get indices of top_k records
+            top_k = min(top_k, sim.shape[0])
+            idx = np.argpartition(sim, -top_k)[-top_k:]
+            scores = sim[idx]
+            # get the utterance categories (decision names)
+            decisions = self.categories[idx] if self.categories is not None else []
+            return [
+                {"decision": d, "score": s.item()} for d, s in zip(decisions, scores)
+            ]
+        else:
+            return []
+        
+    def _convex_scaling(self, dense: list[float], sparse: list[float]):
+        # scale sparse and dense vecs
+        dense = dense * self.alpha
+        sparse = sparse * (1 - self.alpha)
+        return dense, sparse
+
+    def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]:
+        scores_by_class = {}
+        for result in query_results:
+            score = result["score"]
+            decision = result["decision"]
+            if decision in scores_by_class:
+                scores_by_class[decision].append(score)
+            else:
+                scores_by_class[decision] = [score]
+
+        # Calculate total score for each class
+        total_scores = {
+            decision: sum(scores) for decision, scores in scores_by_class.items()
+        }
+        top_class = max(total_scores, key=lambda x: total_scores[x], default=None)
+
+        # Return the top class and its associated scores
+        return str(top_class), scores_by_class.get(top_class, [])
+
+    def _pass_threshold(self, scores: list[float], threshold: float) -> bool:
+        if scores:
+            return max(scores) > threshold
+        else:
+            return False
\ No newline at end of file
diff --git a/semantic_router/matchers/__init__.py b/semantic_router/matchers/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/semantic_router/matchers/base.py b/semantic_router/matchers/base.py
new file mode 100644
index 00000000..fc42cbe8
--- /dev/null
+++ b/semantic_router/matchers/base.py
@@ -0,0 +1,18 @@
+from pydantic import BaseModel
+
+from semantic_router.retrievers import BaseRetriever
+from semantic_router.rankers import BaseRanker
+from semantic_router.schema import Decision
+
+
+class BaseMatcher(BaseModel):
+    retriever: BaseRetriever | None
+    ranker: BaseRanker | None
+    top_k: int | None
+    top_n: int | None
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def __call__(self, query: str, decisions: list[Decision]) -> str:
+        raise NotImplementedError("Subclasses must implement this method")
\ No newline at end of file
diff --git a/semantic_router/matchers/ranker_only.py b/semantic_router/matchers/ranker_only.py
new file mode 100644
index 00000000..08b7fe2e
--- /dev/null
+++ b/semantic_router/matchers/ranker_only.py
@@ -0,0 +1 @@
+from semantic_router import rankers
\ No newline at end of file
diff --git a/semantic_router/matchers/two_stage.py b/semantic_router/matchers/two_stage.py
new file mode 100644
index 00000000..6b570030
--- /dev/null
+++ b/semantic_router/matchers/two_stage.py
@@ -0,0 +1,59 @@
+import numpy as np
+
+from semantic_router.rankers import (
+    BaseRanker,
+    CohereRanker
+)
+from semantic_router.retrievers import (
+    BaseRetriever,
+    CohereRetriever
+)
+from semantic_router.matchers import BaseMatcher
+from semantic_router.schema import Decision
+
+
+class TwoStageMatcher(BaseMatcher):
+    def __init__(
+        self,
+        retriever: BaseRetriever | None,
+        ranker: BaseRanker | None,
+        top_k: int = 25,
+        top_n: int = 5
+    ):
+        super().__init__(
+            retriever=retriever, ranker=ranker, top_k=top_k, top_n=top_n
+        )
+        if retriever is None:
+            retriever = CohereRetriever(
+                name="embed-english-v3.0",
+                top_k=top_k
+            )
+        if ranker is None:
+            ranker = CohereRanker(
+                name="rerank-english-v2.0",
+                top_n=top_n
+            )
+    
+    def __call__(self, query: str, decisions: list[Decision]) -> str:
+        pass
+
+    def add(self, decision: Decision):
+        self._add_decision(decision=decision)
+
+    def _add_decision(self, decision: Decision):
+        # create embeddings for first stage
+        embeds = self.retriever(decision.utterances)
+        # create a decision array for decision categories
+        if self.categories is None:
+            self.categories = np.array([decision.name] * len(embeds))
+        else:
+            str_arr = np.array([decision.name] * len(embeds))
+            self.categories = np.concatenate([self.categories, str_arr])
+        # create utterance array (the index)
+        if self.index is None:
+            self.index = np.array(embeds)
+        else:
+            embed_arr = np.array(embeds)
+            self.index = np.concatenate([self.index, embed_arr])
+
+    
\ No newline at end of file
diff --git a/semantic_router/rankers/__init__.py b/semantic_router/rankers/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/semantic_router/rankers/base.py b/semantic_router/rankers/base.py
new file mode 100644
index 00000000..5d326f33
--- /dev/null
+++ b/semantic_router/rankers/base.py
@@ -0,0 +1,12 @@
+from pydantic import BaseModel
+
+
+class BaseRanker(BaseModel):
+    name: str
+    top_n: int = 5
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def __call__(self, query: str, docs: list[str]) -> list[str]:
+        raise NotImplementedError("Subclasses must implement this method")
diff --git a/semantic_router/rankers/cohere.py b/semantic_router/rankers/cohere.py
new file mode 100644
index 00000000..b703a960
--- /dev/null
+++ b/semantic_router/rankers/cohere.py
@@ -0,0 +1,31 @@
+import os
+
+import cohere
+
+from semantic_router.rankers import BaseReranker
+
+
+class CohereRanker(BaseReranker):
+    client: cohere.Client | None
+
+    def __init__(
+        self, name: str = "rerank-english-v2.0",
+        top_n: int = 5,
+        cohere_api_key: str | None = None
+    ):
+        super().__init__(name=name, top_n=top_n)
+        cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
+        if cohere_api_key is None:
+            raise ValueError("Cohere API key cannot be 'None'.")
+        self.client = cohere.Client(cohere_api_key)
+
+    def __call__(self, query: str, docs: list[str]) -> list[str]:
+        # get top_n results
+        results = self.client.rerank(
+            query=query, documents=docs, top_n=self.top_n,
+            model=self.name
+        )
+        # get indices of entries that are ranked highest by cohere
+        top_idx = [r.index for r in results]
+        top_docs = [docs[i] for i in top_idx]
+        return top_idx, top_docs
\ No newline at end of file
diff --git a/semantic_router/retrievers/__init__.py b/semantic_router/retrievers/__init__.py
new file mode 100644
index 00000000..0fcaa6d2
--- /dev/null
+++ b/semantic_router/retrievers/__init__.py
@@ -0,0 +1,5 @@
+from .base import BaseRetriever
+from .cohere import CohereRetriever
+from .openai import OpenAIRetriever
+
+__all__ = ["BaseRetriever", "CohereRetriever", "OpenAIRetriever"]
diff --git a/semantic_router/encoders/base.py b/semantic_router/retrievers/base.py
similarity index 67%
rename from semantic_router/encoders/base.py
rename to semantic_router/retrievers/base.py
index 4b5ca40d..4274e074 100644
--- a/semantic_router/encoders/base.py
+++ b/semantic_router/retrievers/base.py
@@ -1,11 +1,11 @@
 from pydantic import BaseModel
 
 
-class BaseEncoder(BaseModel):
+class BaseRetriever(BaseModel):
     name: str
 
     class Config:
         arbitrary_types_allowed = True
 
-    def __call__(self, texts: list[str]) -> list[float]:
+    def __call__(self, docs: list[str]) -> list[float]:
         raise NotImplementedError("Subclasses must implement this method")
diff --git a/semantic_router/retrievers/bm25.py b/semantic_router/retrievers/bm25.py
new file mode 100644
index 00000000..2a68a3ff
--- /dev/null
+++ b/semantic_router/retrievers/bm25.py
@@ -0,0 +1,21 @@
+import os
+
+from pinecone_text import BM25Encoder
+
+from semantic_router.retrievers import BaseRetriever
+
+
+class BM25Retriever(BaseRetriever):
+    def __init__(self, name: str = "bm25"):
+        super().__init__(name=name)
+        self.model = BM25Encoder()
+
+    def __call__(self, docs: list[str]) -> list[list[float]]:
+        if self.params is None:
+            raise ValueError("BM25 model not trained, must call `.fit` first.")
+        embeds = self.model.encode_doocuments(docs)
+        return embeds.embeddings
+
+    def fit(self, docs: list[str]):
+        params = self.model.fit(docs)
+        self.model.set_params(**params)
\ No newline at end of file
diff --git a/semantic_router/encoders/cohere.py b/semantic_router/retrievers/cohere.py
similarity index 71%
rename from semantic_router/encoders/cohere.py
rename to semantic_router/retrievers/cohere.py
index 0ed2ecc0..d2334f91 100644
--- a/semantic_router/encoders/cohere.py
+++ b/semantic_router/retrievers/cohere.py
@@ -2,10 +2,10 @@ import os
 
 import cohere
 
-from semantic_router.encoders import BaseEncoder
+from semantic_router.retrievers import BaseRetriever
 
 
-class CohereEncoder(BaseEncoder):
+class CohereRetriever(BaseRetriever):
     client: cohere.Client | None
 
     def __init__(
@@ -17,12 +17,12 @@ class CohereEncoder(BaseEncoder):
             raise ValueError("Cohere API key cannot be 'None'.")
         self.client = cohere.Client(cohere_api_key)
 
-    def __call__(self, texts: list[str]) -> list[list[float]]:
+    def __call__(self, docs: list[str]) -> list[list[float]]:
         if self.client is None:
             raise ValueError("Cohere client is not initialized.")
-        if len(texts) == 1:
+        if len(docs) == 1:
             input_type = "search_query"
         else:
             input_type = "search_document"
-        embeds = self.client.embed(texts, input_type=input_type, model=self.name)
+        embeds = self.client.embed(docs, input_type=input_type, model=self.name)
         return embeds.embeddings
diff --git a/semantic_router/retrievers/huggingface.py b/semantic_router/retrievers/huggingface.py
new file mode 100644
index 00000000..9c8f2f05
--- /dev/null
+++ b/semantic_router/retrievers/huggingface.py
@@ -0,0 +1,9 @@
+from semantic_router.retrievers import BaseRetriever
+
+
+class HuggingFaceRetriever(BaseRetriever):
+    def __init__(self, name: str):
+        self.name = name
+
+    def __call__(self, docs: list[str]) -> list[float]:
+        raise NotImplementedError
diff --git a/semantic_router/encoders/openai.py b/semantic_router/retrievers/openai.py
similarity index 81%
rename from semantic_router/encoders/openai.py
rename to semantic_router/retrievers/openai.py
index 87feec4c..2dbfd880 100644
--- a/semantic_router/encoders/openai.py
+++ b/semantic_router/retrievers/openai.py
@@ -4,17 +4,17 @@ from time import sleep
 import openai
 from openai.error import RateLimitError
 
-from semantic_router.encoders import BaseEncoder
+from semantic_router.retrievers import BaseRetriever
 
 
-class OpenAIEncoder(BaseEncoder):
+class OpenAIRetriever(BaseRetriever):
     def __init__(self, name: str, openai_api_key: str | None = None):
         super().__init__(name=name)
         openai.api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
         if openai.api_key is None:
             raise ValueError("OpenAI API key cannot be 'None'.")
 
-    def __call__(self, texts: list[str]) -> list[list[float]]:
+    def __call__(self, docs: list[str]) -> list[list[float]]:
         """Encode a list of texts using the OpenAI API. Returns a list of
         vector embeddings.
         """
@@ -22,7 +22,7 @@ class OpenAIEncoder(BaseEncoder):
         # exponential backoff in case of RateLimitError
         for j in range(5):
             try:
-                res = openai.Embedding.create(input=texts, engine=self.name)
+                res = openai.Embedding.create(input=docs, engine=self.name)
                 if isinstance(res, dict) and "data" in res:
                     break
             except RateLimitError:
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 439f2322..ea0ad2cf 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -3,11 +3,11 @@ from enum import Enum
 from pydantic import BaseModel
 from pydantic.dataclasses import dataclass
 
-from semantic_router.encoders import (
-    BaseEncoder,
-    CohereEncoder,
-    HuggingFaceEncoder,
-    OpenAIEncoder,
+from semantic_router.retrievers import (
+    BaseRetriever,
+    CohereRetriever,
+    HuggingFaceRetriever,
+    OpenAIRetriever,
 )
 
 
@@ -17,27 +17,27 @@ class Decision(BaseModel):
     description: str | None = None
 
 
-class EncoderType(Enum):
+class RetrieverType(Enum):
     HUGGINGFACE = "huggingface"
     OPENAI = "openai"
     COHERE = "cohere"
 
 
 @dataclass
-class Encoder:
-    type: EncoderType
+class Retriever:
+    type: RetrieverType
     name: str
-    model: BaseEncoder
+    model: BaseRetriever
 
     def __init__(self, type: str, name: str):
-        self.type = EncoderType(type)
+        self.type = RetrieverType(type)
         self.name = name
-        if self.type == EncoderType.HUGGINGFACE:
-            self.model = HuggingFaceEncoder(name)
-        elif self.type == EncoderType.OPENAI:
-            self.model = OpenAIEncoder(name)
-        elif self.type == EncoderType.COHERE:
-            self.model = CohereEncoder(name)
+        if self.type == RetrieverType.HUGGINGFACE:
+            self.model = HuggingFaceRetriever(name)
+        elif self.type == RetrieverType.OPENAI:
+            self.model = OpenAIRetriever(name)
+        elif self.type == RetrieverType.COHERE:
+            self.model = CohereRetriever(name)
 
     def __call__(self, texts: list[str]) -> list[float]:
         return self.model(texts)
-- 
GitLab