diff --git a/docs/05-local-execution.ipynb b/docs/05-local-execution.ipynb index 25e4f012528a1471dd088bf1e4d79c83cef18ec1..c818b9cc7a1e6ab278d5d7d4d84be68d1c105c4a 100644 --- a/docs/05-local-execution.ipynb +++ b/docs/05-local-execution.ipynb @@ -87,6 +87,7 @@ "from semantic_router import Route\n", "from semantic_router.utils.function_call import get_schema\n", "\n", + "\n", "def get_time(timezone: str) -> str:\n", " \"\"\"Finds the current time in a specific timezone.\n", "\n", @@ -100,6 +101,7 @@ " now = datetime.now(ZoneInfo(timezone))\n", " return now.strftime(\"%H:%M\")\n", "\n", + "\n", "time_schema = get_schema(get_time)\n", "time_schema\n", "time = Route(\n", @@ -314,9 +316,14 @@ "from llama_cpp import Llama\n", "from semantic_router.llms import LlamaCppLLM\n", "\n", - "enable_gpu = True # offload LLM layers to the GPU (must fit in memory)\n", + "enable_gpu = True # offload LLM layers to the GPU (must fit in memory)\n", "\n", - "_llm = Llama(model_path=\"./mistral-7b-instruct-v0.2.Q4_0.gguf\", n_gpu_layers=-1 if enable_gpu else 0, n_ctx=2048, verbose=False)\n", + "_llm = Llama(\n", + " model_path=\"./mistral-7b-instruct-v0.2.Q4_0.gguf\",\n", + " n_gpu_layers=-1 if enable_gpu else 0,\n", + " n_ctx=2048,\n", + " verbose=False,\n", + ")\n", "llm = LlamaCppLLM(name=\"Mistral-7B-v0.2-Instruct\", llm=_llm, max_tokens=None)\n", "\n", "rl = RouteLayer(encoder=encoder, routes=routes, llm=llm)" diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 40a7f13aa9a42ac5518f136be06dbe8d1c8eecfe..2eaae7ae5f0fcd07b28d5e22fe9566c73614e445 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -16,14 +16,18 @@ class BaseLLM(BaseModel): def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") - def _is_valid_inputs(self, inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool: + def _is_valid_inputs( + self, inputs: dict[str, Any], function_schema: dict[str, Any] + ) -> bool: """Validate the extracted inputs against the function schema""" try: # Extract parameter names and types from the signature string signature = function_schema["signature"] param_info = [param.strip() for param in signature[1:-1].split(",")] param_names = [info.split(":")[0].strip() for info in param_info] - param_types = [info.split(":")[1].strip().split("=")[0].strip() for info in param_info] + param_types = [ + info.split(":")[1].strip().split("=")[0].strip() for info in param_info + ] for name, type_str in zip(param_names, param_types): if name not in inputs: @@ -34,7 +38,9 @@ class BaseLLM(BaseModel): logger.error(f"Input validation error: {str(e)}") return False - def extract_function_inputs(self, query: str, function_schema: dict[str, Any]) -> dict: + def extract_function_inputs( + self, query: str, function_schema: dict[str, Any] + ) -> dict: logger.info("Extracting function input...") prompt = f""" diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index 5554b73704ef6768c231ce4abaf6b784655595cd..c70745f220bbcb9e3c587682ba7116bd8ac447ca 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -1,6 +1,6 @@ +from contextlib import contextmanager from pathlib import Path from typing import Any -from contextlib import contextmanager from llama_cpp import Llama, LlamaGrammar @@ -22,6 +22,8 @@ class LlamaCppLLM(BaseLLM): temperature: float = 0.2, max_tokens: int = 200, ): + if not llm: + raise ValueError("`llama_cpp.Llama` llm is required") super().__init__(name=name) self.llm = llm self.temperature = temperature @@ -37,6 +39,7 @@ class LlamaCppLLM(BaseLLM): temperature=self.temperature, max_tokens=self.max_tokens, grammar=self.grammar, + stream=False, ) output = completion["choices"][0]["message"]["content"] @@ -58,6 +61,10 @@ class LlamaCppLLM(BaseLLM): finally: self.grammar = None - def extract_function_inputs(self, query: str, function_schema: dict[str, Any]) -> dict: + def extract_function_inputs( + self, query: str, function_schema: dict[str, Any] + ) -> dict: with self._grammar(): - return super().extract_function_inputs(query=query, function_schema=function_schema) + return super().extract_function_inputs( + query=query, function_schema=function_schema + ) diff --git a/semantic_router/route.py b/semantic_router/route.py index 8c797ccec7e3460669ff4e5830f399cd156df5c3..b3f36b8beb63d36a57e3ccc29d2daa0ecf98352b 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -19,13 +19,17 @@ def is_valid(route_config: str) -> bool: for item in output_json: missing_keys = [key for key in required_keys if key not in item] if missing_keys: - logger.warning(f"Missing keys in route config: {', '.join(missing_keys)}") + logger.warning( + f"Missing keys in route config: {', '.join(missing_keys)}" + ) return False return True else: missing_keys = [key for key in required_keys if key not in output_json] if missing_keys: - logger.warning(f"Missing keys in route config: {', '.join(missing_keys)}") + logger.warning( + f"Missing keys in route config: {', '.join(missing_keys)}" + ) return False else: return True @@ -44,9 +48,14 @@ class Route(BaseModel): def __call__(self, query: str) -> RouteChoice: if self.function_schema: if not self.llm: - raise ValueError("LLM is required for dynamic routes. Please ensure the `llm` " "attribute is set.") + raise ValueError( + "LLM is required for dynamic routes. Please ensure the `llm` " + "attribute is set." + ) # if a function schema is provided we generate the inputs - extracted_inputs = self.llm.extract_function_inputs(query=query, function_schema=self.function_schema) + extracted_inputs = self.llm.extract_function_inputs( + query=query, function_schema=self.function_schema + ) func_call = extracted_inputs else: # otherwise we just pass None for the call diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index 2c0be647a095b15e28048a5972a37124619dacf7..16e8ffe50bf9defb94dae99c910c89606e2d176d 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -1,11 +1,9 @@ import inspect -import json from typing import Any, Callable, Dict, List, Union from pydantic import BaseModel from semantic_router.llms import BaseLLM -from semantic_router.llms.llamacpp import LlamaCppLLM from semantic_router.schema import Message, RouteChoice from semantic_router.utils.logger import logger