From f906565002e5828434dbdb9df316ae2abdbac4be Mon Sep 17 00:00:00 2001 From: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com> Date: Wed, 31 Jan 2024 20:24:30 -0500 Subject: [PATCH] [Cohere] Add client name to Cohere calls (#10384) Add client name to cohere calls --- llama_index/embeddings/cohereai.py | 2 +- llama_index/finetuning/rerankers/cohere_reranker.py | 2 +- llama_index/llms/cohere.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llama_index/embeddings/cohereai.py b/llama_index/embeddings/cohereai.py index 1fd4f19edd..91df867f0d 100644 --- a/llama_index/embeddings/cohereai.py +++ b/llama_index/embeddings/cohereai.py @@ -111,7 +111,7 @@ class CohereEmbedding(BaseEmbedding): raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}") super().__init__( - cohere_client=cohere.Client(cohere_api_key), + cohere_client=cohere.Client(cohere_api_key, client_name="llama_index"), cohere_api_key=cohere_api_key, model_name=model_name, truncate=truncate, diff --git a/llama_index/finetuning/rerankers/cohere_reranker.py b/llama_index/finetuning/rerankers/cohere_reranker.py index 220de621dd..756d3da039 100644 --- a/llama_index/finetuning/rerankers/cohere_reranker.py +++ b/llama_index/finetuning/rerankers/cohere_reranker.py @@ -38,7 +38,7 @@ class CohereRerankerFinetuneEngine(BaseCohereRerankerFinetuningEngine): "Must pass in cohere api key or " "specify via COHERE_API_KEY environment variable " ) - self._model = cohere.Client(self.api_key) + self._model = cohere.Client(self.api_key, client_name="llama_index") self._train_file_name = train_file_name self._val_file_name = val_file_name self._model_name = model_name diff --git a/llama_index/llms/cohere.py b/llama_index/llms/cohere.py index 94951cf9c4..fe19c2755b 100644 --- a/llama_index/llms/cohere.py +++ b/llama_index/llms/cohere.py @@ -69,8 +69,8 @@ class Cohere(LLM): additional_kwargs = additional_kwargs or {} callback_manager = callback_manager or CallbackManager([]) - self._client = cohere.Client(api_key) - self._aclient = cohere.AsyncClient(api_key) + self._client = cohere.Client(api_key, client_name="llama_index") + self._aclient = cohere.AsyncClient(api_key, client_name="llama_index") super().__init__( temperature=temperature, -- GitLab