From 58caf3f89a6408e5f3f3c0027094db384fcc4eb0 Mon Sep 17 00:00:00 2001
From: Logan <logan.markewich@live.com>
Date: Sun, 16 Feb 2025 22:30:50 -0600
Subject: [PATCH] improve prompt helper multimodal support (#17831)

---
 .../llama_index/core/indices/prompt_helper.py    |  7 ++++---
 .../llama_index/core/prompts/utils.py            | 16 ++++++++++++++++
 2 files changed, 20 insertions(+), 3 deletions(-)

diff --git a/llama-index-core/llama_index/core/indices/prompt_helper.py b/llama-index-core/llama_index/core/indices/prompt_helper.py
index af9be20e9c..f6d0555de1 100644
--- a/llama-index-core/llama_index/core/indices/prompt_helper.py
+++ b/llama-index-core/llama_index/core/indices/prompt_helper.py
@@ -28,7 +28,7 @@ from llama_index.core.prompts import (
     SelectorPromptTemplate,
 )
 from llama_index.core.prompts.prompt_utils import get_empty_prompt_txt
-from llama_index.core.prompts.utils import format_string
+from llama_index.core.prompts.utils import format_content_blocks
 from llama_index.core.schema import BaseComponent
 from llama_index.core.utilities.token_counting import TokenCounter
 
@@ -198,9 +198,10 @@ class PromptHelper(BaseComponent):
             for message in messages:
                 partial_message = deepcopy(message)
 
+                # TODO: This does not count tokens in non-text blocks
                 prompt_kwargs = prompt.kwargs or {}
-                partial_message.content = format_string(
-                    partial_message.content or "", **prompt_kwargs
+                partial_message.blocks = format_content_blocks(
+                    partial_message.blocks, **prompt_kwargs
                 )
 
                 # add to list of partial messages
diff --git a/llama-index-core/llama_index/core/prompts/utils.py b/llama-index-core/llama_index/core/prompts/utils.py
index 824cfed152..8edf6ac091 100644
--- a/llama-index-core/llama_index/core/prompts/utils.py
+++ b/llama-index-core/llama_index/core/prompts/utils.py
@@ -2,6 +2,7 @@ from typing import Dict, List, Optional
 import re
 
 from llama_index.core.base.llms.base import BaseLLM
+from llama_index.core.base.llms.types import ContentBlock, TextBlock
 
 
 class SafeFormatter:
@@ -27,6 +28,21 @@ def format_string(string_to_format: str, **kwargs: str) -> str:
     return formatter.format(string_to_format)
 
 
+def format_content_blocks(
+    content_blocks: List[ContentBlock], **kwargs: str
+) -> List[ContentBlock]:
+    """Format content blocks with kwargs."""
+    formatter = SafeFormatter(format_dict=kwargs)
+    formatted_blocks: List[ContentBlock] = []
+    for block in content_blocks:
+        if isinstance(block, TextBlock):
+            formatted_blocks.append(TextBlock(text=formatter.format(block.text)))
+        else:
+            formatted_blocks.append(block)
+
+    return formatted_blocks
+
+
 def get_template_vars(template_str: str) -> List[str]:
     """Get template variables from a template string."""
     variables = []
-- 
GitLab