Skip to content
Snippets Groups Projects
Unverified Commit 9ae36d70 authored by Grigoriy's avatar Grigoriy Committed by GitHub
Browse files

add baseurl and request timeout to Ollama MML (#11526)

parent 5a7b2636
Branches
Tags
No related merge requests found
from typing import Any, Dict, Sequence, Tuple from typing import Any, Dict, Optional, Sequence, Tuple
from ollama import Client
from llama_index.core.base.llms.types import ( from llama_index.core.base.llms.types import (
ChatMessage, ChatMessage,
...@@ -10,7 +12,7 @@ from llama_index.core.base.llms.types import ( ...@@ -10,7 +12,7 @@ from llama_index.core.base.llms.types import (
CompletionResponseGen, CompletionResponseGen,
MessageRole, MessageRole,
) )
from llama_index.core.bridge.pydantic import Field from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
from llama_index.core.multi_modal_llms import ( from llama_index.core.multi_modal_llms import (
MultiModalLLM, MultiModalLLM,
...@@ -48,6 +50,10 @@ def _messages_to_dicts(messages: Sequence[ChatMessage]) -> Sequence[Dict[str, An ...@@ -48,6 +50,10 @@ def _messages_to_dicts(messages: Sequence[ChatMessage]) -> Sequence[Dict[str, An
class OllamaMultiModal(MultiModalLLM): class OllamaMultiModal(MultiModalLLM):
base_url: str = Field(
default="http://localhost:11434",
description="Base url the model is hosted under.",
)
model: str = Field(description="The MultiModal Ollama model to use.") model: str = Field(description="The MultiModal Ollama model to use.")
temperature: float = Field( temperature: float = Field(
default=0.75, default=0.75,
...@@ -60,21 +66,19 @@ class OllamaMultiModal(MultiModalLLM): ...@@ -60,21 +66,19 @@ class OllamaMultiModal(MultiModalLLM):
description="The maximum number of context tokens for the model.", description="The maximum number of context tokens for the model.",
gt=0, gt=0,
) )
request_timeout: Optional[float] = Field(
description="The timeout for making http request to Ollama API server",
)
additional_kwargs: Dict[str, Any] = Field( additional_kwargs: Dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Additional model parameters for the Ollama API.", description="Additional model parameters for the Ollama API.",
) )
_client: Client = PrivateAttr()
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
"""Init params.""" """Init params and ollama client."""
# make sure that ollama is installed
try:
import ollama # noqa: F401
except ImportError:
raise ImportError(
"Ollama is not installed. Please install it using `pip install ollama`."
)
super().__init__(**kwargs) super().__init__(**kwargs)
self._client = Client(host=self.base_url, timeout=self.request_timeout)
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
...@@ -103,10 +107,8 @@ class OllamaMultiModal(MultiModalLLM): ...@@ -103,10 +107,8 @@ class OllamaMultiModal(MultiModalLLM):
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
"""Chat.""" """Chat."""
import ollama
ollama_messages = _messages_to_dicts(messages) ollama_messages = _messages_to_dicts(messages)
response = ollama.chat( response = self._client.chat(
model=self.model, messages=ollama_messages, stream=False, **kwargs model=self.model, messages=ollama_messages, stream=False, **kwargs
) )
return ChatResponse( return ChatResponse(
...@@ -123,10 +125,8 @@ class OllamaMultiModal(MultiModalLLM): ...@@ -123,10 +125,8 @@ class OllamaMultiModal(MultiModalLLM):
self, messages: Sequence[ChatMessage], **kwargs: Any self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen: ) -> ChatResponseGen:
"""Stream chat.""" """Stream chat."""
import ollama
ollama_messages = _messages_to_dicts(messages) ollama_messages = _messages_to_dicts(messages)
response = ollama.chat( response = self._client.chat(
model=self.model, messages=ollama_messages, stream=True, **kwargs model=self.model, messages=ollama_messages, stream=True, **kwargs
) )
text = "" text = ""
...@@ -157,9 +157,7 @@ class OllamaMultiModal(MultiModalLLM): ...@@ -157,9 +157,7 @@ class OllamaMultiModal(MultiModalLLM):
**kwargs: Any, **kwargs: Any,
) -> CompletionResponse: ) -> CompletionResponse:
"""Complete.""" """Complete."""
import ollama response = self._client.generate(
response = ollama.generate(
model=self.model, model=self.model,
prompt=prompt, prompt=prompt,
images=image_documents_to_base64(image_documents), images=image_documents_to_base64(image_documents),
...@@ -181,9 +179,7 @@ class OllamaMultiModal(MultiModalLLM): ...@@ -181,9 +179,7 @@ class OllamaMultiModal(MultiModalLLM):
**kwargs: Any, **kwargs: Any,
) -> CompletionResponseGen: ) -> CompletionResponseGen:
"""Stream complete.""" """Stream complete."""
import ollama response = self._client.generate(
response = ollama.generate(
model=self.model, model=self.model,
prompt=prompt, prompt=prompt,
images=image_documents_to_base64(image_documents), images=image_documents_to_base64(image_documents),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment