From 48e2330e26962ec3b0cd7b28655901209d547698 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Wed, 8 Nov 2023 15:32:59 +0100
Subject: [PATCH] add cohere encoder

---
 decision_layer/encoders/cohere.py | 20 +++++++++++++++++---
 1 file changed, 17 insertions(+), 3 deletions(-)

diff --git a/decision_layer/encoders/cohere.py b/decision_layer/encoders/cohere.py
index 4700f52b..ec8837a1 100644
--- a/decision_layer/encoders/cohere.py
+++ b/decision_layer/encoders/cohere.py
@@ -1,8 +1,22 @@
+import os
+import cohere
 from decision_layer.encoders import BaseEncoder
 
 class CohereEncoder(BaseEncoder):
-    def __init__(self, name: str):
-        super().__init__(name)
+    client: cohere.Client | None
+    def __init__(self, name: str, cohere_api_key: str | None = None):
+        super().__init__(name=name, client=None)
+        cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
+        if cohere_api_key is None:
+            raise ValueError("Cohere API key cannot be 'None'.")
+        self.client = cohere.Client(cohere_api_key)
 
     def __call__(self, texts: list[str]) -> list[float]:
-        raise NotImplementedError
\ No newline at end of file
+        if len(texts) == 1:
+            input_type = "search_query"
+        else:
+            input_type = "search_document"
+        embeds = self.client.embed(
+            texts, input_type=input_type, model="embed-english-v3.0"
+        )
+        return embeds.embeddings
\ No newline at end of file
-- 
GitLab