From 866915670186448e112623c3b1570a6817810be6 Mon Sep 17 00:00:00 2001
From: jamescalam <james.briggs@hotmail.com>
Date: Sun, 1 Dec 2024 13:19:04 +0100
Subject: [PATCH] fix: types

---
 semantic_router/llms/base.py                  |  2 ++
 semantic_router/llms/llamacpp.py              |  2 --
 semantic_router/llms/mistral.py               |  2 --
 semantic_router/llms/ollama.py                | 15 +++++----------
 semantic_router/llms/openai.py                |  2 --
 semantic_router/llms/openrouter.py            |  2 --
 semantic_router/llms/zure.py                  | 11 +++++------
 tests/unit/encoders/test_fastembed.py         |  3 +++
 tests/unit/encoders/test_hfendpointencoder.py |  1 +
 tests/unit/encoders/test_huggingface.py       |  4 +++-
 tests/unit/encoders/test_mistral.py           |  1 +
 tests/unit/llms/test_llm_ollama.py            |  3 +--
 12 files changed, 21 insertions(+), 27 deletions(-)

diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py
index 02ea2777..02df2764 100644
--- a/semantic_router/llms/base.py
+++ b/semantic_router/llms/base.py
@@ -9,6 +9,8 @@ from semantic_router.utils.logger import logger
 
 class BaseLLM(BaseModel):
     name: str
+    temperature: Optional[float] = 0.0
+    max_tokens: Optional[int] = None
 
     class Config:
         arbitrary_types_allowed = True
diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py
index 8431be14..8257593c 100644
--- a/semantic_router/llms/llamacpp.py
+++ b/semantic_router/llms/llamacpp.py
@@ -11,8 +11,6 @@ from semantic_router.utils.logger import logger
 
 class LlamaCppLLM(BaseLLM):
     llm: Any
-    temperature: float
-    max_tokens: Optional[int] = 200
     grammar: Optional[Any] = None
     _llama_cpp: Any = PrivateAttr()
 
diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py
index 732cd7b1..370fba65 100644
--- a/semantic_router/llms/mistral.py
+++ b/semantic_router/llms/mistral.py
@@ -11,8 +11,6 @@ from semantic_router.utils.logger import logger
 
 class MistralAILLM(BaseLLM):
     _client: Any = PrivateAttr()
-    temperature: Optional[float]
-    max_tokens: Optional[int]
     _mistralai: Any = PrivateAttr()
 
     def __init__(
diff --git a/semantic_router/llms/ollama.py b/semantic_router/llms/ollama.py
index df35ac06..f6e9779e 100644
--- a/semantic_router/llms/ollama.py
+++ b/semantic_router/llms/ollama.py
@@ -8,22 +8,17 @@ from semantic_router.utils.logger import logger
 
 
 class OllamaLLM(BaseLLM):
-    temperature: Optional[float]
-    llm_name: Optional[str]
-    max_tokens: Optional[int]
-    stream: Optional[bool]
+    stream: bool = False
 
     def __init__(
         self,
-        name: str = "ollama",
+        name: str = "openhermes",
         temperature: float = 0.2,
-        llm_name: str = "openhermes",
         max_tokens: Optional[int] = 200,
         stream: bool = False,
     ):
         super().__init__(name=name)
         self.temperature = temperature
-        self.llm_name = llm_name
         self.max_tokens = max_tokens
         self.stream = stream
 
@@ -31,19 +26,19 @@ class OllamaLLM(BaseLLM):
         self,
         messages: List[Message],
         temperature: Optional[float] = None,
-        llm_name: Optional[str] = None,
+        name: Optional[str] = None,
         max_tokens: Optional[int] = None,
         stream: Optional[bool] = None,
     ) -> str:
         # Use instance defaults if not overridden
         temperature = temperature if temperature is not None else self.temperature
