diff --git a/llama-index-core/llama_index/core/indices/prompt_helper.py b/llama-index-core/llama_index/core/indices/prompt_helper.py index af9be20e9cc81d82f7d0ecb74b1145a95dd04dae..f6d0555de18729e88f24898ab5f2adb188acded7 100644 --- a/llama-index-core/llama_index/core/indices/prompt_helper.py +++ b/llama-index-core/llama_index/core/indices/prompt_helper.py @@ -28,7 +28,7 @@ from llama_index.core.prompts import ( SelectorPromptTemplate, ) from llama_index.core.prompts.prompt_utils import get_empty_prompt_txt -from llama_index.core.prompts.utils import format_string +from llama_index.core.prompts.utils import format_content_blocks from llama_index.core.schema import BaseComponent from llama_index.core.utilities.token_counting import TokenCounter @@ -198,9 +198,10 @@ class PromptHelper(BaseComponent): for message in messages: partial_message = deepcopy(message) + # TODO: This does not count tokens in non-text blocks prompt_kwargs = prompt.kwargs or {} - partial_message.content = format_string( - partial_message.content or "", **prompt_kwargs + partial_message.blocks = format_content_blocks( + partial_message.blocks, **prompt_kwargs ) # add to list of partial messages diff --git a/llama-index-core/llama_index/core/prompts/utils.py b/llama-index-core/llama_index/core/prompts/utils.py index 824cfed1521eb88b5f853e31b28e0f2494fe0c35..8edf6ac09122158025ef5d30d194c000d58cb509 100644 --- a/llama-index-core/llama_index/core/prompts/utils.py +++ b/llama-index-core/llama_index/core/prompts/utils.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional import re from llama_index.core.base.llms.base import BaseLLM +from llama_index.core.base.llms.types import ContentBlock, TextBlock class SafeFormatter: @@ -27,6 +28,21 @@ def format_string(string_to_format: str, **kwargs: str) -> str: return formatter.format(string_to_format) +def format_content_blocks( + content_blocks: List[ContentBlock], **kwargs: str +) -> List[ContentBlock]: + """Format content blocks with kwargs.""" + formatter = SafeFormatter(format_dict=kwargs) + formatted_blocks: List[ContentBlock] = [] + for block in content_blocks: + if isinstance(block, TextBlock): + formatted_blocks.append(TextBlock(text=formatter.format(block.text))) + else: + formatted_blocks.append(block) + + return formatted_blocks + + def get_template_vars(template_str: str) -> List[str]: """Get template variables from a template string.""" variables = []