Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
llamacpp.py 3.79 KiB
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional

from pydantic import PrivateAttr

from semantic_router.llms.base import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.logger import logger


class LlamaCppLLM(BaseLLM):
    """LLM for LlamaCPP. Enables fully local LLM use, helpful for local implementation of
    dynamic routes.
    """
    llm: Any
    grammar: Optional[Any] = None
    _llama_cpp: Any = PrivateAttr()

    def __init__(
        self,
        llm: Any,
        name: str = "llama.cpp",
        temperature: float = 0.2,
        max_tokens: Optional[int] = 200,
        grammar: Optional[Any] = None,
    ):
        """Initialize the LlamaCPPLLM.

        :param llm: The LLM to use.
        :type llm: Any
        :param name: The name of the LLM.
        :type name: str
        :param temperature: The temperature of the LLM.
        :type temperature: float
        :param max_tokens: The maximum number of tokens to generate.
        :type max_tokens: Optional[int]
        :param grammar: The grammar to use.
        :type grammar: Optional[Any]
        """
        super().__init__(
            name=name,
            llm=llm,
            temperature=temperature,
            max_tokens=max_tokens,
            grammar=grammar,
        )

        try:
            import llama_cpp
        except ImportError:
            raise ImportError(
                "Please install LlamaCPP to use Llama CPP llm. "
                "You can install it with: "
                "`pip install 'semantic-router[local]'`"
            )
        self._llama_cpp = llama_cpp
        self.llm = llm
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.grammar = grammar

    def __call__(
        self,
        messages: List[Message],
    ) -> str:
        """Call the LlamaCPPLLM.

        :param messages: The messages to pass to the LlamaCPPLLM.
        :type messages: List[Message]
        :return: The response from the LlamaCPPLLM.
        :rtype: str
        """
        try:
            completion = self.llm.create_chat_completion(
                messages=[m.to_llamacpp() for m in messages],
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                grammar=self.grammar,
                stream=False,
            )
            assert isinstance(completion, dict)  # keep mypy happy
            output = completion["choices"][0]["message"]["content"]

            if not output:
                raise Exception("No output generated")
            return output
        except Exception as e:
            logger.error(f"LLM error: {e}")
            raise

    @contextmanager
    def _grammar(self):
        """Context manager for the grammar.

        :return: The grammar.
        :rtype: Any
        """
        grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf")
        assert grammar_path.exists(), f"{grammar_path}\ndoes not exist"
        try:
            self.grammar = self._llama_cpp.LlamaGrammar.from_file(grammar_path)
            yield
        finally:
            self.grammar = None

    def extract_function_inputs(
        self, query: str, function_schemas: List[Dict[str, Any]]
    ) -> List[Dict[str, Any]]:
        """Extract the function inputs from the query.

        :param query: The query to extract the function inputs from.
        :type query: str
        :param function_schemas: The function schemas to extract the function inputs from.
        :type function_schemas: List[Dict[str, Any]]
        :return: The function inputs.
        :rtype: List[Dict[str, Any]]
        """
        with self._grammar():
            return super().extract_function_inputs(
                query=query, function_schemas=function_schemas
            )