diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 02ea277765e7c9bcfe031e6c743f8bce3ffb0f9e..02df2764502d2d2bdd70135406c5e27c5a900fda 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -9,6 +9,8 @@ from semantic_router.utils.logger import logger class BaseLLM(BaseModel): name: str + temperature: Optional[float] = 0.0 + max_tokens: Optional[int] = None class Config: arbitrary_types_allowed = True diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index 8431be14f93753abef286d66c36904c85ed8873e..8257593c6f250a8bc1ded5f3af0dbc30a8b44141 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -11,8 +11,6 @@ from semantic_router.utils.logger import logger class LlamaCppLLM(BaseLLM): llm: Any - temperature: float - max_tokens: Optional[int] = 200 grammar: Optional[Any] = None _llama_cpp: Any = PrivateAttr() diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index 732cd7b17f9ca09ac4104c70bd7883c3ff8b970a..370fba65756224ded1af6841382f02f63578851d 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -11,8 +11,6 @@ from semantic_router.utils.logger import logger class MistralAILLM(BaseLLM): _client: Any = PrivateAttr() - temperature: Optional[float] - max_tokens: Optional[int] _mistralai: Any = PrivateAttr() def __init__( diff --git a/semantic_router/llms/ollama.py b/semantic_router/llms/ollama.py index df35ac06a97748ece6a2da86fad58bb7e4a0a52f..f6e9779ea65197e08f47135af8fec0c2dba8f295 100644 --- a/semantic_router/llms/ollama.py +++ b/semantic_router/llms/ollama.py @@ -8,22 +8,17 @@ from semantic_router.utils.logger import logger class OllamaLLM(BaseLLM): - temperature: Optional[float] - llm_name: Optional[str] - max_tokens: Optional[int] - stream: Optional[bool] + stream: bool = False def __init__( self, - name: str = "ollama", + name: str = "openhermes", temperature: float = 0.2, - llm_name: str = "openhermes", max_tokens: Optional[int] = 200, stream: bool = False, ): super().__init__(name=name) self.temperature = temperature - self.llm_name = llm_name self.max_tokens = max_tokens self.stream = stream @@ -31,19 +26,19 @@ class OllamaLLM(BaseLLM): self, messages: List[Message], temperature: Optional[float] = None, - llm_name: Optional[str] = None, + name: Optional[str] = None, max_tokens: Optional[int] = None, stream: Optional[bool] = None, ) -> str: # Use instance defaults if not overridden temperature = temperature if temperature is not None else self.temperature - llm_name = llm_name if llm_name is not None else self.llm_name + name = name if name is not None else self.name max_tokens = max_tokens if max_tokens is not None else self.max_tokens stream = stream if stream is not None else self.stream try: payload = { - "model": llm_name, + "model": name, "messages": [m.to_openai() for m in messages], "options": {"temperature": temperature, "num_predict": max_tokens}, "format": "json", diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index dfff80968a588cc1dbf387fc049219b64718f9ec..3465eec9a32765d1c68ad0f53ccdbb59cef2235b 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -23,8 +23,6 @@ from openai.types.chat.chat_completion_message_tool_call import ( class OpenAILLM(BaseLLM): client: Optional[openai.OpenAI] async_client: Optional[openai.AsyncOpenAI] - temperature: Optional[float] - max_tokens: Optional[int] def __init__( self, diff --git a/semantic_router/llms/openrouter.py b/semantic_router/llms/openrouter.py index b00d68a4730c6bf681c0ba4d90a5a79c7febe603..608834dd8679cba506e84665f559df9a22078619 100644 --- a/semantic_router/llms/openrouter.py +++ b/semantic_router/llms/openrouter.py @@ -11,8 +11,6 @@ from semantic_router.utils.logger import logger class OpenRouterLLM(BaseLLM): client: Optional[openai.OpenAI] base_url: Optional[str] - temperature: Optional[float] - max_tokens: Optional[int] def __init__( self, diff --git a/semantic_router/llms/zure.py b/semantic_router/llms/zure.py index 26b7901f7b595ed1705a81484598f4389a242d26..ba833d044785b775b0da123db554a79f55dd5be8 100644 --- a/semantic_router/llms/zure.py +++ b/semantic_router/llms/zure.py @@ -1,5 +1,6 @@ import os from typing import List, Optional +from pydantic import PrivateAttr import openai @@ -10,9 +11,7 @@ from semantic_router.utils.logger import logger class AzureOpenAILLM(BaseLLM): - client: Optional[openai.AzureOpenAI] - temperature: Optional[float] - max_tokens: Optional[int] + _client: Optional[openai.AzureOpenAI] = PrivateAttr(default=None) def __init__( self, @@ -33,7 +32,7 @@ class AzureOpenAILLM(BaseLLM): if azure_endpoint is None: raise ValueError("Azure endpoint API key cannot be 'None'.") try: - self.client = openai.AzureOpenAI( + self._client = openai.AzureOpenAI( api_key=api_key, azure_endpoint=azure_endpoint, api_version=api_version ) except Exception as e: @@ -42,10 +41,10 @@ class AzureOpenAILLM(BaseLLM): self.max_tokens = max_tokens def __call__(self, messages: List[Message]) -> str: - if self.client is None: + if self._client is None: raise ValueError("AzureOpenAI client is not initialized.") try: - completion = self.client.chat.completions.create( + completion = self._client.chat.completions.create( model=self.name, messages=[m.to_openai() for m in messages], temperature=self.temperature, diff --git a/tests/unit/encoders/test_fastembed.py b/tests/unit/encoders/test_fastembed.py index 9b0f32296117db7f9a16ea09284082ce5a24498f..35c05111e9ed37e1982fce3a0ef090169a6d4351 100644 --- a/tests/unit/encoders/test_fastembed.py +++ b/tests/unit/encoders/test_fastembed.py @@ -1,5 +1,8 @@ from semantic_router.encoders import FastEmbedEncoder +import pytest + +_ = pytest.importorskip("fastembed") class TestFastEmbedEncoder: def test_fastembed_encoder(self): diff --git a/tests/unit/encoders/test_hfendpointencoder.py b/tests/unit/encoders/test_hfendpointencoder.py index cb8dd16a661c4194a76935d51390ed4f5b3d3fcc..840df2a9fd0f27ce2cd0dab4ca46b15c994cd5ce 100644 --- a/tests/unit/encoders/test_hfendpointencoder.py +++ b/tests/unit/encoders/test_hfendpointencoder.py @@ -1,4 +1,5 @@ import pytest + from semantic_router.encoders.huggingface import HFEndpointEncoder diff --git a/tests/unit/encoders/test_huggingface.py b/tests/unit/encoders/test_huggingface.py index b615a87d13d64d0a45f6542f137d31ef9fc9a05f..7f496e39f8117dbe4242070e59517acf751ca9a0 100644 --- a/tests/unit/encoders/test_huggingface.py +++ b/tests/unit/encoders/test_huggingface.py @@ -4,7 +4,9 @@ import os import numpy as np import pytest -from semantic_router.encoders.huggingface import HuggingFaceEncoder +_ = pytest.importorskip("transformers") + +from semantic_router.encoders.huggingface import HuggingFaceEncoder # noqa: E402 test_model_name = "aurelio-ai/sr-test-huggingface" diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py index 25dba6b759bc2a033c64ae0055e0b1b45d626f48..a236ad167139b5d5994ed8ad2d3610e4b49fc349 100644 --- a/tests/unit/encoders/test_mistral.py +++ b/tests/unit/encoders/test_mistral.py @@ -1,6 +1,7 @@ from unittest.mock import patch import pytest + from mistralai.exceptions import MistralException from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, UsageInfo diff --git a/tests/unit/llms/test_llm_ollama.py b/tests/unit/llms/test_llm_ollama.py index 369e5f4444789c3bb70988455bb49c7df4792445..91af8a3c7b3a565ade17c5ecb529b080ecebf978 100644 --- a/tests/unit/llms/test_llm_ollama.py +++ b/tests/unit/llms/test_llm_ollama.py @@ -11,9 +11,8 @@ def ollama_llm(): class TestOllamaLLM: def test_ollama_llm_init_success(self, ollama_llm): - assert ollama_llm.name == "ollama" assert ollama_llm.temperature == 0.2 - assert ollama_llm.llm_name == "openhermes" + assert ollama_llm.name == "openhermes" assert ollama_llm.max_tokens == 200 assert ollama_llm.stream is False