from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional

from pydantic import PrivateAttr

from semantic_router.llms.base import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.logger import logger


class LlamaCppLLM(BaseLLM):
    """LLM for LlamaCPP. Enables fully local LLM use, helpful for local implementation of
    dynamic routes.
    """
    llm: Any
    grammar: Optional[Any] = None
    _llama_cpp: Any = PrivateAttr()

    def __init__(
        self,
        llm: Any,
        name: str = "llama.cpp",
        temperature: float = 0.2,
        max_tokens: Optional[int] = 200,
        grammar: Optional[Any] = None,
    ):
        """Initialize the LlamaCPPLLM.

        :param llm: The LLM to use.
        :type llm: Any
        :param name: The name of the LLM.
        :type name: str
        :param temperature: The temperature of the LLM.
        :type temperature: float
        :param max_tokens: The maximum number of tokens to generate.
        :type max_tokens: Optional[int]
        :param grammar: The grammar to use.
        :type grammar: Optional[Any]
        """
        super().__init__(
            name=name,
            llm=llm,
            temperature=temperature,
            max_tokens=max_tokens,
            grammar=grammar,
        )

        try:
            import llama_cpp
        except ImportError:
            raise ImportError(
                "Please install LlamaCPP to use Llama CPP llm. "
                "You can install it with: "
                "`pip install 'semantic-router[local]'`"
            )
        self._llama_cpp = llama_cpp
        self.llm = llm
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.grammar = grammar

    def __call__(
        self,
        messages: List[Message],
    ) -> str:
        """Call the LlamaCPPLLM.

        :param messages: The messages to pass to the LlamaCPPLLM.
        :type messages: List[Message]
        :return: The response from the LlamaCPPLLM.
        :rtype: str
        """
        try:
            completion = self.llm.create_chat_completion(
                messages=[m.to_llamacpp() for m in messages],
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                grammar=self.grammar,
                stream=False,
            )
            assert isinstance(completion, dict)  # keep mypy happy
            output = completion["choices"][0]["message"]["content"]

            if not output:
                raise Exception("No output generated")
            return output
        except Exception as e:
            logger.error(f"LLM error: {e}")
            raise

    @contextmanager
    def _grammar(self):
        """Context manager for the grammar.

        :return: The grammar.
        :rtype: Any
        """
        grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf")
        assert grammar_path.exists(), f"{grammar_path}\ndoes not exist"
        try:
            self.grammar = self._llama_cpp.LlamaGrammar.from_file(grammar_path)
            yield
        finally:
            self.grammar = None

    def extract_function_inputs(
        self, query: str, function_schemas: List[Dict[str, Any]]
    ) -> List[Dict[str, Any]]:
        """Extract the function inputs from the query.

        :param query: The query to extract the function inputs from.
        :type query: str
        :param function_schemas: The function schemas to extract the function inputs from.
        :type function_schemas: List[Dict[str, Any]]
        :return: The function inputs.
        :rtype: List[Dict[str, Any]]
        """
        with self._grammar():
            return super().extract_function_inputs(
                query=query, function_schemas=function_schemas
            )