import os
from typing import Any, List, Optional

from pydantic import PrivateAttr

from semantic_router.llms import BaseLLM
from semantic_router.schema import Message


class CohereLLM(BaseLLM):
    """LLM for Cohere. Requires a Cohere API key from https://dashboard.cohere.com/api-keys.

    This class provides functionality to interact with the Cohere API for generating text responses.
    It extends the BaseLLM class and implements the __call__ method to generate text responses.
    """
    _client: Any = PrivateAttr()

    def __init__(
        self,
        name: Optional[str] = None,
        cohere_api_key: Optional[str] = None,
    ):
        """Initialize the CohereLLM.

        :param name: The name of the Cohere model to use can also be set via the
            COHERE_CHAT_MODEL_NAME environment variable.
        :type name: Optional[str]
        :param cohere_api_key: The API key for the Cohere client. Can also be set via the
            COHERE_API_KEY environment variable.
        :type cohere_api_key: Optional[str]
        """
        if name is None:
            name = os.getenv("COHERE_CHAT_MODEL_NAME", "command")
        super().__init__(name=name)
        self._client = self._initialize_client(cohere_api_key)

    def _initialize_client(self, cohere_api_key: Optional[str] = None):
        """Initialize the Cohere client.

        :param cohere_api_key: The API key for the Cohere client. Can also be set via the
            COHERE_API_KEY environment variable.
        :type cohere_api_key: Optional[str]
        """
        try:
            import cohere
        except ImportError:
            raise ImportError(
                "Please install Cohere to use CohereLLM. "
                "You can install it with: "
                "`pip install 'semantic-router[cohere]'`"
            )
        cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
        if cohere_api_key is None:
            raise ValueError("Cohere API key cannot be 'None'.")
        try:
            client = cohere.Client(cohere_api_key)
        except Exception as e:
            raise ValueError(
                f"Cohere API client failed to initialize. Error: {e}"
            ) from e
        return client

    def __call__(self, messages: List[Message]) -> str:
        """Call the Cohere client.

        :param messages: The messages to pass to the Cohere client.
        :type messages: List[Message]
        :return: The response from the Cohere client.
        :rtype: str
        """
        if self._client is None:
            raise ValueError("Cohere client is not initialized.")
        try:
            completion = self._client.chat(
                model=self.name,
                chat_history=[m.to_cohere() for m in messages[:-1]],
                message=messages[-1].content,
            )

            output = completion.text

            if not output:
                raise Exception("No output generated")
            return output

        except Exception as e:
            raise ValueError(f"Cohere API call failed. Error: {e}") from e