From 48e2330e26962ec3b0cd7b28655901209d547698 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Wed, 8 Nov 2023 15:32:59 +0100 Subject: [PATCH] add cohere encoder --- decision_layer/encoders/cohere.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/decision_layer/encoders/cohere.py b/decision_layer/encoders/cohere.py index 4700f52b..ec8837a1 100644 --- a/decision_layer/encoders/cohere.py +++ b/decision_layer/encoders/cohere.py @@ -1,8 +1,22 @@ +import os +import cohere from decision_layer.encoders import BaseEncoder class CohereEncoder(BaseEncoder): - def __init__(self, name: str): - super().__init__(name) + client: cohere.Client | None + def __init__(self, name: str, cohere_api_key: str | None = None): + super().__init__(name=name, client=None) + cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") + if cohere_api_key is None: + raise ValueError("Cohere API key cannot be 'None'.") + self.client = cohere.Client(cohere_api_key) def __call__(self, texts: list[str]) -> list[float]: - raise NotImplementedError \ No newline at end of file + if len(texts) == 1: + input_type = "search_query" + else: + input_type = "search_document" + embeds = self.client.embed( + texts, input_type=input_type, model="embed-english-v3.0" + ) + return embeds.embeddings \ No newline at end of file -- GitLab