-
James Briggs authoredJames Briggs authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
llamacpp.py 3.79 KiB
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional
from pydantic import PrivateAttr
from semantic_router.llms.base import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.logger import logger
class LlamaCppLLM(BaseLLM):
"""LLM for LlamaCPP. Enables fully local LLM use, helpful for local implementation of
dynamic routes.
"""
llm: Any
grammar: Optional[Any] = None
_llama_cpp: Any = PrivateAttr()
def __init__(
self,
llm: Any,
name: str = "llama.cpp",
temperature: float = 0.2,
max_tokens: Optional[int] = 200,
grammar: Optional[Any] = None,
):
"""Initialize the LlamaCPPLLM.
:param llm: The LLM to use.
:type llm: Any
:param name: The name of the LLM.
:type name: str
:param temperature: The temperature of the LLM.
:type temperature: float
:param max_tokens: The maximum number of tokens to generate.
:type max_tokens: Optional[int]
:param grammar: The grammar to use.
:type grammar: Optional[Any]
"""
super().__init__(
name=name,
llm=llm,
temperature=temperature,
max_tokens=max_tokens,
grammar=grammar,
)
try:
import llama_cpp
except ImportError:
raise ImportError(
"Please install LlamaCPP to use Llama CPP llm. "
"You can install it with: "
"`pip install 'semantic-router[local]'`"
)
self._llama_cpp = llama_cpp
self.llm = llm
self.temperature = temperature
self.max_tokens = max_tokens
self.grammar = grammar
def __call__(
self,
messages: List[Message],
) -> str:
"""Call the LlamaCPPLLM.
:param messages: The messages to pass to the LlamaCPPLLM.
:type messages: List[Message]
:return: The response from the LlamaCPPLLM.
:rtype: str
"""
try:
completion = self.llm.create_chat_completion(
messages=[m.to_llamacpp() for m in messages],
temperature=self.temperature,
max_tokens=self.max_tokens,
grammar=self.grammar,
stream=False,
)
assert isinstance(completion, dict) # keep mypy happy
output = completion["choices"][0]["message"]["content"]
if not output:
raise Exception("No output generated")
return output
except Exception as e:
logger.error(f"LLM error: {e}")
raise
@contextmanager
def _grammar(self):
"""Context manager for the grammar.
:return: The grammar.
:rtype: Any
"""
grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf")
assert grammar_path.exists(), f"{grammar_path}\ndoes not exist"
try:
self.grammar = self._llama_cpp.LlamaGrammar.from_file(grammar_path)
yield
finally:
self.grammar = None
def extract_function_inputs(
self, query: str, function_schemas: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Extract the function inputs from the query.
:param query: The query to extract the function inputs from.
:type query: str
:param function_schemas: The function schemas to extract the function inputs from.
:type function_schemas: List[Dict[str, Any]]
:return: The function inputs.
:rtype: List[Dict[str, Any]]
"""
with self._grammar():
return super().extract_function_inputs(
query=query, function_schemas=function_schemas
)