From f906565002e5828434dbdb9df316ae2abdbac4be Mon Sep 17 00:00:00 2001
From: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com>
Date: Wed, 31 Jan 2024 20:24:30 -0500
Subject: [PATCH] [Cohere] Add client name to Cohere calls (#10384)

Add client name to cohere calls
---
 llama_index/embeddings/cohereai.py                  | 2 +-
 llama_index/finetuning/rerankers/cohere_reranker.py | 2 +-
 llama_index/llms/cohere.py                          | 4 ++--
 3 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/llama_index/embeddings/cohereai.py b/llama_index/embeddings/cohereai.py
index 1fd4f19edd..91df867f0d 100644
--- a/llama_index/embeddings/cohereai.py
+++ b/llama_index/embeddings/cohereai.py
@@ -111,7 +111,7 @@ class CohereEmbedding(BaseEmbedding):
             raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}")
 
         super().__init__(
-            cohere_client=cohere.Client(cohere_api_key),
+            cohere_client=cohere.Client(cohere_api_key, client_name="llama_index"),
             cohere_api_key=cohere_api_key,
             model_name=model_name,
             truncate=truncate,
diff --git a/llama_index/finetuning/rerankers/cohere_reranker.py b/llama_index/finetuning/rerankers/cohere_reranker.py
index 220de621dd..756d3da039 100644
--- a/llama_index/finetuning/rerankers/cohere_reranker.py
+++ b/llama_index/finetuning/rerankers/cohere_reranker.py
@@ -38,7 +38,7 @@ class CohereRerankerFinetuneEngine(BaseCohereRerankerFinetuningEngine):
                 "Must pass in cohere api key or "
                 "specify via COHERE_API_KEY environment variable "
             )
-        self._model = cohere.Client(self.api_key)
+        self._model = cohere.Client(self.api_key, client_name="llama_index")
         self._train_file_name = train_file_name
         self._val_file_name = val_file_name
         self._model_name = model_name
diff --git a/llama_index/llms/cohere.py b/llama_index/llms/cohere.py
index 94951cf9c4..fe19c2755b 100644
--- a/llama_index/llms/cohere.py
+++ b/llama_index/llms/cohere.py
@@ -69,8 +69,8 @@ class Cohere(LLM):
         additional_kwargs = additional_kwargs or {}
         callback_manager = callback_manager or CallbackManager([])
 
-        self._client = cohere.Client(api_key)
-        self._aclient = cohere.AsyncClient(api_key)
+        self._client = cohere.Client(api_key, client_name="llama_index")
+        self._aclient = cohere.AsyncClient(api_key, client_name="llama_index")
 
         super().__init__(
             temperature=temperature,
-- 
GitLab