diff --git a/docs/examples/multi_modal/ollama_multi_modal.ipynb b/docs/examples/multi_modal/ollama_multi_modal.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3dff3b46aafa261314a85c9e1e8c9fa497fcb841 --- /dev/null +++ b/docs/examples/multi_modal/ollama_multi_modal.ipynb @@ -0,0 +1,271 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "03e3af8d-a850-49bd-9b0b-707b29ee320e", + "metadata": {}, + "source": [ + "# Multimodal Ollama\n", + "\n", + "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/multi_modal/ollama_multi_modal.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n", + "\n", + "This notebook shows you how to use our Ollama multimodal integration.\n", + "\n", + "Supports complete, stream_complete, chat, stream_chat methods (async support coming soon).\n", + "\n", + "Use on its own or plug into broader [multi-modal use cases](https://docs.llamaindex.ai/en/stable/use_cases/multimodal.html)" + ] + }, + { + "cell_type": "markdown", + "id": "f49494ce-2162-490d-8d42-e3a39ecc498f", + "metadata": {}, + "source": [ + "#### Define Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c8d5f34-7131-4470-879f-480c60f55250", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.multi_modal_llms import OllamaMultiModal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cac7cff1-f4ce-4b1e-b5f1-8b62fc2e4505", + "metadata": {}, + "outputs": [], + "source": [ + "mm_model = OllamaMultiModal(model=\"llava\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d720d38-3611-464a-ae8c-491b2dd2bf04", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-02-03 11:41:04-- https://res.cloudinary.com/hello-tickets/image/upload/c_limit,f_auto,q_auto,w_1920/v1640835927/o3pfl41q7m5bj8jardk0.jpg\n", + "Resolving res.cloudinary.com (res.cloudinary.com)... 2606:4700::6813:a641, 2606:4700::6813:a741, 104.19.166.65, ...\n", + "Connecting to res.cloudinary.com (res.cloudinary.com)|2606:4700::6813:a641|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 181517 (177K) [image/jpeg]\n", + "Saving to: ‘jerry_images/test.png’\n", + "\n", + "jerry_images/test.p 100%[===================>] 177.26K --.-KB/s in 0.01s \n", + "\n", + "2024-02-03 11:41:04 (14.6 MB/s) - ‘jerry_images/test.png’ saved [181517/181517]\n", + "\n" + ] + } + ], + "source": [ + "!wget \"https://res.cloudinary.com/hello-tickets/image/upload/c_limit,f_auto,q_auto,w_1920/v1640835927/o3pfl41q7m5bj8jardk0.jpg\" -O jerry_images/test.png" + ] + }, + { + "cell_type": "markdown", + "id": "d0c667a2-e0fc-4929-adf9-dfbc0017359b", + "metadata": {}, + "source": [ + "#### Load Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "427443b3-0c26-4e52-ab5d-82e63839428b", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.multi_modal_llms.generic_utils import load_image_urls\n", + "\n", + "image_urls = [\n", + " # \"https://www.visualcapitalist.com/wp-content/uploads/2023/10/US_Mortgage_Rate_Surge-Sept-11-1.jpg\",\n", + " # \"https://www.sportsnet.ca/wp-content/uploads/2023/11/CP1688996471-1040x572.jpg\",\n", + " \"https://res.cloudinary.com/hello-tickets/image/upload/c_limit,f_auto,q_auto,w_1920/v1640835927/o3pfl41q7m5bj8jardk0.jpg\",\n", + " # \"https://www.cleverfiles.com/howto/wp-content/uploads/2018/03/minion.jpg\",\n", + "]\n", + "\n", + "image_documents = load_image_urls(image_urls)" + ] + }, + { + "cell_type": "markdown", + "id": "f47de9bf-1a31-4b12-b388-a23fc5905c22", + "metadata": {}, + "source": [ + "#### Completion (Non-Streaming/Streaming)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96ebffbe-be55-48fc-9cc9-305654c9e20a", + "metadata": {}, + "outputs": [], + "source": [ + "complete_response = mm_model.complete(\n", + " prompt=\"Tell me more about this image\", image_documents=image_documents\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "509de1e2-319d-4f30-81ce-212a9a9025e8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " The image shows the ancient Greek landmark known as the Colosseum, located in Rome, Italy. It is lit up with colorful lights and appears to be illuminated against a night sky. The Colosseum is a distinctive oval structure that has been used for various purposes over the centuries, including gladiatorial contests and public events. The colors of light suggest they could be representing national colors, such as those of Italy (red), white, and green. The photo captures the grandeur and historical significance of this iconic monument. \n" + ] + } + ], + "source": [ + "print(str(complete_response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "483b3222-11ee-4ab1-9e3c-f223e55aa7c3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " This is an image of the Colosseum, a famous landmark located in Rome, Italy. It's illuminated at night with colorful lights, which gives it a festive appearance. The Colosseum is a significant historical and architectural structure that was used for gladiatorial contests and other public spectacles during the Roman Empire. " + ] + } + ], + "source": [ + "response_gen = mm_model.stream_complete(\n", + " prompt=\"Tell me more about this image\",\n", + " image_documents=image_documents,\n", + ")\n", + "for r in response_gen:\n", + " print(r.delta, end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "1a9f9c44-659c-4e71-9728-1388e9160c09", + "metadata": {}, + "source": [ + "#### Chat (Non-Streaming/Streaming)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af3203ce-4a43-44e2-b70b-17df390d9dc1", + "metadata": {}, + "outputs": [], + "source": [ + "# chat\n", + "from llama_index.llms import ChatMessage, MessageRole\n", + "\n", + "image_bytes_io = [d.resolve_image() for d in image_documents]\n", + "\n", + "chat_response = mm_model.chat(\n", + " [\n", + " ChatMessage(\n", + " role=MessageRole.USER,\n", + " content=\"Tell me more about this image\",\n", + " additional_kwargs={\"images\": image_bytes_io},\n", + " )\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1839bc6-3b02-4ff3-a0c4-dad3e7f995a7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "assistant: This is an image of the Colosseum, also known as the Flavian Amphitheatre, which is a renowned landmark located in Rome, Italy. The Colosseum is one of the most famous and well-preserved ancient structures in the world. It was built during the Roman Empire and was used for public entertainment such as gladiatorial contests, reenactments of battles, and dramas based on classical mythology.\n", + "\n", + "The structure has a distinctive elliptical shape with tiered seating for spectators, which is clearly visible in this image. The photo is taken at night, and the Colosseum is illuminated by colorful lights that highlight its arches and the overall outline of the building. The colors of the lights correspond to those of the Italian flag, symbolizing a sense of national pride or celebration.\n", + "\n", + "The architecture and design of the Colosseum are indicative of the engineering prowess of the Romans during their peak period. It has become an iconic symbol of Roman civilization and continues to be a popular tourist attraction in Rome. \n" + ] + } + ], + "source": [ + "print(str(chat_response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "977f1d6a-bde2-41be-8229-684f21f44014", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " The image shows the Colosseum, a well-known landmark and amphitheater in Rome, Italy. It is illuminated with colorful lights, which appear to be red, green, and yellow, possibly indicating a special event or celebration, such as a national holiday given the colors of the Italian flag (red, white, and green). The Colosseum is captured at night under a clear sky, which adds to its dramatic presentation. " + ] + } + ], + "source": [ + "# stream chat\n", + "from llama_index.llms import ChatMessage, MessageRole\n", + "\n", + "\n", + "image_bytes_io = [d.resolve_image() for d in image_documents]\n", + "\n", + "chat_gen = mm_model.stream_chat(\n", + " [\n", + " ChatMessage(\n", + " role=MessageRole.USER,\n", + " content=\"Tell me more about this image\",\n", + " additional_kwargs={\"images\": image_bytes_io},\n", + " )\n", + " ]\n", + ")\n", + "for r in chat_gen:\n", + " print(r.delta, end=\"\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llama_index_v2", + "language": "python", + "name": "llama_index_v2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama_index/multi_modal_llms/__init__.py b/llama_index/multi_modal_llms/__init__.py index c45ab8c230f1c01c88b04aaeb3b3872245615ca8..3ca6b25c8f93c8f222e953c10525d73a2061272f 100644 --- a/llama_index/multi_modal_llms/__init__.py +++ b/llama_index/multi_modal_llms/__init__.py @@ -7,6 +7,7 @@ from llama_index.multi_modal_llms.dashscope import ( DashScopeMultiModalModels, ) from llama_index.multi_modal_llms.gemini import GeminiMultiModal +from llama_index.multi_modal_llms.ollama import OllamaMultiModal from llama_index.multi_modal_llms.openai import OpenAIMultiModal from llama_index.multi_modal_llms.replicate_multi_modal import ReplicateMultiModal @@ -18,4 +19,5 @@ __all__ = [ "GeminiMultiModal", "DashScopeMultiModal", "DashScopeMultiModalModels", + "OllamaMultiModal", ] diff --git a/llama_index/multi_modal_llms/generic_utils.py b/llama_index/multi_modal_llms/generic_utils.py index 69800b2ba03e8387ea8ce5d6525c33e292463549..512a0712c8fad1446c0699b12de866ee1a19594d 100644 --- a/llama_index/multi_modal_llms/generic_utils.py +++ b/llama_index/multi_modal_llms/generic_utils.py @@ -1,8 +1,13 @@ import base64 -from typing import List +import logging +from typing import List, Sequence + +import requests from llama_index.schema import ImageDocument +logger = logging.getLogger(__name__) + def load_image_urls(image_urls: List[str]) -> List[ImageDocument]: # load remote image urls into image documents @@ -17,3 +22,30 @@ def load_image_urls(image_urls: List[str]) -> List[ImageDocument]: def encode_image(image_path: str) -> str: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") + + +# Supporting Ollama like Multi-Modal images base64 encoding +def image_documents_to_base64( + image_documents: Sequence[ImageDocument], +) -> List[str]: + image_encodings = [] + # encode image documents to base64 + for image_document in image_documents: + if image_document.image: + image_encodings.append(image_document.image) + elif image_document.image_path: + image_encodings.append(encode_image(image_document.image_path)) + elif ( + "file_path" in image_document.metadata + and image_document.metadata["file_path"] != "" + ): + image_encodings.append(encode_image(image_document.metadata["file_path"])) + elif image_document.image_url: + response = requests.get(image_document.image_url) + try: + image_encodings.append( + base64.b64encode(response.content).decode("utf-8") + ) + except Exception as e: + logger.warning(f"Cannot encode the image url-> {e}") + return image_encodings diff --git a/llama_index/multi_modal_llms/ollama.py b/llama_index/multi_modal_llms/ollama.py new file mode 100644 index 0000000000000000000000000000000000000000..5bff2ba72a2d610e75b5dd2d1163a659f3b74e16 --- /dev/null +++ b/llama_index/multi_modal_llms/ollama.py @@ -0,0 +1,219 @@ +from typing import Any, Dict, Sequence, Tuple + +from llama_index.bridge.pydantic import Field +from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + MessageRole, +) +from llama_index.multi_modal_llms import ( + MultiModalLLM, + MultiModalLLMMetadata, +) +from llama_index.multi_modal_llms.generic_utils import image_documents_to_base64 +from llama_index.schema import ImageDocument + + +def get_additional_kwargs( + response: Dict[str, Any], exclude: Tuple[str, ...] +) -> Dict[str, Any]: + return {k: v for k, v in response.items() if k not in exclude} + + +def _messages_to_dicts(messages: Sequence[ChatMessage]) -> Sequence[Dict[str, Any]]: + """Convert messages to dicts. + + For use in ollama API + + """ + results = [] + for message in messages: + # TODO: just pass through the image arg for now. + # TODO: have a consistent interface between multimodal models + images = message.additional_kwargs.get("images") + results.append( + { + "role": message.role.value, + "content": message.content, + "images": images, + } + ) + return results + + +class OllamaMultiModal(MultiModalLLM): + model: str = Field(description="The MultiModal Ollama model to use.") + temperature: float = Field( + default=0.75, + description="The temperature to use for sampling.", + gte=0.0, + lte=1.0, + ) + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of context tokens for the model.", + gt=0, + ) + additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description="Additional model parameters for the Ollama API.", + ) + + def __init__(self, **kwargs: Any) -> None: + """Init params.""" + # make sure that ollama is installed + try: + import ollama # noqa: F401 + except ImportError: + raise ImportError( + "Ollama is not installed. Please install it using `pip install ollama`." + ) + super().__init__(**kwargs) + + @classmethod + def class_name(cls) -> str: + return "Ollama_multi_modal_llm" + + @property + def metadata(self) -> MultiModalLLMMetadata: + """LLM metadata.""" + return MultiModalLLMMetadata( + context_window=self.context_window, + num_output=DEFAULT_NUM_OUTPUTS, + model_name=self.model, + is_chat_model=True, # Ollama supports chat API for all models + ) + + @property + def _model_kwargs(self) -> Dict[str, Any]: + base_kwargs = { + "temperature": self.temperature, + "num_ctx": self.context_window, + } + return { + **base_kwargs, + **self.additional_kwargs, + } + + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + """Chat.""" + import ollama + + ollama_messages = _messages_to_dicts(messages) + response = ollama.chat(model=self.model, messages=ollama_messages, stream=False) + return ChatResponse( + message=ChatMessage( + content=response["message"]["content"], + role=MessageRole(response["message"]["role"]), + additional_kwargs=get_additional_kwargs(response, ("message",)), + ), + raw=response["message"], + additional_kwargs=get_additional_kwargs(response, ("message",)), + ) + + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + """Stream chat.""" + import ollama + + ollama_messages = _messages_to_dicts(messages) + response = ollama.chat(model=self.model, messages=ollama_messages, stream=True) + text = "" + for chunk in response: + if "done" in chunk and chunk["done"]: + break + message = chunk["message"] + delta = message.get("content") + text += delta + yield ChatResponse( + message=ChatMessage( + content=text, + role=MessageRole(message["role"]), + additional_kwargs=get_additional_kwargs( + message, ("content", "role") + ), + ), + delta=delta, + raw=message, + additional_kwargs=get_additional_kwargs(chunk, ("message",)), + ) + + def complete( + self, + prompt: str, + image_documents: Sequence[ImageDocument], + formatted: bool = False, + **kwargs: Any, + ) -> CompletionResponse: + """Complete.""" + import ollama + + response = ollama.generate( + model=self.model, + prompt=prompt, + images=image_documents_to_base64(image_documents), + stream=False, + options=self._model_kwargs, + ) + return CompletionResponse( + text=response["response"], + raw=response, + additional_kwargs=get_additional_kwargs(response, ("response",)), + ) + + def stream_complete( + self, + prompt: str, + image_documents: Sequence[ImageDocument], + formatted: bool = False, + **kwargs: Any, + ) -> CompletionResponseGen: + """Stream complete.""" + import ollama + + response = ollama.generate( + model=self.model, + prompt=prompt, + images=image_documents_to_base64(image_documents), + stream=True, + options=self._model_kwargs, + ) + text = "" + for chunk in response: + if "done" in chunk and chunk["done"]: + break + delta = chunk.get("response") + text += delta + yield CompletionResponse( + text=str(chunk["response"]), + delta=delta, + raw=chunk, + additional_kwargs=get_additional_kwargs(chunk, ("response",)), + ) + + async def acomplete( + self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any + ) -> CompletionResponse: + raise NotImplementedError("Ollama does not support async completion.") + + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + raise NotImplementedError("Ollama does not support async chat.") + + async def astream_complete( + self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any + ) -> CompletionResponseAsyncGen: + raise NotImplementedError("Ollama does not support async streaming completion.") + + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + raise NotImplementedError("Ollama does not support async streaming chat.")