diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index 4dc02146920bd0cebf9d98a628d434220c968965..6be7784570c63645bb8f76cbd8de4d7545c2e609 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -8,7 +8,7 @@ from semantic_router.utils.defaults import EncoderDefault class CohereEncoder(BaseEncoder): - client: Any = PrivateAttr() + _client: Any = PrivateAttr() type: str = "cohere" input_type: Optional[str] = "search_query" @@ -27,7 +27,7 @@ class CohereEncoder(BaseEncoder): input_type=input_type, # type: ignore ) self.input_type = input_type - self.client = self._initialize_client(cohere_api_key) + self._client = self._initialize_client(cohere_api_key) def _initialize_client(self, cohere_api_key: Optional[str] = None): """Initializes the Cohere client. @@ -59,10 +59,10 @@ class CohereEncoder(BaseEncoder): return client def __call__(self, docs: List[str]) -> List[List[float]]: - if self.client is None: + if self._client is None: raise ValueError("Cohere client is not initialized.") try: - embeds = self.client.embed( + embeds = self._client.embed( texts=docs, input_type=self.input_type, model=self.name ) # Check for unsupported type. diff --git a/semantic_router/llms/cohere.py b/semantic_router/llms/cohere.py index 98dc445cbb9bd3057bd6e41ab5bcadac2087be1d..05a9b1bdb0f84f175103e741742777525469a3f1 100644 --- a/semantic_router/llms/cohere.py +++ b/semantic_router/llms/cohere.py @@ -8,7 +8,7 @@ from semantic_router.schema import Message class CohereLLM(BaseLLM): - client: Any = PrivateAttr() + _client: Any = PrivateAttr() def __init__( self, @@ -18,7 +18,7 @@ class CohereLLM(BaseLLM): if name is None: name = os.getenv("COHERE_CHAT_MODEL_NAME", "command") super().__init__(name=name) - self.client = self._initialize_client(cohere_api_key) + self._client = self._initialize_client(cohere_api_key) def _initialize_client(self, cohere_api_key: Optional[str] = None): try: @@ -41,10 +41,10 @@ class CohereLLM(BaseLLM): return client def __call__(self, messages: List[Message]) -> str: - if self.client is None: + if self._client is None: raise ValueError("Cohere client is not initialized.") try: - completion = self.client.chat( + completion = self._client.chat( model=self.name, chat_history=[m.to_cohere() for m in messages[:-1]], message=messages[-1].content,