From d4d29a4522ee4878611695d223169d3e31c22362 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Wed, 21 Feb 2024 17:04:31 +0400
Subject: [PATCH] OllamaLLM Updates

Ensured that we have default values for attributes name, temperature, llm_name, max_tokens and stream.

User can choose to alter the values that are actually used on the fly via new optional arguments in  __call__.

Also _call_ changed to __call__, in line with other LLMs.

Previously self.name was being used to identify the LLM to be called, but self.name is intended to identify the OllamaLLM instance as an instance of OllamaLLM. So it's now set to "ollama".
---
 semantic_router/llms/ollama.py | 46 +++++++++++++++++++++++++++-------
 1 file changed, 37 insertions(+), 9 deletions(-)

diff --git a/semantic_router/llms/ollama.py b/semantic_router/llms/ollama.py
index 5b9dfc7a..8aa0045b 100644
--- a/semantic_router/llms/ollama.py
+++ b/semantic_router/llms/ollama.py
@@ -9,22 +9,50 @@ from semantic_router.utils.logger import logger
 
 
 class OllamaLLM(BaseLLM):
-    max_tokens: Optional[int] = 200
+    temperature: Optional[float]
+    llm_name: Optional[str]
+    max_tokens: Optional[int]
+    stream: Optional[bool]
 
+    def __init__(
+        self,
+        name: str = "ollama",
+        temperature: float = 0.2,
+        llm_name: str = "openhermes",
+        max_tokens: Optional[int] = 200,
+        stream: bool = False,
+    ):
+        super().__init__(name=name)
+        self.temperature = temperature
+        self.llm_name = llm_name
+        self.max_tokens = max_tokens
+        self.stream = stream
 
-    def _call_(self, messages: List[Message]) -> str:
+    def __call__(
+        self, 
+        messages: List[Message], 
+        temperature: Optional[float] = None, 
+        llm_name: Optional[str] = None, 
+        max_tokens: Optional[int] = None, 
+        stream: Optional[bool] = None
+    ) -> str:
         
-        try:
+        # Use instance defaults if not overridden
+        temperature = temperature if temperature is not None else self.temperature
+        llm_name = llm_name if llm_name is not None else self.llm_name
+        max_tokens = max_tokens if max_tokens is not None else self.max_tokens
+        stream = stream if stream is not None else self.stream
 
+        try:
             payload = {
-                "model": self.name,
+                "model": llm_name,
                 "messages": [m.to_openai() for m in messages],
-                "options":{
-                    "temperature":0.0,
-                    "num_predict":self.max_tokens
+                "options": {
+                    "temperature": temperature,
+                    "num_predict": max_tokens
                 },
-                "format":"json",
-                "stream":False
+                "format": "json",
+                "stream": stream
             }
 
             response = requests.post("http://localhost:11434/api/chat", json=payload)
-- 
GitLab