From 9575e7fe150747ac469163777de16eb320b02559 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sat, 13 Jan 2024 16:36:29 +0000 Subject: [PATCH] lint --- docs/05-local-execution.ipynb | 48 ++++++++++++++-------------- semantic_router/llms/base.py | 3 ++ semantic_router/llms/llamacpp.py | 26 ++++++++++++--- tests/unit/llms/test_llm_llamacpp.py | 3 +- 4 files changed, 49 insertions(+), 31 deletions(-) diff --git a/docs/05-local-execution.ipynb b/docs/05-local-execution.ipynb index f03f77ae..4e3fbe6f 100644 --- a/docs/05-local-execution.ipynb +++ b/docs/05-local-execution.ipynb @@ -44,12 +44,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "4f9b5729", "metadata": {}, "outputs": [], "source": [ - "#!CMAKE_ARGS=\"-DLLAMA_METAL=on\"" + "!CMAKE_ARGS=\"-DLLAMA_METAL=on\"" ] }, { @@ -85,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "e26db664-9dff-476a-84ef-edd7a8cdf1ba", "metadata": {}, "outputs": [], @@ -149,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "fac95b0c-c61f-4158-b7d9-0221f7d0b65e", "metadata": {}, "outputs": [ @@ -162,7 +162,7 @@ " 'output': \"<class 'str'>\"}" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -183,7 +183,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "5253c141-141b-4fda-b07c-a313393902ed", "metadata": {}, "outputs": [ @@ -222,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "772cec0d-7a0c-4c7e-9b7a-4a1864b0a8ec", "metadata": { "scrolled": true @@ -324,7 +324,7 @@ "llama_build_graph: non-view tensors processed: 676/676\n", "llama_new_context_with_model: compute buffer total size = 159.19 MiB\n", "ggml_backend_metal_buffer_type_alloc_buffer: allocated buffer, size = 156.02 MiB, ( 4332.22 / 21845.34)\n", - "\u001b[32m2024-01-13 14:48:23 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + "\u001b[32m2024-01-13 16:29:46 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" ] } ], @@ -349,7 +349,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "a8bd1da4-8ff7-4cd3-a5e3-fd79a938cc67", "metadata": {}, "outputs": [ @@ -359,7 +359,7 @@ "RouteChoice(name='chitchat', function_call=None, similarity_score=None, trigger=None)" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -370,7 +370,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "c6ccbea2-376b-4b28-9b79-d2e9c71e99f4", "metadata": { "scrolled": true @@ -414,7 +414,7 @@ "ws_30 ::= [ <U+0009><U+000A>] ws \n", "ws_31 ::= ws_30 | \n", "\n", - "\u001b[32m2024-01-13 14:51:39 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" + "\u001b[32m2024-01-13 16:29:47 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" ] }, { @@ -427,10 +427,10 @@ { "data": { "text/plain": [ - "'09:51'" + "'11:29'" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -443,7 +443,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "720f976a", "metadata": {}, "outputs": [ @@ -485,7 +485,7 @@ "ws_30 ::= [ <U+0009><U+000A>] ws \n", "ws_31 ::= ws_30 | \n", "\n", - "\u001b[32m2024-01-13 15:00:56 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" + "\u001b[32m2024-01-13 16:29:50 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" ] }, { @@ -498,10 +498,10 @@ { "data": { "text/plain": [ - "'16:00'" + "'17:29'" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -514,7 +514,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "c9d9dbbb", "metadata": {}, "outputs": [ @@ -556,7 +556,7 @@ "ws_30 ::= [ <U+0009><U+000A>] ws \n", "ws_31 ::= ws_30 | \n", "\n", - "\u001b[32m2024-01-13 15:01:59 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" + "\u001b[32m2024-01-13 16:29:51 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" ] }, { @@ -569,10 +569,10 @@ { "data": { "text/plain": [ - "'22:02'" + "'23:29'" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -627,7 +627,7 @@ "ws_30 ::= [ <U+0009><U+000A>] ws \n", "ws_31 ::= ws_30 | \n", "\n", - "\u001b[32m2024-01-13 15:02:49 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" + "\u001b[32m2024-01-13 16:29:53 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" ] }, { @@ -640,7 +640,7 @@ { "data": { "text/plain": [ - "'22:03'" + "'23:30'" ] }, "execution_count": 10, diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 7c7a6816..4fd2c389 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -13,6 +13,9 @@ class BaseLLM(BaseModel): class Config: arbitrary_types_allowed = True + def __init__(self, name: str, **kwargs): + super().__init__(name=name, **kwargs) + def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index c0bc0a2a..e6787953 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -2,17 +2,25 @@ from contextlib import contextmanager from pathlib import Path from typing import Any, Optional -from llama_cpp import Llama, LlamaGrammar +from llama_cpp import Llama, LlamaGrammar, CreateChatCompletionResponse from semantic_router.llms.base import BaseLLM from semantic_router.schema import Message from semantic_router.utils.logger import logger +class LlamaCppBaseLLM(BaseLLM): + def __init__(self, name: str, llm: Llama, temperature: float, max_tokens: int): + super().__init__(name) + self.llm = llm + self.temperature = temperature + self.max_tokens = max_tokens + + class LlamaCppLLM(BaseLLM): llm: Llama temperature: float - max_tokens: int + max_tokens: Optional[int] = 200 grammar: Optional[LlamaGrammar] = None def __init__( @@ -20,14 +28,22 @@ class LlamaCppLLM(BaseLLM): llm: Llama, name: str = "llama.cpp", temperature: float = 0.2, - max_tokens: int = 200, + max_tokens: Optional[int] = 200, + grammar: Optional[LlamaGrammar] = None, ): if not llm: raise ValueError("`llama_cpp.Llama` llm is required") - super().__init__(name=name, llm=llm, temperature=temperature, max_tokens=max_tokens) + super().__init__( + name=name, + llm=llm, + temperature=temperature, + max_tokens=max_tokens, + grammar=grammar, + ) self.llm = llm self.temperature = temperature self.max_tokens = max_tokens + self.grammar = grammar def __call__( self, @@ -41,7 +57,7 @@ class LlamaCppLLM(BaseLLM): grammar=self.grammar, stream=False, ) - + assert type(completion) is CreateChatCompletionResponse output = completion["choices"][0]["message"]["content"] if not output: diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py index db47b154..4bcf2a8e 100644 --- a/tests/unit/llms/test_llm_llamacpp.py +++ b/tests/unit/llms/test_llm_llamacpp.py @@ -1,10 +1,9 @@ import pytest +from llama_cpp import Llama from semantic_router.llms import LlamaCppLLM from semantic_router.schema import Message -from llama_cpp import Llama - @pytest.fixture def llamacpp_llm(mocker): -- GitLab