diff --git a/decision_layer/encoders/cohere.py b/decision_layer/encoders/cohere.py index 4700f52bcb0560766dfdfc504fbdbe1d1cfa4047..ec8837a17dd485278779005d6156c010048d72f7 100644 --- a/decision_layer/encoders/cohere.py +++ b/decision_layer/encoders/cohere.py @@ -1,8 +1,22 @@ +import os +import cohere from decision_layer.encoders import BaseEncoder class CohereEncoder(BaseEncoder): - def __init__(self, name: str): - super().__init__(name) + client: cohere.Client | None + def __init__(self, name: str, cohere_api_key: str | None = None): + super().__init__(name=name, client=None) + cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") + if cohere_api_key is None: + raise ValueError("Cohere API key cannot be 'None'.") + self.client = cohere.Client(cohere_api_key) def __call__(self, texts: list[str]) -> list[float]: - raise NotImplementedError \ No newline at end of file + if len(texts) == 1: + input_type = "search_query" + else: + input_type = "search_document" + embeds = self.client.embed( + texts, input_type=input_type, model="embed-english-v3.0" + ) + return embeds.embeddings \ No newline at end of file