Skip to content
Snippets Groups Projects
Unverified Commit f9a057dd authored by Huu Le's avatar Huu Le Committed by GitHub
Browse files

feat: add support for multimodal indexes (#453)


---------
Co-authored-by: default avatarthucpn <thucsh2@gmail.com>
Co-authored-by: default avatarMarcus Schiesser <mail@marcusschiesser.de>
parent aedd73d8
No related branches found
Tags v0.3.23
No related merge requests found
---
"create-llama": patch
---
Add support multimodal indexes (e.g. from LlamaCloud)
import os
from typing import Optional
from typing import Any, Dict, List, Optional, Sequence
from llama_index.core import get_response_synthesizer
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.prompts.base import BasePromptTemplate
from llama_index.core.prompts.default_prompt_selectors import (
DEFAULT_TEXT_QA_PROMPT_SEL,
)
from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes
from llama_index.core.response_synthesizers.base import BaseSynthesizer, QueryTextType
from llama_index.core.schema import (
ImageNode,
NodeWithScore,
)
from llama_index.core.tools.query_engine import QueryEngineTool
from llama_index.core.types import RESPONSE_TEXT_TYPE
from app.settings import get_multi_modal_llm
def create_query_engine(index, **kwargs):
def create_query_engine(index, **kwargs) -> BaseQueryEngine:
"""
Create a query engine for the given index.
......@@ -12,16 +29,23 @@ def create_query_engine(index, **kwargs):
index: The index to create a query engine for.
params (optional): Additional parameters for the query engine, e.g: similarity_top_k
"""
top_k = int(os.getenv("TOP_K", 0))
if top_k != 0 and kwargs.get("filters") is None:
kwargs["similarity_top_k"] = top_k
multimodal_llm = get_multi_modal_llm()
if multimodal_llm:
kwargs["response_synthesizer"] = MultiModalSynthesizer(
multimodal_model=multimodal_llm,
)
# If index is index is LlamaCloudIndex
# use auto_routed mode for better query results
if (
index.__class__.__name__ == "LlamaCloudIndex"
and kwargs.get("auto_routed") is None
):
kwargs["auto_routed"] = True
if index.__class__.__name__ == "LlamaCloudIndex":
if kwargs.get("retrieval_mode") is None:
kwargs["retrieval_mode"] = "auto_routed"
if multimodal_llm:
kwargs["retrieve_image_nodes"] = True
return index.as_query_engine(**kwargs)
......@@ -51,3 +75,113 @@ def get_query_engine_tool(
name=name,
description=description,
)
class MultiModalSynthesizer(BaseSynthesizer):
"""
A synthesizer that summarizes text nodes and uses a multi-modal LLM to generate a response.
"""
def __init__(
self,
multimodal_model: MultiModalLLM,
response_synthesizer: Optional[BaseSynthesizer] = None,
text_qa_template: Optional[BasePromptTemplate] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._multi_modal_llm = multimodal_model
self._response_synthesizer = response_synthesizer or get_response_synthesizer()
self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
def _get_prompts(self, **kwargs) -> Dict[str, Any]:
return {
"text_qa_template": self._text_qa_template,
}
def _update_prompts(self, prompts: Dict[str, Any]) -> None:
if "text_qa_template" in prompts:
self._text_qa_template = prompts["text_qa_template"]
async def aget_response(
self,
*args,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return await self._response_synthesizer.aget_response(*args, **response_kwargs)
def get_response(self, *args, **kwargs) -> RESPONSE_TEXT_TYPE:
return self._response_synthesizer.get_response(*args, **kwargs)
async def asynthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
if len(image_nodes) == 0:
return await self._response_synthesizer.asynthesize(query, text_nodes)
# Summarize the text nodes to avoid exceeding the token limit
text_response = str(
await self._response_synthesizer.asynthesize(query, text_nodes)
)
fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str, # type: ignore
)
llm_response = await self._multi_modal_llm.acomplete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)
return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)
def synthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
if len(image_nodes) == 0:
return self._response_synthesizer.synthesize(query, text_nodes)
# Summarize the text nodes to avoid exceeding the token limit
text_response = str(self._response_synthesizer.synthesize(query, text_nodes))
fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str, # type: ignore
)
llm_response = self._multi_modal_llm.complete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)
return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)
import os
from typing import Dict
from typing import Dict, Optional
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.settings import Settings
# `Settings` does not support setting `MultiModalLLM`
# so we use a global variable to store it
_multi_modal_llm: Optional[MultiModalLLM] = None
def get_multi_modal_llm():
return _multi_modal_llm
def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
......@@ -60,14 +69,21 @@ def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.multi_modal_llms.openai.utils import GPT4V_MODELS
max_tokens = os.getenv("LLM_MAX_TOKENS")
model_name = os.getenv("MODEL", "gpt-4o-mini")
Settings.llm = OpenAI(
model=os.getenv("MODEL", "gpt-4o-mini"),
model=model_name,
temperature=float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
max_tokens=int(max_tokens) if max_tokens is not None else None,
)
if model_name in GPT4V_MODELS:
global _multi_modal_llm
_multi_modal_llm = OpenAIMultiModal(model=model_name)
dimensions = os.getenv("EMBEDDING_DIM")
Settings.embed_model = OpenAIEmbedding(
model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment