From 9575e7fe150747ac469163777de16eb320b02559 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Sat, 13 Jan 2024 16:36:29 +0000
Subject: [PATCH] lint

---
 docs/05-local-execution.ipynb        | 48 ++++++++++++++--------------
 semantic_router/llms/base.py         |  3 ++
 semantic_router/llms/llamacpp.py     | 26 ++++++++++++---
 tests/unit/llms/test_llm_llamacpp.py |  3 +-
 4 files changed, 49 insertions(+), 31 deletions(-)

diff --git a/docs/05-local-execution.ipynb b/docs/05-local-execution.ipynb
index f03f77ae..4e3fbe6f 100644
--- a/docs/05-local-execution.ipynb
+++ b/docs/05-local-execution.ipynb
@@ -44,12 +44,12 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "id": "4f9b5729",
    "metadata": {},
    "outputs": [],
    "source": [
-    "#!CMAKE_ARGS=\"-DLLAMA_METAL=on\""
+    "!CMAKE_ARGS=\"-DLLAMA_METAL=on\""
    ]
   },
   {
@@ -85,7 +85,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": 2,
    "id": "e26db664-9dff-476a-84ef-edd7a8cdf1ba",
    "metadata": {},
    "outputs": [],
@@ -149,7 +149,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 3,
    "id": "fac95b0c-c61f-4158-b7d9-0221f7d0b65e",
    "metadata": {},
    "outputs": [
@@ -162,7 +162,7 @@
        " 'output': \"<class 'str'>\"}"
       ]
      },
-     "execution_count": 2,
+     "execution_count": 3,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -183,7 +183,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 4,
    "id": "5253c141-141b-4fda-b07c-a313393902ed",
    "metadata": {},
    "outputs": [
@@ -222,7 +222,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 5,
    "id": "772cec0d-7a0c-4c7e-9b7a-4a1864b0a8ec",
    "metadata": {
     "scrolled": true
@@ -324,7 +324,7 @@
       "llama_build_graph: non-view tensors processed: 676/676\n",
       "llama_new_context_with_model: compute buffer total size = 159.19 MiB\n",
       "ggml_backend_metal_buffer_type_alloc_buffer: allocated buffer, size =   156.02 MiB, ( 4332.22 / 21845.34)\n",
-      "\u001b[32m2024-01-13 14:48:23 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n"
+      "\u001b[32m2024-01-13 16:29:46 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n"
      ]
     }
    ],
@@ -349,7 +349,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 6,
    "id": "a8bd1da4-8ff7-4cd3-a5e3-fd79a938cc67",
    "metadata": {},
    "outputs": [
@@ -359,7 +359,7 @@
        "RouteChoice(name='chitchat', function_call=None, similarity_score=None, trigger=None)"
       ]
      },
-     "execution_count": 5,
+     "execution_count": 6,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -370,7 +370,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 7,
    "id": "c6ccbea2-376b-4b28-9b79-d2e9c71e99f4",
    "metadata": {
     "scrolled": true
@@ -414,7 +414,7 @@
       "ws_30 ::= [ <U+0009><U+000A>] ws \n",
       "ws_31 ::= ws_30 | \n",
       "\n",
-      "\u001b[32m2024-01-13 14:51:39 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n"
+      "\u001b[32m2024-01-13 16:29:47 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n"
      ]
     },
     {
@@ -427,10 +427,10 @@
     {
      "data": {
       "text/plain": [
-       "'09:51'"
+       "'11:29'"
       ]
      },
-     "execution_count": 6,
+     "execution_count": 7,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -443,7 +443,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 8,
    "id": "720f976a",
    "metadata": {},
    "outputs": [
@@ -485,7 +485,7 @@
       "ws_30 ::= [ <U+0009><U+000A>] ws \n",
       "ws_31 ::= ws_30 | \n",
       "\n",
-      "\u001b[32m2024-01-13 15:00:56 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n"
+      "\u001b[32m2024-01-13 16:29:50 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n"
      ]
     },
     {
@@ -498,10 +498,10 @@
     {
      "data": {
       "text/plain": [
-       "'16:00'"
+       "'17:29'"
       ]
      },
-     "execution_count": 7,
+     "execution_count": 8,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -514,7 +514,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 9,
    "id": "c9d9dbbb",
    "metadata": {},
    "outputs": [
@@ -556,7 +556,7 @@
       "ws_30 ::= [ <U+0009><U+000A>] ws \n",
       "ws_31 ::= ws_30 | \n",
       "\n",
-      "\u001b[32m2024-01-13 15:01:59 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n"
+      "\u001b[32m2024-01-13 16:29:51 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n"
      ]
     },
     {
@@ -569,10 +569,10 @@
     {
      "data": {
       "text/plain": [
-       "'22:02'"
+       "'23:29'"
       ]
      },
-     "execution_count": 8,
+     "execution_count": 9,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -627,7 +627,7 @@
       "ws_30 ::= [ <U+0009><U+000A>] ws \n",
       "ws_31 ::= ws_30 | \n",
       "\n",
-      "\u001b[32m2024-01-13 15:02:49 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n"
+      "\u001b[32m2024-01-13 16:29:53 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n"
      ]
     },
     {
@@ -640,7 +640,7 @@
     {
      "data": {
       "text/plain": [
-       "'22:03'"
+       "'23:30'"
       ]
      },
      "execution_count": 10,
diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py
index 7c7a6816..4fd2c389 100644
--- a/semantic_router/llms/base.py
+++ b/semantic_router/llms/base.py
@@ -13,6 +13,9 @@ class BaseLLM(BaseModel):
     class Config:
         arbitrary_types_allowed = True
 
+    def __init__(self, name: str, **kwargs):
+        super().__init__(name=name, **kwargs)
+
     def __call__(self, messages: List[Message]) -> Optional[str]:
         raise NotImplementedError("Subclasses must implement this method")
 
diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py
index c0bc0a2a..e6787953 100644
--- a/semantic_router/llms/llamacpp.py
+++ b/semantic_router/llms/llamacpp.py
@@ -2,17 +2,25 @@ from contextlib import contextmanager
 from pathlib import Path
 from typing import Any, Optional
 
-from llama_cpp import Llama, LlamaGrammar
+from llama_cpp import Llama, LlamaGrammar, CreateChatCompletionResponse
 
 from semantic_router.llms.base import BaseLLM
 from semantic_router.schema import Message
 from semantic_router.utils.logger import logger
 
 
+class LlamaCppBaseLLM(BaseLLM):
+    def __init__(self, name: str, llm: Llama, temperature: float, max_tokens: int):
+        super().__init__(name)
+        self.llm = llm
+        self.temperature = temperature
+        self.max_tokens = max_tokens
+
+
 class LlamaCppLLM(BaseLLM):
     llm: Llama
     temperature: float
-    max_tokens: int
+    max_tokens: Optional[int] = 200
     grammar: Optional[LlamaGrammar] = None
 
     def __init__(
@@ -20,14 +28,22 @@ class LlamaCppLLM(BaseLLM):
         llm: Llama,
         name: str = "llama.cpp",
         temperature: float = 0.2,
-        max_tokens: int = 200,
+        max_tokens: Optional[int] = 200,
+        grammar: Optional[LlamaGrammar] = None,
     ):
         if not llm:
             raise ValueError("`llama_cpp.Llama` llm is required")
-        super().__init__(name=name, llm=llm, temperature=temperature, max_tokens=max_tokens)
+        super().__init__(
+            name=name,
+            llm=llm,
+            temperature=temperature,
+            max_tokens=max_tokens,
+            grammar=grammar,
+        )
         self.llm = llm
         self.temperature = temperature
         self.max_tokens = max_tokens
+        self.grammar = grammar
 
     def __call__(
         self,
@@ -41,7 +57,7 @@ class LlamaCppLLM(BaseLLM):
                 grammar=self.grammar,
                 stream=False,
             )
-
+            assert type(completion) is CreateChatCompletionResponse
             output = completion["choices"][0]["message"]["content"]
 
             if not output:
diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py
index db47b154..4bcf2a8e 100644
--- a/tests/unit/llms/test_llm_llamacpp.py
+++ b/tests/unit/llms/test_llm_llamacpp.py
@@ -1,10 +1,9 @@
 import pytest
+from llama_cpp import Llama
 
 from semantic_router.llms import LlamaCppLLM
 from semantic_router.schema import Message
 
-from llama_cpp import Llama
-
 
 @pytest.fixture
 def llamacpp_llm(mocker):
-- 
GitLab