From 866915670186448e112623c3b1570a6817810be6 Mon Sep 17 00:00:00 2001 From: jamescalam <james.briggs@hotmail.com> Date: Sun, 1 Dec 2024 13:19:04 +0100 Subject: [PATCH] fix: types --- semantic_router/llms/base.py | 2 ++ semantic_router/llms/llamacpp.py | 2 -- semantic_router/llms/mistral.py | 2 -- semantic_router/llms/ollama.py | 15 +++++---------- semantic_router/llms/openai.py | 2 -- semantic_router/llms/openrouter.py | 2 -- semantic_router/llms/zure.py | 11 +++++------ tests/unit/encoders/test_fastembed.py | 3 +++ tests/unit/encoders/test_hfendpointencoder.py | 1 + tests/unit/encoders/test_huggingface.py | 4 +++- tests/unit/encoders/test_mistral.py | 1 + tests/unit/llms/test_llm_ollama.py | 3 +-- 12 files changed, 21 insertions(+), 27 deletions(-) diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 02ea2777..02df2764 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -9,6 +9,8 @@ from semantic_router.utils.logger import logger class BaseLLM(BaseModel): name: str + temperature: Optional[float] = 0.0 + max_tokens: Optional[int] = None class Config: arbitrary_types_allowed = True diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index 8431be14..8257593c 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -11,8 +11,6 @@ from semantic_router.utils.logger import logger class LlamaCppLLM(BaseLLM): llm: Any - temperature: float - max_tokens: Optional[int] = 200 grammar: Optional[Any] = None _llama_cpp: Any = PrivateAttr() diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index 732cd7b1..370fba65 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -11,8 +11,6 @@ from semantic_router.utils.logger import logger class MistralAILLM(BaseLLM): _client: Any = PrivateAttr() - temperature: Optional[float] - max_tokens: Optional[int] _mistralai: Any = PrivateAttr() def __init__( diff --git a/semantic_router/llms/ollama.py b/semantic_router/llms/ollama.py index df35ac06..f6e9779e 100644 --- a/semantic_router/llms/ollama.py +++ b/semantic_router/llms/ollama.py @@ -8,22 +8,17 @@ from semantic_router.utils.logger import logger class OllamaLLM(BaseLLM): - temperature: Optional[float] - llm_name: Optional[str] - max_tokens: Optional[int] - stream: Optional[bool] + stream: bool = False def __init__( self, - name: str = "ollama", + name: str = "openhermes", temperature: float = 0.2, - llm_name: str = "openhermes", max_tokens: Optional[int] = 200, stream: bool = False, ): super().__init__(name=name) self.temperature = temperature - self.llm_name = llm_name self.max_tokens = max_tokens self.stream = stream @@ -31,19 +26,19 @@ class OllamaLLM(BaseLLM): self, messages: List[Message], temperature: Optional[float] = None, - llm_name: Optional[str] = None, + name: Optional[str] = None, max_tokens: Optional[int] = None, stream: Optional[bool] = None, ) -> str: # Use instance defaults if not overridden temperature = temperature if temperature is not None else self.temperature - llm_name = llm_name if llm_name is not None else self.llm_name + name = name if name is not None else self.name max_tokens = max_tokens if max_tokens is not None else self.max_tokens stream = stream if stream is not None else self.stream try: payload = { - "model": llm_name, + "model": name, "messages": [m.to_openai() for m in messages], "options": {"temperature": temperature, "num_predict": max_tokens}, "format": "json", diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index dfff8096..3465eec9 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -23,8 +23,6 @@ from openai.types.chat.chat_completion_message_tool_call import ( class OpenAILLM(BaseLLM): client: Optional[openai.OpenAI] async_client: Optional[openai.AsyncOpenAI] - temperature: Optional[float] - max_tokens: Optional[int] def __init__( self, diff --git a/semantic_router/llms/openrouter.py b/semantic_router/llms/openrouter.py index b00d68a4..608834dd 100644 --- a/semantic_router/llms/openrouter.py +++ b/semantic_router/llms/openrouter.py @@ -11,8 +11,6 @@ from semantic_router.utils.logger import logger class OpenRouterLLM(BaseLLM): client: Optional[openai.OpenAI] base_url: Optional[str] - temperature: Optional[float] - max_tokens: Optional[int] def __init__( self, diff --git a/semantic_router/llms/zure.py b/semantic_router/llms/zure.py index 26b7901f..ba833d04 100644 --- a/semantic_router/llms/zure.py +++ b/semantic_router/llms/zure.py @@ -1,5 +1,6 @@ import os from typing import List, Optional +from pydantic import PrivateAttr import openai @@ -10,9 +11,7 @@ from semantic_router.utils.logger import logger class AzureOpenAILLM(BaseLLM): - client: Optional[openai.AzureOpenAI] - temperature: Optional[float] - max_tokens: Optional[int] + _client: Optional[openai.AzureOpenAI] = PrivateAttr(default=None) def __init__( self, @@ -33,7 +32,7 @@ class AzureOpenAILLM(BaseLLM): if azure_endpoint is None: raise ValueError("Azure endpoint API key cannot be 'None'.") try: - self.client = openai.AzureOpenAI( + self._client = openai.AzureOpenAI( api_key=api_key, azure_endpoint=azure_endpoint, api_version=api_version ) except Exception as e: @@ -42,10 +41,10 @@ class AzureOpenAILLM(BaseLLM): self.max_tokens = max_tokens def __call__(self, messages: List[Message]) -> str: - if self.client is None: + if self._client is None: raise ValueError("AzureOpenAI client is not initialized.") try: - completion = self.client.chat.completions.create( + completion = self._client.chat.completions.create( model=self.name, messages=[m.to_openai() for m in messages], temperature=self.temperature, diff --git a/tests/unit/encoders/test_fastembed.py b/tests/unit/encoders/test_fastembed.py index 9b0f3229..35c05111 100644 --- a/tests/unit/encoders/test_fastembed.py +++ b/tests/unit/encoders/test_fastembed.py @@ -1,5 +1,8 @@ from semantic_router.encoders import FastEmbedEncoder +import pytest + +_ = pytest.importorskip("fastembed") class TestFastEmbedEncoder: def test_fastembed_encoder(self): diff --git a/tests/unit/encoders/test_hfendpointencoder.py b/tests/unit/encoders/test_hfendpointencoder.py index cb8dd16a..840df2a9 100644 --- a/tests/unit/encoders/test_hfendpointencoder.py +++ b/tests/unit/encoders/test_hfendpointencoder.py @@ -1,4 +1,5 @@ import pytest + from semantic_router.encoders.huggingface import HFEndpointEncoder diff --git a/tests/unit/encoders/test_huggingface.py b/tests/unit/encoders/test_huggingface.py index b615a87d..7f496e39 100644 --- a/tests/unit/encoders/test_huggingface.py +++ b/tests/unit/encoders/test_huggingface.py @@ -4,7 +4,9 @@ import os import numpy as np import pytest -from semantic_router.encoders.huggingface import HuggingFaceEncoder +_ = pytest.importorskip("transformers") + +from semantic_router.encoders.huggingface import HuggingFaceEncoder # noqa: E402 test_model_name = "aurelio-ai/sr-test-huggingface" diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py index 25dba6b7..a236ad16 100644 --- a/tests/unit/encoders/test_mistral.py +++ b/tests/unit/encoders/test_mistral.py @@ -1,6 +1,7 @@ from unittest.mock import patch import pytest + from mistralai.exceptions import MistralException from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, UsageInfo diff --git a/tests/unit/llms/test_llm_ollama.py b/tests/unit/llms/test_llm_ollama.py index 369e5f44..91af8a3c 100644 --- a/tests/unit/llms/test_llm_ollama.py +++ b/tests/unit/llms/test_llm_ollama.py @@ -11,9 +11,8 @@ def ollama_llm(): class TestOllamaLLM: def test_ollama_llm_init_success(self, ollama_llm): - assert ollama_llm.name == "ollama" assert ollama_llm.temperature == 0.2 - assert ollama_llm.llm_name == "openhermes" + assert ollama_llm.name == "openhermes" assert ollama_llm.max_tokens == 200 assert ollama_llm.stream is False -- GitLab