-        llm_name = llm_name if llm_name is not None else self.llm_name
+        name = name if name is not None else self.name
         max_tokens = max_tokens if max_tokens is not None else self.max_tokens
         stream = stream if stream is not None else self.stream
 
         try:
             payload = {
-                "model": llm_name,
+                "model": name,
                 "messages": [m.to_openai() for m in messages],
                 "options": {"temperature": temperature, "num_predict": max_tokens},
                 "format": "json",
diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py
index dfff8096..3465eec9 100644
--- a/semantic_router/llms/openai.py
+++ b/semantic_router/llms/openai.py
@@ -23,8 +23,6 @@ from openai.types.chat.chat_completion_message_tool_call import (
 class OpenAILLM(BaseLLM):
     client: Optional[openai.OpenAI]
     async_client: Optional[openai.AsyncOpenAI]
-    temperature: Optional[float]
-    max_tokens: Optional[int]
 
     def __init__(
         self,
diff --git a/semantic_router/llms/openrouter.py b/semantic_router/llms/openrouter.py
index b00d68a4..608834dd 100644
--- a/semantic_router/llms/openrouter.py
+++ b/semantic_router/llms/openrouter.py
@@ -11,8 +11,6 @@ from semantic_router.utils.logger import logger
 class OpenRouterLLM(BaseLLM):
     client: Optional[openai.OpenAI]
     base_url: Optional[str]
-    temperature: Optional[float]
-    max_tokens: Optional[int]
 
     def __init__(
         self,
diff --git a/semantic_router/llms/zure.py b/semantic_router/llms/zure.py
index 26b7901f..ba833d04 100644
--- a/semantic_router/llms/zure.py
+++ b/semantic_router/llms/zure.py
@@ -1,5 +1,6 @@
 import os
 from typing import List, Optional
+from pydantic import PrivateAttr
 
 import openai
 
@@ -10,9 +11,7 @@ from semantic_router.utils.logger import logger
 
 
 class AzureOpenAILLM(BaseLLM):
-    client: Optional[openai.AzureOpenAI]
-    temperature: Optional[float]
-    max_tokens: Optional[int]
+    _client: Optional[openai.AzureOpenAI] = PrivateAttr(default=None)
 
     def __init__(
         self,
@@ -33,7 +32,7 @@ class AzureOpenAILLM(BaseLLM):
         if azure_endpoint is None:
             raise ValueError("Azure endpoint API key cannot be 'None'.")
         try:
-            self.client = openai.AzureOpenAI(
+            self._client = openai.AzureOpenAI(
                 api_key=api_key, azure_endpoint=azure_endpoint, api_version=api_version
             )
         except Exception as e:
@@ -42,10 +41,10 @@ class AzureOpenAILLM(BaseLLM):
         self.max_tokens = max_tokens
 
     def __call__(self, messages: List[Message]) -> str:
-        if self.client is None:
+        if self._client is None:
             raise ValueError("AzureOpenAI client is not initialized.")
         try:
-            completion = self.client.chat.completions.create(
+            completion = self._client.chat.completions.create(
                 model=self.name,
                 messages=[m.to_openai() for m in messages],
                 temperature=self.temperature,
diff --git a/tests/unit/encoders/test_fastembed.py b/tests/unit/encoders/test_fastembed.py
index 9b0f3229..35c05111 100644
--- a/tests/unit/encoders/test_fastembed.py
+++ b/tests/unit/encoders/test_fastembed.py
@@ -1,5 +1,8 @@
 from semantic_router.encoders import FastEmbedEncoder
 
+import pytest
+
+_ = pytest.importorskip("fastembed")
 
 class TestFastEmbedEncoder:
     def test_fastembed_encoder(self):
diff --git a/tests/unit/encoders/test_hfendpointencoder.py b/tests/unit/encoders/test_hfendpointencoder.py
index cb8dd16a..840df2a9 100644
--- a/tests/unit/encoders/test_hfendpointencoder.py
+++ b/tests/unit/encoders/test_hfendpointencoder.py
@@ -1,4 +1,5 @@
 import pytest
+
 from semantic_router.encoders.huggingface import HFEndpointEncoder
 
 
diff --git a/tests/unit/encoders/test_huggingface.py b/tests/unit/encoders/test_huggingface.py
index b615a87d..7f496e39 100644
--- a/tests/unit/encoders/test_huggingface.py
+++ b/tests/unit/encoders/test_huggingface.py
@@ -4,7 +4,9 @@ import os
 import numpy as np
 import pytest
 
-from semantic_router.encoders.huggingface import HuggingFaceEncoder
+_ = pytest.importorskip("transformers")
+
+from semantic_router.encoders.huggingface import HuggingFaceEncoder  # noqa: E402
 
 test_model_name = "aurelio-ai/sr-test-huggingface"
 
diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py
index 25dba6b7..a236ad16 100644
--- a/tests/unit/encoders/test_mistral.py
+++ b/tests/unit/encoders/test_mistral.py
@@ -1,6 +1,7 @@
 from unittest.mock import patch
 
 import pytest
+
 from mistralai.exceptions import MistralException
 from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, UsageInfo
 
diff --git a/tests/unit/llms/test_llm_ollama.py b/tests/unit/llms/test_llm_ollama.py
index 369e5f44..91af8a3c 100644
--- a/tests/unit/llms/test_llm_ollama.py
+++ b/tests/unit/llms/test_llm_ollama.py
@@ -11,9 +11,8 @@ def ollama_llm():
 
 class TestOllamaLLM:
     def test_ollama_llm_init_success(self, ollama_llm):
-        assert ollama_llm.name == "ollama"
         assert ollama_llm.temperature == 0.2
-        assert ollama_llm.llm_name == "openhermes"
+        assert ollama_llm.name == "openhermes"
         assert ollama_llm.max_tokens == 200
         assert ollama_llm.stream is False
 
-- 
GitLab