diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-ollama/llama_index/multi_modal_llms/ollama/base.py b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-ollama/llama_index/multi_modal_llms/ollama/base.py index a1cd409edc75aaa7cd2dccd397723ac88bdd1ff4..c4e44dca3e769de7dd6d71e290343082b12bf808 100644 --- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-ollama/llama_index/multi_modal_llms/ollama/base.py +++ b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-ollama/llama_index/multi_modal_llms/ollama/base.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple + +from ollama import Client from llama_index.core.base.llms.types import ( ChatMessage, @@ -10,7 +12,7 @@ from llama_index.core.base.llms.types import ( CompletionResponseGen, MessageRole, ) -from llama_index.core.bridge.pydantic import Field +from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS from llama_index.core.multi_modal_llms import ( MultiModalLLM, @@ -48,6 +50,10 @@ def _messages_to_dicts(messages: Sequence[ChatMessage]) -> Sequence[Dict[str, An class OllamaMultiModal(MultiModalLLM): + base_url: str = Field( + default="http://localhost:11434", + description="Base url the model is hosted under.", + ) model: str = Field(description="The MultiModal Ollama model to use.") temperature: float = Field( default=0.75, @@ -60,21 +66,19 @@ class OllamaMultiModal(MultiModalLLM): description="The maximum number of context tokens for the model.", gt=0, ) + request_timeout: Optional[float] = Field( + description="The timeout for making http request to Ollama API server", + ) additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional model parameters for the Ollama API.", ) + _client: Client = PrivateAttr() def __init__(self, **kwargs: Any) -> None: - """Init params.""" - # make sure that ollama is installed - try: - import ollama # noqa: F401 - except ImportError: - raise ImportError( - "Ollama is not installed. Please install it using `pip install ollama`." - ) + """Init params and ollama client.""" super().__init__(**kwargs) + self._client = Client(host=self.base_url, timeout=self.request_timeout) @classmethod def class_name(cls) -> str: @@ -103,10 +107,8 @@ class OllamaMultiModal(MultiModalLLM): def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: """Chat.""" - import ollama - ollama_messages = _messages_to_dicts(messages) - response = ollama.chat( + response = self._client.chat( model=self.model, messages=ollama_messages, stream=False, **kwargs ) return ChatResponse( @@ -123,10 +125,8 @@ class OllamaMultiModal(MultiModalLLM): self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: """Stream chat.""" - import ollama - ollama_messages = _messages_to_dicts(messages) - response = ollama.chat( + response = self._client.chat( model=self.model, messages=ollama_messages, stream=True, **kwargs ) text = "" @@ -157,9 +157,7 @@ class OllamaMultiModal(MultiModalLLM): **kwargs: Any, ) -> CompletionResponse: """Complete.""" - import ollama - - response = ollama.generate( + response = self._client.generate( model=self.model, prompt=prompt, images=image_documents_to_base64(image_documents), @@ -181,9 +179,7 @@ class OllamaMultiModal(MultiModalLLM): **kwargs: Any, ) -> CompletionResponseGen: """Stream complete.""" - import ollama - - response = ollama.generate( + response = self._client.generate( model=self.model, prompt=prompt, images=image_documents_to_base64(image_documents),