Skip to content
Snippets Groups Projects
Unverified Commit b6d5c4f7 authored by James Briggs's avatar James Briggs
Browse files

fix for Encoder class

parent 4c49baf5
No related branches found
No related tags found
No related merge requests found
...@@ -23,12 +23,12 @@ class RouteChoice(BaseModel): ...@@ -23,12 +23,12 @@ class RouteChoice(BaseModel):
@dataclass @dataclass
class Encoder: class Encoder:
type: str type: EncoderType
name: str | None name: str | None
model: BaseEncoder model: BaseEncoder
def __init__(self, type: str, name: str | None): def __init__(self, type: str, name: str | None):
self.type = type self.type = EncoderType(type)
self.name = name self.name = name
if self.type == EncoderType.HUGGINGFACE: if self.type == EncoderType.HUGGINGFACE:
raise NotImplementedError raise NotImplementedError
...@@ -37,7 +37,7 @@ class Encoder: ...@@ -37,7 +37,7 @@ class Encoder:
elif self.type == EncoderType.COHERE: elif self.type == EncoderType.COHERE:
self.model = CohereEncoder(name) self.model = CohereEncoder(name)
else: else:
raise NotImplementedError raise ValueError
def __call__(self, texts: list[str]) -> list[list[float]]: def __call__(self, texts: list[str]) -> list[list[float]]:
return self.model(texts) return self.model(texts)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment