diff --git a/semantic_router/llms/ollama.py b/semantic_router/llms/ollama.py index 5b9dfc7a490aaf1bc434104247696b59300b5518..8aa0045bb7ce6cb9a11c286eb6a7ffe4e72f8bc5 100644 --- a/semantic_router/llms/ollama.py +++ b/semantic_router/llms/ollama.py @@ -9,22 +9,50 @@ from semantic_router.utils.logger import logger class OllamaLLM(BaseLLM): - max_tokens: Optional[int] = 200 + temperature: Optional[float] + llm_name: Optional[str] + max_tokens: Optional[int] + stream: Optional[bool] + def __init__( + self, + name: str = "ollama", + temperature: float = 0.2, + llm_name: str = "openhermes", + max_tokens: Optional[int] = 200, + stream: bool = False, + ): + super().__init__(name=name) + self.temperature = temperature + self.llm_name = llm_name + self.max_tokens = max_tokens + self.stream = stream - def _call_(self, messages: List[Message]) -> str: + def __call__( + self, + messages: List[Message], + temperature: Optional[float] = None, + llm_name: Optional[str] = None, + max_tokens: Optional[int] = None, + stream: Optional[bool] = None + ) -> str: - try: + # Use instance defaults if not overridden + temperature = temperature if temperature is not None else self.temperature + llm_name = llm_name if llm_name is not None else self.llm_name + max_tokens = max_tokens if max_tokens is not None else self.max_tokens + stream = stream if stream is not None else self.stream + try: payload = { - "model": self.name, + "model": llm_name, "messages": [m.to_openai() for m in messages], - "options":{ - "temperature":0.0, - "num_predict":self.max_tokens + "options": { + "temperature": temperature, + "num_predict": max_tokens }, - "format":"json", - "stream":False + "format": "json", + "stream": stream } response = requests.post("http://localhost:11434/api/chat", json=payload)