Skip to content
Snippets Groups Projects
Unverified Commit 9ae36d70 authored by Grigoriy's avatar Grigoriy Committed by GitHub
Browse files

add baseurl and request timeout to Ollama MML (#11526)

parent 5a7b2636
Branches
Tags
No related merge requests found
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
from ollama import Client
from llama_index.core.base.llms.types import (
ChatMessage,
......@@ -10,7 +12,7 @@ from llama_index.core.base.llms.types import (
CompletionResponseGen,
MessageRole,
)
from llama_index.core.bridge.pydantic import Field
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
from llama_index.core.multi_modal_llms import (
MultiModalLLM,
......@@ -48,6 +50,10 @@ def _messages_to_dicts(messages: Sequence[ChatMessage]) -> Sequence[Dict[str, An
class OllamaMultiModal(MultiModalLLM):
base_url: str = Field(
default="http://localhost:11434",
description="Base url the model is hosted under.",
)
model: str = Field(description="The MultiModal Ollama model to use.")
temperature: float = Field(
default=0.75,
......@@ -60,21 +66,19 @@ class OllamaMultiModal(MultiModalLLM):
description="The maximum number of context tokens for the model.",
gt=0,
)
request_timeout: Optional[float] = Field(
description="The timeout for making http request to Ollama API server",
)
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="Additional model parameters for the Ollama API.",
)
_client: Client = PrivateAttr()
def __init__(self, **kwargs: Any) -> None:
"""Init params."""
# make sure that ollama is installed
try:
import ollama # noqa: F401
except ImportError:
raise ImportError(
"Ollama is not installed. Please install it using `pip install ollama`."
)
"""Init params and ollama client."""
super().__init__(**kwargs)
self._client = Client(host=self.base_url, timeout=self.request_timeout)
@classmethod
def class_name(cls) -> str:
......@@ -103,10 +107,8 @@ class OllamaMultiModal(MultiModalLLM):
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
"""Chat."""
import ollama
ollama_messages = _messages_to_dicts(messages)
response = ollama.chat(
response = self._client.chat(
model=self.model, messages=ollama_messages, stream=False, **kwargs
)
return ChatResponse(
......@@ -123,10 +125,8 @@ class OllamaMultiModal(MultiModalLLM):
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
"""Stream chat."""
import ollama
ollama_messages = _messages_to_dicts(messages)
response = ollama.chat(
response = self._client.chat(
model=self.model, messages=ollama_messages, stream=True, **kwargs
)
text = ""
......@@ -157,9 +157,7 @@ class OllamaMultiModal(MultiModalLLM):
**kwargs: Any,
) -> CompletionResponse:
"""Complete."""
import ollama
response = ollama.generate(
response = self._client.generate(
model=self.model,
prompt=prompt,
images=image_documents_to_base64(image_documents),
......@@ -181,9 +179,7 @@ class OllamaMultiModal(MultiModalLLM):
**kwargs: Any,
) -> CompletionResponseGen:
"""Stream complete."""
import ollama
response = ollama.generate(
response = self._client.generate(
model=self.model,
prompt=prompt,
images=image_documents_to_base64(image_documents),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment