diff --git a/CHANGELOG.md b/CHANGELOG.md index ea419b6af6a5d2908eebfb68bf208274672491a3..60f56b6e4970b1cf9d21d7c43945df735f81b71f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,16 @@ - Ensure temperature is a float for openai (#7382) - Remove duplicate subjects in knowledge graph retriever (#7378) +### Breaking/Deprecated API Changes +- Refactor prompt template (#7319) + - Use `BasePromptTemplate` for generic typing + - Use `PromptTemplate`, `ChatPromptTemplate`, `SelectorPromptTemplate` as core implementations + - Use `LangchainPromptTemplate` for compatibility with Langchain prompt templates + - Fully replace specific prompt classes (e.g. `SummaryPrompt`) with generic `BasePromptTemplate` for typing in codebase. + - Keep `Prompt` as an alias for `PromptTemplate` for backwards compatibility. + - BREAKING CHANGE: remove support for `Prompt.from_langchain_prompt`, please use `template=LangchainPromptTemplate(lc_template)` instead. + + ## [0.8.8] - 2023-08-23 ### New Features diff --git a/docs/api_reference/prompts.rst b/docs/api_reference/prompts.rst index 9bc92c46b0961d1425281d0c755b3991bb8339ff..1b64c244554fed5a6a18ecdc43f21d33df001d7e 100644 --- a/docs/api_reference/prompts.rst +++ b/docs/api_reference/prompts.rst @@ -7,26 +7,30 @@ These are the reference prompt templates. We first show links to default prompts. -We then show the base prompt class, -derived from `Langchain <https://langchain.readthedocs.io/en/latest/modules/prompt.html>`_. +We then show the base prompt template class and its subclasses. Default Prompts ^^^^^^^^^^^^^^^^^ -The list of default prompts can be `found here <https://github.com/jerryjliu/llama_index/blob/main/llama_index/prompts/default_prompts.py>`_. -**NOTE**: we've also curated a set of refine prompts for ChatGPT use cases. -The list of ChatGPT refine prompts can be -`found here <https://github.com/jerryjliu/llama_index/blob/main/llama_index/prompts/chat_prompts.py>`_. +* `Completion prompt templates <https://github.com/jerryjliu/llama_index/blob/main/llama_index/prompts/default_prompts.py>`_. +* `Chat prompt templates <https://github.com/jerryjliu/llama_index/blob/main/llama_index/prompts/chat_prompts.py>`_. +* `Selector prompt templates <https://github.com/jerryjliu/llama_index/blob/main/llama_index/prompts/default_prompt_selectors.py>`_. -Base Prompt Class + +Prompt Classes ^^^^^^^^^^^^^^^^^ -.. automodule:: llama_index.prompts - :members: - :inherited-members: - :exclude-members: Config, construct, copy, dict, from_examples, from_file, get_full_format_args, output_parser, save, template, template_format, update_forward_refs, validate_variable_names, json, template_is_valid +.. autopydantic_model:: llama_index.prompts.base.BasePromptTemplate + +.. autopydantic_model:: llama_index.prompts.base.PromptTemplate + +.. autopydantic_model:: llama_index.prompts.base.ChatPromptTemplate + +.. autopydantic_model:: llama_index.prompts.base.SelectorPromptTemplate + +.. autopydantic_model:: llama_index.prompts.base.LangchainPromptTemplate Subclass Prompts (deprecated) diff --git a/docs/core_modules/model_modules/llms/usage_custom.md b/docs/core_modules/model_modules/llms/usage_custom.md index 9b92ac4865d0d79f65af37b985a68e9ac8df0bc4..f24a12d793c12055e9086b47694c168af2755404 100644 --- a/docs/core_modules/model_modules/llms/usage_custom.md +++ b/docs/core_modules/model_modules/llms/usage_custom.md @@ -119,7 +119,7 @@ Many open-source models from HuggingFace require either some preamble before eac Below, this example uses both the `system_prompt` and `query_wrapper_prompt`, using specific prompts from the model card found [here](https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b). ```python -from llama_index.prompts.prompts import SimpleInputPrompt +from llama_index.prompts import PromptTemplate system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version) - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. @@ -129,7 +129,7 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version) """ # This will wrap the default prompts that are internal to llama-index -query_wrapper_prompt = SimpleInputPrompt("<|USER|>{query_str}<|ASSISTANT|>") +query_wrapper_prompt = PromptTemplate("<|USER|>{query_str}<|ASSISTANT|>") import torch from llama_index.llms import HuggingFaceLLM diff --git a/docs/core_modules/model_modules/prompts.md b/docs/core_modules/model_modules/prompts.md index ca551647e522e6eea37922edaf9263d8ccfe7eee..1d078d2a781feb5172a10b102a6b0740cc9dd121 100644 --- a/docs/core_modules/model_modules/prompts.md +++ b/docs/core_modules/model_modules/prompts.md @@ -18,7 +18,7 @@ Users may also provide their own prompt templates to further customize the behav Defining a custom prompt is as simple as creating a format string ```python -from llama_index import Prompt +from llama_index.prompts import PromptTemplate template = ( "We have provided context information below. \n" @@ -27,10 +27,36 @@ template = ( "\n---------------------\n" "Given this information, please answer the question: {query_str}\n" ) -qa_template = Prompt(template) +qa_template = PromptTemplate(template) + +# you can create text prompt (for completion API) +prompt = qa_template.format(context_str=..., query_str=...) + +# or easily convert to message prompts (for chat API) +messages = qa_template.format_messages(context_str=..., query_str=...) ``` -> Note: you may see references to legacy prompt subclasses such as `QuestionAnswerPrompt`, `RefinePrompt`. These have been deprecated (and now are type aliases of `Prompt`). Now you can directly specify `Prompt(template)` to construct custom prompts. But you still have to make sure the template string contains the expected parameters (e.g. `{context_str}` and `{query_str}`) when replacing a default question answer prompt. +> Note: you may see references to legacy prompt subclasses such as `QuestionAnswerPrompt`, `RefinePrompt`. These have been deprecated (and now are type aliases of `PromptTemplate`). Now you can directly specify `PromptTemplate(template)` to construct custom prompts. But you still have to make sure the template string contains the expected parameters (e.g. `{context_str}` and `{query_str}`) when replacing a default question answer prompt. + +You can also define a template from chat messages +```python +from llama_index.prompts import ChatPromptTemplate, ChatMessage, MessageRole + +message_templates = [ + ChatMessage(content="You are an expert system.", role=MessageRole.SYSTEM), + ChatMessage( + content="Generate a short story about {topic}", + role=MessageRole.USER, + ), +] +chat_template = ChatPromptTemplate(message_templates=message_templates) + +# you can create message prompts (for chat API) +messages = chat_template.format_messages(topic=...) + +# or easily convert to text prompt (for completion API) +prompt = chat_template.format(topic=...) +``` ### Passing custom prompts into the pipeline @@ -45,8 +71,8 @@ The most commonly used prompts will be the `text_qa_template` and the `refine_te #### Modify prompts used in index construction Different indices use different types of prompts during construction (some don't use prompts at all). -For instance, `TreeIndex` uses a `SummaryPrompt` to hierarchically -summarize the nodes, and `KeywordTableIndex` uses a `KeywordExtractPrompt` to extract keywords. +For instance, `TreeIndex` uses a summary prompt to hierarchically +summarize the nodes, and `KeywordTableIndex` uses a keyword extract prompt to extract keywords. There are two equivalent ways to override the prompts: diff --git a/docs/core_modules/query_modules/chat_engines/usage_pattern.md b/docs/core_modules/query_modules/chat_engines/usage_pattern.md index 63fefcc6f19f29d6db69df3374ae4d238b609dbf..883a57fc7e0241f919fb6e3213d1ba4571f61fe3 100644 --- a/docs/core_modules/query_modules/chat_engines/usage_pattern.md +++ b/docs/core_modules/query_modules/chat_engines/usage_pattern.md @@ -63,10 +63,10 @@ Here's an example where we configure the following: * print verbose debug message. ```python -from llama_index.prompts import Prompt +from llama_index.prompts import PromptTemplate from llama_index.llms import ChatMessage, MessageRole -custom_prompt = Prompt("""\ +custom_prompt = PromptTemplate("""\ Given a conversation (between Human and Assistant) and a follow up message from Human, \ rewrite the message to be a standalone question that captures all relevant context \ from the conversation. diff --git a/docs/core_modules/query_modules/structured_outputs/output_parser.md b/docs/core_modules/query_modules/structured_outputs/output_parser.md index c5fcc15e7d5db9b153c18aeebac791df2bb52a1d..f2ba5328b2570353c7b0860760a3b61c100e95b9 100644 --- a/docs/core_modules/query_modules/structured_outputs/output_parser.md +++ b/docs/core_modules/query_modules/structured_outputs/output_parser.md @@ -15,7 +15,7 @@ Guardrails is an open-source Python package for specification/validation/correct from llama_index import VectorStoreIndex, SimpleDirectoryReader from llama_index.output_parsers import GuardrailsOutputParser from llama_index.llm_predictor import StructuredLLMPredictor -from llama_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt +from llama_index.prompts import PromptTemplate from llama_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT_TMPL, DEFAULT_REFINE_PROMPT_TMPL @@ -62,8 +62,8 @@ output_parser = GuardrailsOutputParser.from_rail_string(rail_spec, llm=llm_predi fmt_qa_tmpl = output_parser.format(DEFAULT_TEXT_QA_PROMPT_TMPL) fmt_refine_tmpl = output_parser.format(DEFAULT_REFINE_PROMPT_TMPL) -qa_prompt = QuestionAnswerPrompt(fmt_qa_tmpl, output_parser=output_parser) -refine_prompt = RefinePrompt(fmt_refine_tmpl, output_parser=output_parser) +qa_prompt = PromptTemplate(fmt_qa_tmpl, output_parser=output_parser) +refine_prompt = PromptTemplate(fmt_refine_tmpl, output_parser=output_parser) # obtain a structured response query_engine = index.as_query_engine( @@ -94,7 +94,7 @@ Langchain also offers output parsing modules that you can use within LlamaIndex. from llama_index import VectorStoreIndex, SimpleDirectoryReader from llama_index.output_parsers import LangchainOutputParser from llama_index.llm_predictor import StructuredLLMPredictor -from llama_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt +from llama_index.prompts import PromptTemplate from llama_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT_TMPL, DEFAULT_REFINE_PROMPT_TMPL from langchain.output_parsers import StructuredOutputParser, ResponseSchema @@ -117,8 +117,8 @@ output_parser = LangchainOutputParser(lc_output_parser) # format each prompt with output parser instructions fmt_qa_tmpl = output_parser.format(DEFAULT_TEXT_QA_PROMPT_TMPL) fmt_refine_tmpl = output_parser.format(DEFAULT_REFINE_PROMPT_TMPL) -qa_prompt = QuestionAnswerPrompt(fmt_qa_tmpl, output_parser=output_parser) -refine_prompt = RefinePrompt(fmt_refine_tmpl, output_parser=output_parser) +qa_prompt = PromptTemplate(fmt_qa_tmpl, output_parser=output_parser) +refine_prompt = PromptTemplate(fmt_refine_tmpl, output_parser=output_parser) # query index query_engine = index.as_query_engine( diff --git a/docs/end_to_end_tutorials/question_and_answer/terms_definitions_tutorial.md b/docs/end_to_end_tutorials/question_and_answer/terms_definitions_tutorial.md index 98f79ee68290e4a1112a302ce3c677c5368c5c34..0f576b5ffc208e46b9ba3b904dae09fec2f8c215 100644 --- a/docs/end_to_end_tutorials/question_and_answer/terms_definitions_tutorial.md +++ b/docs/end_to_end_tutorials/question_and_answer/terms_definitions_tutorial.md @@ -298,14 +298,9 @@ This is due to the concept of "refining" answers in Llama Index. Since we are qu So, the refine process seems to be messing with our results! Rather than appending extra instructions to the `query_str`, remove that, and Llama Index will let us provide our own custom prompts! Let's create those now, using the [default prompts](https://github.com/jerryjliu/llama_index/blob/main/llama_index/prompts/default_prompts.py) and [chat specific prompts](https://github.com/jerryjliu/llama_index/blob/main/llama_index/prompts/chat_prompts.py) as a guide. Using a new file `constants.py`, let's create some new query templates: ```python -from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model -from langchain.prompts.chat import ( - AIMessagePromptTemplate, - ChatPromptTemplate, - HumanMessagePromptTemplate, -) - -from llama_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt +from llama_index.prompts import PromptTemplate, SelectorPromptTemplate, ChatPromptTemplate +from llama_index.prompts.utils import is_chat_model +from llama_index.llms.base import ChatMessage, MessageRole # Text QA templates DEFAULT_TEXT_QA_PROMPT_TMPL = ( @@ -316,7 +311,7 @@ DEFAULT_TEXT_QA_PROMPT_TMPL = ( "Given the context information answer the following question " "(if you don't know the answer, use the best of your knowledge): {query_str}\n" ) -TEXT_QA_TEMPLATE = QuestionAnswerPrompt(DEFAULT_TEXT_QA_PROMPT_TMPL) +TEXT_QA_TEMPLATE = PromptTemplate(DEFAULT_TEXT_QA_PROMPT_TMPL) # Refine templates DEFAULT_REFINE_PROMPT_TMPL = ( @@ -330,32 +325,29 @@ DEFAULT_REFINE_PROMPT_TMPL = ( "Given the new context and using the best of your knowledge, improve the existing answer. " "If you can't improve the existing answer, just repeat it again." ) -DEFAULT_REFINE_PROMPT = RefinePrompt(DEFAULT_REFINE_PROMPT_TMPL) +DEFAULT_REFINE_PROMPT = PromptTemplate(DEFAULT_REFINE_PROMPT_TMPL) CHAT_REFINE_PROMPT_TMPL_MSGS = [ - HumanMessagePromptTemplate.from_template("{query_str}"), - AIMessagePromptTemplate.from_template("{existing_answer}"), - HumanMessagePromptTemplate.from_template( - "We have the opportunity to refine the above answer " + ChatMessage(content="{query_str}", role=MessageRole.USER), + ChatMessage(content="{existing_answer}", role=MessageRole.ASSISTANT), + ChatMessage( + content="We have the opportunity to refine the above answer " "(only if needed) with some more context below.\n" "------------\n" "{context_msg}\n" "------------\n" "Given the new context and using the best of your knowledge, improve the existing answer. " - "If you can't improve the existing answer, just repeat it again." + "If you can't improve the existing answer, just repeat it again.", + role=MessageRole.USER, ), ] -CHAT_REFINE_PROMPT_LC = ChatPromptTemplate.from_messages(CHAT_REFINE_PROMPT_TMPL_MSGS) -CHAT_REFINE_PROMPT = RefinePrompt.from_langchain_prompt(CHAT_REFINE_PROMPT_LC) +CHAT_REFINE_PROMPT = ChatPromptTemplate(CHAT_REFINE_PROMPT_TMPL_MSGS) # refine prompt selector -DEFAULT_REFINE_PROMPT_SEL_LC = ConditionalPromptSelector( - default_prompt=DEFAULT_REFINE_PROMPT.get_langchain_prompt(), - conditionals=[(is_chat_model, CHAT_REFINE_PROMPT.get_langchain_prompt())], -) -REFINE_TEMPLATE = RefinePrompt( - langchain_prompt_selector=DEFAULT_REFINE_PROMPT_SEL_LC +REFINE_TEMPLATE = SelectorPromptTemplate( + default_template=DEFAULT_REFINE_PROMPT, + conditionals=[(is_chat_model, CHAT_REFINE_PROMPT)], ) ``` diff --git a/docs/examples/customization/llms/SimpleIndexDemo-Huggingface_camel.ipynb b/docs/examples/customization/llms/SimpleIndexDemo-Huggingface_camel.ipynb index a91fb6fdd0b19b152528e1eb6d31ef50b33982d6..0fce5ea5094999cb678e39a5723b3a2cebc7fe21 100644 --- a/docs/examples/customization/llms/SimpleIndexDemo-Huggingface_camel.ipynb +++ b/docs/examples/customization/llms/SimpleIndexDemo-Huggingface_camel.ipynb @@ -75,11 +75,11 @@ "outputs": [], "source": [ "# setup prompts - specific to StableLM\n", - "from llama_index.prompts.prompts import SimpleInputPrompt\n", + "from llama_index.prompts import PromptTemplate\n", "\n", "# This will wrap the default prompts that are internal to llama-index\n", "# taken from https://huggingface.co/Writer/camel-5b-hf\n", - "query_wrapper_prompt = SimpleInputPrompt(\n", + "query_wrapper_prompt = PromptTemplate(\n", " \"Below is an instruction that describes a task. \"\n", " \"Write a response that appropriately completes the request.\\n\\n\"\n", " \"### Instruction:\\n{query_str}\\n\\n### Response:\"\n", diff --git a/docs/examples/customization/llms/SimpleIndexDemo-Huggingface_stablelm.ipynb b/docs/examples/customization/llms/SimpleIndexDemo-Huggingface_stablelm.ipynb index 77955e9da59760470db271224f777deda9fa9db5..6919c882c1b7f6e23c196870ae9933ea8283932d 100644 --- a/docs/examples/customization/llms/SimpleIndexDemo-Huggingface_stablelm.ipynb +++ b/docs/examples/customization/llms/SimpleIndexDemo-Huggingface_stablelm.ipynb @@ -75,7 +75,7 @@ "outputs": [], "source": [ "# setup prompts - specific to StableLM\n", - "from llama_index.prompts.prompts import SimpleInputPrompt\n", + "from llama_index.prompts import PromptTemplate\n", "\n", "system_prompt = \"\"\"<|SYSTEM|># StableLM Tuned (Alpha version)\n", "- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.\n", @@ -85,7 +85,7 @@ "\"\"\"\n", "\n", "# This will wrap the default prompts that are internal to llama-index\n", - "query_wrapper_prompt = SimpleInputPrompt(\"<|USER|>{query_str}<|ASSISTANT|>\")" + "query_wrapper_prompt = PromptTemplate(\"<|USER|>{query_str}<|ASSISTANT|>\")" ] }, { diff --git a/docs/examples/customization/prompts/chat_prompts.ipynb b/docs/examples/customization/prompts/chat_prompts.ipynb index eae5407cc8a1d0c8975e7631249494704512d4a6..1ff7f444ce0afdb4c815ee30aff3d151095f980f 100644 --- a/docs/examples/customization/prompts/chat_prompts.ipynb +++ b/docs/examples/customization/prompts/chat_prompts.ipynb @@ -20,54 +20,55 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "from langchain.prompts.chat import (\n", - " ChatPromptTemplate,\n", - " HumanMessagePromptTemplate,\n", - " SystemMessagePromptTemplate,\n", - ")\n", - "from llama_index.prompts import Prompt\n", + "from llama_index.llms import ChatMessage, MessageRole\n", + "from llama_index.prompts import ChatPromptTemplate\n", "\n", + "# Text QA Prompt\n", "chat_text_qa_msgs = [\n", - " SystemMessagePromptTemplate.from_template(\n", - " \"Always answer the question, even if the context isn't helpful.\"\n", + " ChatMessage(\n", + " role=MessageRole.SYSTEM,\n", + " content=\"Always answer the question, even if the context isn't helpful.\",\n", " ),\n", - " HumanMessagePromptTemplate.from_template(\n", - " \"Context information is below.\\n\"\n", - " \"---------------------\\n\"\n", - " \"{context_str}\\n\"\n", - " \"---------------------\\n\"\n", - " \"Given the context information and not prior knowledge, \"\n", - " \"answer the question: {query_str}\\n\"\n", + " ChatMessage(\n", + " role=MessageRole.USER,\n", + " content=(\n", + " \"Context information is below.\\n\"\n", + " \"---------------------\\n\"\n", + " \"{context_str}\\n\"\n", + " \"---------------------\\n\"\n", + " \"Given the context information and not prior knowledge, \"\n", + " \"answer the question: {query_str}\\n\"\n", + " ),\n", " ),\n", "]\n", - "chat_text_qa_msgs_lc = ChatPromptTemplate.from_messages(chat_text_qa_msgs)\n", - "text_qa_template = Prompt.from_langchain_prompt(chat_text_qa_msgs_lc)\n", + "text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)\n", "\n", "# Refine Prompt\n", "chat_refine_msgs = [\n", - " SystemMessagePromptTemplate.from_template(\n", - " \"Always answer the question, even if the context isn't helpful.\"\n", + " ChatMessage(\n", + " role=MessageRole.SYSTEM,\n", + " content=\"Always answer the question, even if the context isn't helpful.\",\n", " ),\n", - " HumanMessagePromptTemplate.from_template(\n", - " \"We have the opportunity to refine the original answer \"\n", - " \"(only if needed) with some more context below.\\n\"\n", - " \"------------\\n\"\n", - " \"{context_msg}\\n\"\n", - " \"------------\\n\"\n", - " \"Given the new context, refine the original answer to better \"\n", - " \"answer the question: {query_str}. \"\n", - " \"If the context isn't useful, output the original answer again.\\n\"\n", - " \"Original Answer: {existing_answer}\"\n", + " ChatMessage(\n", + " role=MessageRole.USER,\n", + " content=(\n", + " \"We have the opportunity to refine the original answer \"\n", + " \"(only if needed) with some more context below.\\n\"\n", + " \"------------\\n\"\n", + " \"{context_msg}\\n\"\n", + " \"------------\\n\"\n", + " \"Given the new context, refine the original answer to better \"\n", + " \"answer the question: {query_str}. \"\n", + " \"If the context isn't useful, output the original answer again.\\n\"\n", + " \"Original Answer: {existing_answer}\"\n", + " ),\n", " ),\n", "]\n", - "\n", - "\n", - "chat_refine_msgs_lc = ChatPromptTemplate.from_messages(chat_refine_msgs)\n", - "refine_template = Prompt.from_langchain_prompt(chat_refine_msgs_lc)" + "refine_template = ChatPromptTemplate(chat_refine_msgs)" ] }, { @@ -82,20 +83,20 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import openai\n", "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = \"YOUR_API_KEY\"\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n", "openai.api_key = os.environ[\"OPENAI_API_KEY\"]" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -106,7 +107,7 @@ "\n", "# Create an index using a chat model, so that we can use the chat prompts!\n", "service_context = ServiceContext.from_defaults(\n", - " llm=OpenAI(model=\"gpt-3.5-turbo\", temperature=0.0)\n", + " llm=OpenAI(model=\"gpt-3.5-turbo\", temperature=0.1)\n", ")\n", "\n", "index = VectorStoreIndex.from_documents(documents, service_context=service_context)" @@ -122,14 +123,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Based on the given context information, there is no mention of Joe Biden. Therefore, it is not possible to determine who Joe Biden is based on this information alone.\n" + "I'm sorry, but the given context does not provide any information about Joe Biden.\n" ] } ], @@ -147,14 +148,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Joe Biden is a politician who served as the 46th President of the United States.\n" + "Joe Biden is the 46th President of the United States.\n" ] } ], @@ -176,9 +177,9 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "llama-index", "language": "python", - "name": "python3" + "name": "llama-index" }, "language_info": { "codemirror_mode": { @@ -190,7 +191,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6" + "version": "3.11.0" }, "orig_nbformat": 4 }, diff --git a/docs/examples/customization/prompts/completion_prompts.ipynb b/docs/examples/customization/prompts/completion_prompts.ipynb index c900a7184e23d675d12e632c0c3cce50d8ec2e95..a4adf4cf223530a7877741ab60b43213a5adbf85 100644 --- a/docs/examples/customization/prompts/completion_prompts.ipynb +++ b/docs/examples/customization/prompts/completion_prompts.ipynb @@ -20,11 +20,11 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "from llama_index.prompts import Prompt\n", + "from llama_index.prompts import PromptTemplate\n", "\n", "text_qa_template_str = (\n", " \"Context information is below.\\n\"\n", @@ -35,7 +35,7 @@ " \"answer the question: {query_str}\\n\"\n", " \"If the context isn't helpful, you can also answer the question on your own.\\n\"\n", ")\n", - "text_qa_template = Prompt(text_qa_template_str)\n", + "text_qa_template = PromptTemplate(text_qa_template_str)\n", "\n", "refine_template_str = (\n", " \"The original question is as follows: {query_str}\\n\"\n", @@ -47,7 +47,7 @@ " \"------------\\n\"\n", " \"Using both the new context and your own knowledege, update or repeat the existing answer.\\n\"\n", ")\n", - "refine_template = Prompt(refine_template_str)" + "refine_template = PromptTemplate(refine_template_str)" ] }, { @@ -62,28 +62,31 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import openai\n", "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = \"YOU_API_KEY\"\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n", "openai.api_key = os.environ[\"OPENAI_API_KEY\"]" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ - "from llama_index import VectorStoreIndex, SimpleDirectoryReader\n", + "from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext\n", + "from llama_index.llms import OpenAI\n", + "\n", + "service_context = ServiceContext.from_defaults(llm=OpenAI(model=\"text-davinci-003\"))\n", "\n", "documents = SimpleDirectoryReader(\"../../data/paul_graham/\").load_data()\n", "\n", - "index = VectorStoreIndex.from_documents(documents)" + "index = VectorStoreIndex.from_documents(documents, service_context=service_context)" ] }, { @@ -96,15 +99,14 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n", - "Joe Biden is not mentioned in the context information.\n" + " Joe Biden is not mentioned in the context information.\n" ] } ], @@ -122,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -152,9 +154,9 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "llama-index", "language": "python", - "name": "python3" + "name": "llama-index" }, "language_info": { "codemirror_mode": { @@ -166,7 +168,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6" + "version": "3.11.0" }, "orig_nbformat": 4 }, diff --git a/docs/examples/output_parsing/GuardrailsDemo.ipynb b/docs/examples/output_parsing/GuardrailsDemo.ipynb index 4edd43c4a55a97875bf77f03be8f3ffc583fb189..b3818525066cd063ca3930b1cf91ec389a4c4df0 100644 --- a/docs/examples/output_parsing/GuardrailsDemo.ipynb +++ b/docs/examples/output_parsing/GuardrailsDemo.ipynb @@ -117,7 +117,7 @@ }, "outputs": [], "source": [ - "from llama_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt\n", + "from llama_index.prompts import PromptTemplate\n", "from llama_index.prompts.default_prompts import (\n", " DEFAULT_TEXT_QA_PROMPT_TMPL,\n", " DEFAULT_REFINE_PROMPT_TMPL,\n", @@ -190,8 +190,8 @@ "fmt_qa_tmpl = output_parser.format(DEFAULT_TEXT_QA_PROMPT_TMPL)\n", "fmt_refine_tmpl = output_parser.format(DEFAULT_REFINE_PROMPT_TMPL)\n", "\n", - "qa_prompt = QuestionAnswerPrompt(fmt_qa_tmpl, output_parser=output_parser)\n", - "refine_prompt = RefinePrompt(fmt_refine_tmpl, output_parser=output_parser)" + "qa_prompt = PromptTemplate(fmt_qa_tmpl, output_parser=output_parser)\n", + "refine_prompt = PromptTemplate(fmt_refine_tmpl, output_parser=output_parser)" ] }, { diff --git a/docs/examples/output_parsing/LangchainOutputParserDemo.ipynb b/docs/examples/output_parsing/LangchainOutputParserDemo.ipynb index 907a24a563bd2e0e1b0f5521a6d4827f6e4d643b..85fe9231f2f7e97ec9a8e0957658fe4605593850 100644 --- a/docs/examples/output_parsing/LangchainOutputParserDemo.ipynb +++ b/docs/examples/output_parsing/LangchainOutputParserDemo.ipynb @@ -119,7 +119,7 @@ }, "outputs": [], "source": [ - "from llama_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt\n", + "from llama_index.prompts import PromptTemplate\n", "from llama_index.prompts.default_prompts import (\n", " DEFAULT_TEXT_QA_PROMPT_TMPL,\n", " DEFAULT_REFINE_PROMPT_TMPL,\n", @@ -174,8 +174,8 @@ "fmt_qa_tmpl = output_parser.format(DEFAULT_TEXT_QA_PROMPT_TMPL)\n", "fmt_refine_tmpl = output_parser.format(DEFAULT_REFINE_PROMPT_TMPL)\n", "\n", - "qa_prompt = QuestionAnswerPrompt(fmt_qa_tmpl, output_parser=output_parser)\n", - "refine_prompt = RefinePrompt(fmt_refine_tmpl, output_parser=output_parser)" + "qa_prompt = PromptTemplate(fmt_qa_tmpl, output_parser=output_parser)\n", + "refine_prompt = PromptTemplate(fmt_refine_tmpl, output_parser=output_parser)" ] }, { diff --git a/docs/examples/query_engine/ensemble_query_engine.ipynb b/docs/examples/query_engine/ensemble_query_engine.ipynb index 33e623fe2292e0ca06c00f5d3bfa21223a620cb1..7825e1a52607e633256bf9c9176728ca2013d940 100644 --- a/docs/examples/query_engine/ensemble_query_engine.ipynb +++ b/docs/examples/query_engine/ensemble_query_engine.ipynb @@ -208,7 +208,7 @@ }, "outputs": [], "source": [ - "from llama_index import Prompt\n", + "from llama_index.prompts import PromptTemplate\n", "\n", "QA_PROMPT_TMPL = (\n", " \"Context information is below.\\n\"\n", @@ -223,7 +223,7 @@ " \"Question: {query_str}\\n\"\n", " \"Answer (including relevance score): \"\n", ")\n", - "QA_PROMPT = Prompt(QA_PROMPT_TMPL)\n", + "QA_PROMPT = PromptTemplate(QA_PROMPT_TMPL)\n", "\n", "keyword_query_engine = keyword_index.as_query_engine(text_qa_template=QA_PROMPT)\n", "vector_query_engine = vector_index.as_query_engine(text_qa_template=QA_PROMPT)" @@ -368,7 +368,9 @@ " \"Answer: \"\n", ")\n", "\n", - "tree_summarize = TreeSummarize(summary_template=Prompt(TREE_SUMMARIZE_PROMPT_TMPL))\n", + "tree_summarize = TreeSummarize(\n", + " summary_template=PromptTemplate(TREE_SUMMARIZE_PROMPT_TMPL)\n", + ")\n", "\n", "query_engine = RouterQueryEngine(\n", " selector=LLMMultiSelector.from_defaults(),\n", diff --git a/docs/examples/vector_stores/SimpleIndexDemoLlama-Local.ipynb b/docs/examples/vector_stores/SimpleIndexDemoLlama-Local.ipynb index 3fe80977ed152c46385399e71332677598a06be9..060d3bef0e9e4f90df43ad5a82f871caf3665936 100644 --- a/docs/examples/vector_stores/SimpleIndexDemoLlama-Local.ipynb +++ b/docs/examples/vector_stores/SimpleIndexDemoLlama-Local.ipynb @@ -177,7 +177,7 @@ "source": [ "import torch\n", "from llama_index.llms import HuggingFaceLLM\n", - "from llama_index.prompts.prompts import SimpleInputPrompt\n", + "from llama_index.prompts import PromptTemplate\n", "\n", "# Model names (make sure you have access on HF)\n", "LLAMA2_7B = \"meta-llama/Llama-2-7b-hf\"\n", @@ -197,7 +197,7 @@ "- Never generate offensive or foul language.\n", "\"\"\"\n", "\n", - "query_wrapper_prompt = SimpleInputPrompt(\n", + "query_wrapper_prompt = PromptTemplate(\n", " \"[INST]<<SYS>>\\n\" + SYSTEM_PROMPT + \"<</SYS>>\\n\\n{query_str}[/INST]\"\n", ")\n", "\n", diff --git a/examples/paul_graham_essay/TestEssay.ipynb b/examples/paul_graham_essay/TestEssay.ipynb index 4a1a20116cbf45ce9258631b611057f1071327c6..404778d020042e5a292ad6926d66b0b5cc01a2e5 100644 --- a/examples/paul_graham_essay/TestEssay.ipynb +++ b/examples/paul_graham_essay/TestEssay.ipynb @@ -292,7 +292,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import Prompt" + "from llama_index.prompts import PromptTemplate" ] }, { @@ -328,7 +328,7 @@ " \"Given the context information and not prior knowledge, \"\n", " f\"answer the question: {query_str}\\n\"\n", ")\n", - "SUMMARY_PROMPT = Prompt(SUMMARY_PROMPT_TMPL)\n", + "SUMMARY_PROMPT = PromptTemplate(SUMMARY_PROMPT_TMPL)\n", "index_with_query = TreeIndex.from_documents(documents, summary_template=SUMMARY_PROMPT)" ] }, diff --git a/examples/test_wiki/TestNYC-Benchmark-GPT4.ipynb b/examples/test_wiki/TestNYC-Benchmark-GPT4.ipynb index 01dd4e954eb5c2cba76bd4bf7a0295bd3d766f77..25bfc407c2df1e555a35b2d5f3199b3fcf250a2a 100644 --- a/examples/test_wiki/TestNYC-Benchmark-GPT4.ipynb +++ b/examples/test_wiki/TestNYC-Benchmark-GPT4.ipynb @@ -55,7 +55,7 @@ " LLMPredictor,\n", " VectorStoreIndex,\n", " ListIndex,\n", - " Prompt,\n", + " PromptTemplate,\n", " ServiceContext,\n", ")\n", "from llama_index.indices.base import BaseIndex\n", @@ -324,7 +324,7 @@ " \"and explain why.\"\n", ")\n", "\n", - "DEFAULT_EVAL_PROMPT = Prompt(EVAL_PROMPT_TMPL)" + "DEFAULT_EVAL_PROMPT = PromptTemplate(EVAL_PROMPT_TMPL)" ] }, { diff --git a/experimental/classifier/utils.py b/experimental/classifier/utils.py index ece63a75dbbe9621684f8fa5d7c41718abce367c..2f65448921a8f3af178f501d852d88e232442ba1 100644 --- a/experimental/classifier/utils.py +++ b/experimental/classifier/utils.py @@ -9,7 +9,7 @@ from sklearn.model_selection import train_test_split from llama_index.indices.utils import extract_numbers_given_response from llama_index.llm_predictor import LLMPredictor -from llama_index.prompts.base import Prompt +from llama_index.prompts import BasePromptTemplate, PromptTemplate def get_train_and_eval_data( @@ -74,7 +74,7 @@ def extract_float_given_response(response: str, n: int = 1) -> Optional[float]: def get_eval_preds( - train_prompt: Prompt, train_str: str, eval_df: pd.DataFrame, n: int = 20 + train_prompt: BasePromptTemplate, train_str: str, eval_df: pd.DataFrame, n: int = 20 ) -> List: """Get eval preds.""" llm_predictor = LLMPredictor() @@ -111,9 +111,7 @@ train_prompt_str = ( "Survived: " ) -train_prompt = Prompt( - input_variables=["train_str", "eval_str"], template=train_prompt_str -) +train_prompt = PromptTemplate(template=train_prompt_str) # prompt to summarize the data @@ -130,9 +128,7 @@ qa_data_str = ( "Given this, answer the question: {query_str}" ) -qa_data_prompt = Prompt( - input_variables=["context_str", "query_str"], template=qa_data_str -) +qa_data_prompt = PromptTemplate(template=qa_data_str) # prompt to refine the answer refine_str = ( @@ -151,10 +147,7 @@ refine_str = ( "answer the question. " "If the context isn't useful, return the original answer." ) -refine_prompt = Prompt( - input_variables=["query_str", "existing_answer", "context_msg"], - template=refine_str, -) +refine_prompt = PromptTemplate(template=refine_str) # train prompt with refined context @@ -174,6 +167,4 @@ train_prompt_with_context_str = ( "Survived: " ) -train_prompt_with_context = Prompt( - input_variables=["train_str", "eval_str"], template=train_prompt_with_context_str -) +train_prompt_with_context = PromptTemplate(template=train_prompt_with_context_str) diff --git a/llama_index/__init__.py b/llama_index/__init__.py index 27e426ff332d579a4b3247bf3ded7aa3f49dc717..19ff6e0d8933dbfc7681ceebf029adc099de11df 100644 --- a/llama_index/__init__.py +++ b/llama_index/__init__.py @@ -76,7 +76,14 @@ from llama_index.langchain_helpers.memory_wrapper import GPTIndexMemory from llama_index.langchain_helpers.sql_wrapper import SQLDatabase # prompts -from llama_index.prompts.base import Prompt +from llama_index.prompts import ( + BasePromptTemplate, + PromptTemplate, + ChatPromptTemplate, + SelectorPromptTemplate, + # backwards compatibility + Prompt, +) from llama_index.prompts.prompts import ( KeywordExtractPrompt, QueryKeywordExtractPrompt, @@ -168,6 +175,10 @@ __all__ = [ "GPTSQLStructStoreIndex", "GPTDocumentSummaryIndex", "Prompt", + "PromptTemplate", + "BasePromptTemplate", + "ChatPromptTemplate", + "SelectorPromptTemplate", "LangchainEmbedding", "OpenAIEmbedding", "SummaryPrompt", diff --git a/llama_index/agent/context_retriever_agent.py b/llama_index/agent/context_retriever_agent.py index 0bc92662aaaa1ac115cb5559793dbaa24e72d057..7c59adc4eef9a26a138e659ddaa21b8dee611f33 100644 --- a/llama_index/agent/context_retriever_agent.py +++ b/llama_index/agent/context_retriever_agent.py @@ -17,7 +17,7 @@ from llama_index.llms.base import LLM, ChatMessage from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import is_function_calling_model from llama_index.memory import BaseMemory, ChatMemoryBuffer -from llama_index.prompts.prompts import QuestionAnswerPrompt +from llama_index.prompts import PromptTemplate from llama_index.schema import NodeWithScore from llama_index.tools import BaseTool @@ -30,7 +30,7 @@ DEFAULT_QA_PROMPT_TMPL = ( "Given the context information and not prior knowledge, " "either pick the corresponding tool or answer the function: {query_str}\n" ) -DEFAULT_QA_PROMPT = QuestionAnswerPrompt(DEFAULT_QA_PROMPT_TMPL) +DEFAULT_QA_PROMPT = PromptTemplate(DEFAULT_QA_PROMPT_TMPL) class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): @@ -44,7 +44,7 @@ class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): Args: tools (List[BaseTool]): A list of tools. retriever (BaseRetriever): A retriever. - qa_prompt (Optional[QuestionAnswerPrompt]): A QA prompt. + qa_prompt (Optional[PromptTemplate]): A QA prompt. context_separator (str): A context separator. llm (Optional[OpenAI]): An OpenAI LLM. chat_history (Optional[List[ChatMessage]]): A chat history. @@ -59,7 +59,7 @@ class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): self, tools: List[BaseTool], retriever: BaseRetriever, - qa_prompt: QuestionAnswerPrompt, + qa_prompt: PromptTemplate, context_separator: str, llm: OpenAI, memory: BaseMemory, @@ -86,7 +86,7 @@ class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): cls, tools: List[BaseTool], retriever: BaseRetriever, - qa_prompt: Optional[QuestionAnswerPrompt] = None, + qa_prompt: Optional[PromptTemplate] = None, context_separator: str = "\n", llm: Optional[LLM] = None, chat_history: Optional[List[ChatMessage]] = None, @@ -102,7 +102,7 @@ class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): Args: retriever (BaseRetriever): A retriever. - qa_prompt (Optional[QuestionAnswerPrompt]): A QA prompt. + qa_prompt (Optional[PromptTemplate]): A QA prompt. context_separator (str): A context separator. llm (Optional[OpenAI]): An OpenAI LLM. chat_history (Optional[ChatMessageHistory]): A chat history. diff --git a/llama_index/chat_engine/condense_question.py b/llama_index/chat_engine/condense_question.py index e00cf0a9f936f47c8b0b0d82614c4aba72b4e525..79f5cd88c7fed235c075177a859227b60f7c5734 100644 --- a/llama_index/chat_engine/condense_question.py +++ b/llama_index/chat_engine/condense_question.py @@ -2,6 +2,7 @@ import logging from threading import Thread from typing import Any, List, Optional, Type +from llama_index.callbacks import CallbackManager, trace_method from llama_index.chat_engine.types import ( AgentChatResponse, BaseChatEngine, @@ -13,10 +14,9 @@ from llama_index.indices.service_context import ServiceContext from llama_index.llms.base import ChatMessage, MessageRole from llama_index.llms.generic_utils import messages_to_history_str from llama_index.memory import BaseMemory, ChatMemoryBuffer -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.response.schema import RESPONSE_TYPE, StreamingResponse from llama_index.tools import ToolOutput -from llama_index.callbacks import CallbackManager, trace_method logger = logging.getLogger(__name__) @@ -35,7 +35,7 @@ from the conversation. <Standalone question> """ -DEFAULT_PROMPT = Prompt(DEFAULT_TEMPLATE) +DEFAULT_PROMPT = PromptTemplate(DEFAULT_TEMPLATE) class CondenseQuestionChatEngine(BaseChatEngine): @@ -48,7 +48,7 @@ class CondenseQuestionChatEngine(BaseChatEngine): def __init__( self, query_engine: BaseQueryEngine, - condense_question_prompt: Prompt, + condense_question_prompt: BasePromptTemplate, memory: BaseMemory, service_context: ServiceContext, verbose: bool = False, @@ -65,7 +65,7 @@ class CondenseQuestionChatEngine(BaseChatEngine): def from_defaults( cls, query_engine: BaseQueryEngine, - condense_question_prompt: Optional[Prompt] = None, + condense_question_prompt: Optional[BasePromptTemplate] = None, chat_history: Optional[List[ChatMessage]] = None, memory: Optional[BaseMemory] = None, memory_cls: Type[BaseMemory] = ChatMemoryBuffer, diff --git a/llama_index/evaluation/base.py b/llama_index/evaluation/base.py index 85d0378cdde44f8ce043602eaab590f6646e1064..047fec057bd5f324127c58a3976d8e1c97fef54b 100644 --- a/llama_index/evaluation/base.py +++ b/llama_index/evaluation/base.py @@ -7,7 +7,7 @@ from typing import List, Optional from llama_index.indices.base import ServiceContext from llama_index.indices.list.base import ListIndex -from llama_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt +from llama_index.prompts import PromptTemplate from llama_index.schema import Document from llama_index.response.schema import Response @@ -155,8 +155,8 @@ class ResponseEvaluator: index = ListIndex.from_documents(context, service_context=self.service_context) response_txt = "" - EVAL_PROMPT_TMPL = QuestionAnswerPrompt(DEFAULT_EVAL_PROMPT) - REFINE_PROMPT_TMPL = RefinePrompt(DEFAULT_REFINE_PROMPT) + EVAL_PROMPT_TMPL = PromptTemplate(DEFAULT_EVAL_PROMPT) + REFINE_PROMPT_TMPL = PromptTemplate(DEFAULT_REFINE_PROMPT) query_engine = index.as_query_engine( text_qa_template=EVAL_PROMPT_TMPL, @@ -199,8 +199,8 @@ class ResponseEvaluator: ) response_txt = "" - EVAL_PROMPT_TMPL = QuestionAnswerPrompt(DEFAULT_EVAL_PROMPT) - REFINE_PROMPT_TMPL = RefinePrompt(DEFAULT_REFINE_PROMPT) + EVAL_PROMPT_TMPL = PromptTemplate(DEFAULT_EVAL_PROMPT) + REFINE_PROMPT_TMPL = PromptTemplate(DEFAULT_REFINE_PROMPT) query_engine = index.as_query_engine( text_qa_template=EVAL_PROMPT_TMPL, @@ -285,10 +285,8 @@ class QueryResponseEvaluator(BaseEvaluator): context = self.get_context(response) index = ListIndex.from_documents(context, service_context=self.service_context) - QUERY_RESPONSE_EVAL_PROMPT_TMPL = QuestionAnswerPrompt( - QUERY_RESPONSE_EVAL_PROMPT - ) - QUERY_RESPONSE_REFINE_PROMPT_TMPL = RefinePrompt(QUERY_RESPONSE_REFINE_PROMPT) + QUERY_RESPONSE_EVAL_PROMPT_TMPL = PromptTemplate(QUERY_RESPONSE_EVAL_PROMPT) + QUERY_RESPONSE_REFINE_PROMPT_TMPL = PromptTemplate(QUERY_RESPONSE_REFINE_PROMPT) query_response = f"Question: {query}\nResponse: {answer}" @@ -337,10 +335,8 @@ class QueryResponseEvaluator(BaseEvaluator): ) response_txt = "" - QUERY_RESPONSE_EVAL_PROMPT_TMPL = QuestionAnswerPrompt( - QUERY_RESPONSE_EVAL_PROMPT - ) - QUERY_RESPONSE_REFINE_PROMPT_TMPL = RefinePrompt( + QUERY_RESPONSE_EVAL_PROMPT_TMPL = PromptTemplate(QUERY_RESPONSE_EVAL_PROMPT) + QUERY_RESPONSE_REFINE_PROMPT_TMPL = PromptTemplate( QUERY_RESPONSE_REFINE_PROMPT ) diff --git a/llama_index/evaluation/dataset_generation.py b/llama_index/evaluation/dataset_generation.py index 566ca814909f93999b0fbdb916ceeb6ab509091d..195714d484b737691db9f080fbfd6d772a64367f 100644 --- a/llama_index/evaluation/dataset_generation.py +++ b/llama_index/evaluation/dataset_generation.py @@ -8,10 +8,10 @@ from typing import List, Optional from llama_index import ( Document, ListIndex, - QuestionAnswerPrompt, ServiceContext, ) from llama_index.llms.openai import OpenAI +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.schema import BaseNode, NodeWithScore, MetadataMode from llama_index.indices.postprocessor.node import KeywordNodePostprocessor @@ -49,7 +49,7 @@ class DatasetGenerator: nodes: List[BaseNode], service_context: Optional[ServiceContext] = None, num_questions_per_chunk: int = 10, - text_question_template: Optional[QuestionAnswerPrompt] = None, + text_question_template: Optional[BasePromptTemplate] = None, question_gen_query: Optional[str] = None, required_keywords: Optional[List[str]] = None, exclude_keywords: Optional[List[str]] = None, @@ -58,7 +58,7 @@ class DatasetGenerator: if service_context is None: service_context = _get_default_service_context() self.service_context = service_context - self.text_question_template = text_question_template or QuestionAnswerPrompt( + self.text_question_template = text_question_template or PromptTemplate( DEFAULT_QUESTION_GENERATION_PROMPT ) self.question_gen_query = ( @@ -77,7 +77,7 @@ class DatasetGenerator: documents: List[Document], service_context: Optional[ServiceContext] = None, num_questions_per_chunk: int = 10, - text_question_template: Optional[QuestionAnswerPrompt] = None, + text_question_template: Optional[BasePromptTemplate] = None, question_gen_query: Optional[str] = None, required_keywords: Optional[List[str]] = None, exclude_keywords: Optional[List[str]] = None, diff --git a/llama_index/evaluation/guideline_eval.py b/llama_index/evaluation/guideline_eval.py index 8142efdc94c187ee8d421ce0136d5acdb5b83665..0e378133b8fcfbb055a4a814da1b914d6dbd59a7 100644 --- a/llama_index/evaluation/guideline_eval.py +++ b/llama_index/evaluation/guideline_eval.py @@ -1,12 +1,12 @@ import logging from typing import Optional -from llama_index.bridge.langchain import PydanticOutputParser from pydantic import BaseModel, Field +from llama_index.bridge.langchain import PydanticOutputParser from llama_index.evaluation.base import BaseEvaluator, Evaluation from llama_index.indices.base import ServiceContext -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate from llama_index.response.schema import Response logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ class GuidelineEvaluator(BaseEvaluator): ) format_instructions = parser.get_format_instructions() response_str = response.response - prompt = Prompt(self.eval_template) + prompt = PromptTemplate(self.eval_template) logger.debug("prompt: %s", prompt) logger.debug("query: %s", query) logger.debug("response: %s", response_str) diff --git a/llama_index/indices/common/struct_store/base.py b/llama_index/indices/common/struct_store/base.py index beff18c0341cb66711fdcbe2d39b71dd53555920..9afa5b187f72c8817bce59117aa93651d637174d 100644 --- a/llama_index/indices/common/struct_store/base.py +++ b/llama_index/indices/common/struct_store/base.py @@ -9,6 +9,7 @@ from llama_index.data_structs.table import StructDatapoint from llama_index.indices.service_context import ServiceContext from llama_index.langchain_helpers.sql_wrapper import SQLDatabase from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import ( DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL, ) @@ -17,13 +18,6 @@ from llama_index.prompts.default_prompts import ( DEFAULT_TABLE_CONTEXT_QUERY, ) from llama_index.prompts.prompt_type import PromptType -from llama_index.prompts.prompts import ( - QuestionAnswerPrompt, - RefinePrompt, - RefineTableContextPrompt, - SchemaExtractPrompt, - TableContextPrompt, -) from llama_index.response_synthesizers import get_response_synthesizer from llama_index.schema import BaseNode, MetadataMode from llama_index.text_splitter import TextSplitter @@ -40,9 +34,9 @@ class SQLDocumentContextBuilder: llm_predictor (Optional[BaseLLMPredictor]): LLM Predictor to use. prompt_helper (Optional[PromptHelper]): Prompt Helper to use. text_splitter (Optional[TextSplitter]): Text Splitter to use. - table_context_prompt (Optional[TableContextPrompt]): A + table_context_prompt (Optional[BasePromptTemplate]): A Table Context Prompt (see :ref:`Prompt-Templates`). - refine_table_context_prompt (Optional[RefineTableContextPrompt]): + refine_table_context_prompt (Optional[BasePromptTemplate]): A Refine Table Context Prompt (see :ref:`Prompt-Templates`). table_context_task (Optional[str]): The query to perform on the table context. A default query string is used @@ -54,8 +48,8 @@ class SQLDocumentContextBuilder: sql_database: SQLDatabase, service_context: Optional[ServiceContext] = None, text_splitter: Optional[TextSplitter] = None, - table_context_prompt: Optional[TableContextPrompt] = None, - refine_table_context_prompt: Optional[RefineTableContextPrompt] = None, + table_context_prompt: Optional[BasePromptTemplate] = None, + refine_table_context_prompt: Optional[BasePromptTemplate] = None, table_context_task: Optional[str] = None, ) -> None: """Initialize params.""" @@ -92,14 +86,13 @@ class SQLDocumentContextBuilder: ) -> str: """Build context from documents for a single table.""" schema = self._sql_database.get_single_table_info(table_name) - prompt_with_schema = QuestionAnswerPrompt.from_prompt( - self._table_context_prompt.partial_format(schema=schema), - prompt_type=PromptType.QUESTION_ANSWER, - ) - refine_prompt_with_schema = RefinePrompt.from_prompt( - self._refine_table_context_prompt.partial_format(schema=schema), - prompt_type=PromptType.REFINE, + prompt_with_schema = self._table_context_prompt.partial_format(schema=schema) + prompt_with_schema.metadata["prompt_type"] = PromptType.QUESTION_ANSWER + refine_prompt_with_schema = self._refine_table_context_prompt.partial_format( + schema=schema ) + refine_prompt_with_schema.metadata["prompt_type"] = PromptType.REFINE + text_splitter = ( self._text_splitter or self._service_context.prompt_helper.get_text_splitter_given_prompt( @@ -143,7 +136,7 @@ class BaseStructDatapointExtractor: def __init__( self, llm_predictor: BaseLLMPredictor, - schema_extract_prompt: SchemaExtractPrompt, + schema_extract_prompt: BasePromptTemplate, output_parser: OUTPUT_PARSER_TYPE, ) -> None: """Initialize params.""" diff --git a/llama_index/indices/common/struct_store/sql.py b/llama_index/indices/common/struct_store/sql.py index a7524d231e947f8f8a3bccbfebb4e95a5f7c091e..5ec501fcd5ad9279cc8f746745b3acab8592537d 100644 --- a/llama_index/indices/common/struct_store/sql.py +++ b/llama_index/indices/common/struct_store/sql.py @@ -11,7 +11,7 @@ from llama_index.indices.common.struct_store.base import ( ) from llama_index.langchain_helpers.sql_wrapper import SQLDatabase from llama_index.llm_predictor.base import BaseLLMPredictor -from llama_index.prompts.prompts import SchemaExtractPrompt +from llama_index.prompts import BasePromptTemplate class SQLStructDatapointExtractor(BaseStructDatapointExtractor): @@ -20,7 +20,7 @@ class SQLStructDatapointExtractor(BaseStructDatapointExtractor): def __init__( self, llm_predictor: BaseLLMPredictor, - schema_extract_prompt: SchemaExtractPrompt, + schema_extract_prompt: BasePromptTemplate, output_parser: OUTPUT_PARSER_TYPE, sql_database: SQLDatabase, table_name: Optional[str] = None, diff --git a/llama_index/indices/common_tree/base.py b/llama_index/indices/common_tree/base.py index 454327ecd5288b09051eb74477c1ca77f672c430..e59a27673b2cb64d6d1b3b936d18814bd01369a0 100644 --- a/llama_index/indices/common_tree/base.py +++ b/llama_index/indices/common_tree/base.py @@ -8,14 +8,13 @@ from typing import Dict, List, Optional, Sequence, Tuple from llama_index.async_utils import run_async_tasks from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.data_structs.data_structs import IndexGraph -from llama_index.schema import BaseNode, TextNode -from llama_index.storage.docstore import BaseDocumentStore -from llama_index.storage.docstore.registry import get_default_docstore from llama_index.indices.service_context import ServiceContext from llama_index.indices.utils import get_sorted_node_list, truncate_text -from llama_index.prompts.prompts import SummaryPrompt +from llama_index.prompts import BasePromptTemplate +from llama_index.schema import BaseNode, MetadataMode, TextNode +from llama_index.storage.docstore import BaseDocumentStore +from llama_index.storage.docstore.registry import get_default_docstore from llama_index.utils import get_tqdm_iterable -from llama_index.schema import MetadataMode logger = logging.getLogger(__name__) @@ -31,7 +30,7 @@ class GPTTreeIndexBuilder: def __init__( self, num_children: int, - summary_prompt: SummaryPrompt, + summary_prompt: BasePromptTemplate, service_context: ServiceContext, docstore: Optional[BaseDocumentStore] = None, show_progress: bool = False, diff --git a/llama_index/indices/document_summary/base.py b/llama_index/indices/document_summary/base.py index cd6d973b067b5f5e32ee70f2ca136d2e53e71ae0..d6db59ee50b925057d3ff4c559cb2a6b21463a39 100644 --- a/llama_index/indices/document_summary/base.py +++ b/llama_index/indices/document_summary/base.py @@ -9,7 +9,6 @@ import logging from collections import defaultdict from enum import Enum from typing import Any, Dict, Optional, Sequence, Union, cast -from llama_index.utils import get_tqdm_iterable from llama_index.data_structs.document_summary import IndexDocumentSummary from llama_index.indices.base import BaseIndex @@ -18,17 +17,18 @@ from llama_index.indices.service_context import ServiceContext from llama_index.response.schema import Response from llama_index.response_synthesizers import ( BaseSynthesizer, - get_response_synthesizer, ResponseMode, + get_response_synthesizer, ) from llama_index.schema import ( BaseNode, - NodeWithScore, NodeRelationship, + NodeWithScore, RelatedNodeInfo, TextNode, ) from llama_index.storage.docstore.types import RefDocInfo +from llama_index.utils import get_tqdm_iterable logger = logging.getLogger(__name__) @@ -51,7 +51,7 @@ class DocumentSummaryIndex(BaseIndex[IndexDocumentSummary]): """Document Summary Index. Args: - summary_template (Optional[SummaryPrompt]): A Summary Prompt + summary_template (Optional[BasePromptTemplate]): A Summary Prompt (see :ref:`Prompt-Templates`). show_progress (bool): Whether to show tqdm progress bars. Defaults to False. diff --git a/llama_index/indices/document_summary/retrievers.py b/llama_index/indices/document_summary/retrievers.py index c0e4a21b654b1da1f188a691d5a6dbcfd3d38516..8e729e901fc9d1e186434cb6bebe26c4f0059ae6 100644 --- a/llama_index/indices/document_summary/retrievers.py +++ b/llama_index/indices/document_summary/retrievers.py @@ -16,9 +16,9 @@ from llama_index.indices.utils import ( default_format_node_batch_fn, default_parse_choice_select_answer_fn, ) -from llama_index.prompts.choice_select import ( +from llama_index.prompts import PromptTemplate +from llama_index.prompts.default_prompts import ( DEFAULT_CHOICE_SELECT_PROMPT, - ChoiceSelectPrompt, ) from llama_index.schema import BaseNode, MetadataMode, NodeWithScore @@ -38,7 +38,7 @@ class DocumentSummaryIndexRetriever(BaseRetriever): def __init__( self, index: DocumentSummaryIndex, - choice_select_prompt: Optional[ChoiceSelectPrompt] = None, + choice_select_prompt: Optional[PromptTemplate] = None, choice_batch_size: int = 10, format_node_batch_fn: Optional[Callable] = None, parse_choice_select_answer_fn: Optional[Callable] = None, diff --git a/llama_index/indices/empty/retrievers.py b/llama_index/indices/empty/retrievers.py index 5c2cb66525255dd8e3d16e78ed5226da2c1ca9d3..f4c09a76315156c333749af33b956e5448f03134 100644 --- a/llama_index/indices/empty/retrievers.py +++ b/llama_index/indices/empty/retrievers.py @@ -4,8 +4,8 @@ from typing import Any, List, Optional from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.empty.base import EmptyIndex from llama_index.indices.query.schema import QueryBundle +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_SIMPLE_INPUT_PROMPT -from llama_index.prompts.prompts import SimpleInputPrompt from llama_index.schema import NodeWithScore @@ -15,7 +15,7 @@ class EmptyIndexRetriever(BaseRetriever): Passes the raw LLM call to the underlying LLM model. Args: - input_prompt (Optional[SimpleInputPrompt]): A Simple Input Prompt + input_prompt (Optional[BasePromptTemplate]): A Simple Input Prompt (see :ref:`Prompt-Templates`). """ @@ -23,7 +23,7 @@ class EmptyIndexRetriever(BaseRetriever): def __init__( self, index: EmptyIndex, - input_prompt: Optional[SimpleInputPrompt] = None, + input_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> None: """Initialize params.""" diff --git a/llama_index/indices/keyword_table/base.py b/llama_index/indices/keyword_table/base.py index f8ec8a9d103d5b33fe3324fe3eadecf4675810aa..2655b69c0de7f57a0fefa2408b27d35dc36e9917 100644 --- a/llama_index/indices/keyword_table/base.py +++ b/llama_index/indices/keyword_table/base.py @@ -11,7 +11,6 @@ existing keywords in the table. from abc import abstractmethod from enum import Enum from typing import Any, Dict, Optional, Sequence, Set, Union -from llama_index.utils import get_tqdm_iterable from llama_index.async_utils import run_async_tasks from llama_index.data_structs.data_structs import KeywordTable @@ -19,13 +18,14 @@ from llama_index.indices.base import BaseIndex from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.keyword_table.utils import extract_keywords_given_response from llama_index.indices.service_context import ServiceContext +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import ( DEFAULT_KEYWORD_EXTRACT_TEMPLATE, DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE, ) -from llama_index.prompts.prompts import KeywordExtractPrompt from llama_index.schema import BaseNode, MetadataMode from llama_index.storage.docstore.types import RefDocInfo +from llama_index.utils import get_tqdm_iterable DQKET = DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE @@ -49,7 +49,7 @@ class BaseKeywordTableIndex(BaseIndex[KeywordTable]): are then used to answer the query. Args: - keyword_extract_template (Optional[KeywordExtractPrompt]): A Keyword + keyword_extract_template (Optional[BasePromptTemplate]): A Keyword Extraction Prompt (see :ref:`Prompt-Templates`). use_async (bool): Whether to use asynchronous calls. Defaults to False. @@ -64,7 +64,7 @@ class BaseKeywordTableIndex(BaseIndex[KeywordTable]): nodes: Optional[Sequence[BaseNode]] = None, index_struct: Optional[KeywordTable] = None, service_context: Optional[ServiceContext] = None, - keyword_extract_template: Optional[KeywordExtractPrompt] = None, + keyword_extract_template: Optional[BasePromptTemplate] = None, max_keywords_per_chunk: int = 10, use_async: bool = False, show_progress: bool = False, diff --git a/llama_index/indices/keyword_table/retrievers.py b/llama_index/indices/keyword_table/retrievers.py index 07a7ea2c9ec32ba376deb217c8d896109f9855dd..2c03664b525cf3c68bf7bcacf1b4b72372ee712d 100644 --- a/llama_index/indices/keyword_table/retrievers.py +++ b/llama_index/indices/keyword_table/retrievers.py @@ -12,11 +12,11 @@ from llama_index.indices.keyword_table.utils import ( simple_extract_keywords, ) from llama_index.indices.query.schema import QueryBundle +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import ( DEFAULT_KEYWORD_EXTRACT_TEMPLATE, DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE, ) -from llama_index.prompts.prompts import KeywordExtractPrompt, QueryKeywordExtractPrompt from llama_index.schema import NodeWithScore from llama_index.utils import truncate_text @@ -31,15 +31,15 @@ class BaseKeywordTableRetriever(BaseRetriever): Arguments are shared among subclasses. Args: - keyword_extract_template (Optional[KeywordExtractPrompt]): A Keyword + keyword_extract_template (Optional[BasePromptTemplate]): A Keyword Extraction Prompt (see :ref:`Prompt-Templates`). - query_keyword_extract_template (Optional[QueryKeywordExtractPrompt]): A Query + query_keyword_extract_template (Optional[BasePromptTemplate]): A Query Keyword Extraction Prompt (see :ref:`Prompt-Templates`). - refine_template (Optional[RefinePrompt]): A Refinement Prompt + refine_template (Optional[BasePromptTemplate]): A Refinement Prompt (see :ref:`Prompt-Templates`). - text_qa_template (Optional[QuestionAnswerPrompt]): A Question Answering Prompt + text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt (see :ref:`Prompt-Templates`). max_keywords_per_query (int): Maximum number of keywords to extract from query. num_chunks_per_query (int): Maximum number of text chunks to query. @@ -49,8 +49,8 @@ class BaseKeywordTableRetriever(BaseRetriever): def __init__( self, index: BaseKeywordTableIndex, - keyword_extract_template: Optional[KeywordExtractPrompt] = None, - query_keyword_extract_template: Optional[QueryKeywordExtractPrompt] = None, + keyword_extract_template: Optional[BasePromptTemplate] = None, + query_keyword_extract_template: Optional[BasePromptTemplate] = None, max_keywords_per_query: int = 10, num_chunks_per_query: int = 10, **kwargs: Any, diff --git a/llama_index/indices/knowledge_graph/base.py b/llama_index/indices/knowledge_graph/base.py index e807b22b749346492e034748f3c457aec9451500..7e1a0fae5e83f0c5485486652cc8c8538e043e21 100644 --- a/llama_index/indices/knowledge_graph/base.py +++ b/llama_index/indices/knowledge_graph/base.py @@ -18,8 +18,8 @@ from llama_index.graph_stores.types import GraphStore from llama_index.indices.base import BaseIndex from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.service_context import ServiceContext +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT -from llama_index.prompts.prompts import KnowledgeGraphPrompt from llama_index.schema import BaseNode, MetadataMode from llama_index.storage.docstore.types import RefDocInfo from llama_index.storage.storage_context import StorageContext @@ -34,7 +34,7 @@ class KnowledgeGraphIndex(BaseIndex[KG]): Build a KG by extracting triplets, and leveraging the KG during query-time. Args: - kg_triple_extract_template (KnowledgeGraphPrompt): The prompt to use for + kg_triple_extract_template (BasePromptTemplate): The prompt to use for extracting triplets. max_triplets_per_chunk (int): The maximum number of triplets to extract. service_context (Optional[ServiceContext]): The service context to use. @@ -58,7 +58,7 @@ class KnowledgeGraphIndex(BaseIndex[KG]): index_struct: Optional[KG] = None, service_context: Optional[ServiceContext] = None, storage_context: Optional[StorageContext] = None, - kg_triple_extract_template: Optional[KnowledgeGraphPrompt] = None, + kg_triple_extract_template: Optional[BasePromptTemplate] = None, max_triplets_per_chunk: int = 10, include_embeddings: bool = False, show_progress: bool = False, diff --git a/llama_index/indices/knowledge_graph/retrievers.py b/llama_index/indices/knowledge_graph/retrievers.py index a060d66222d49cb65e93b6a79a9ec950e343b89a..fa8205ccc820fe8765a1ba36cef96d89e70339e5 100644 --- a/llama_index/indices/knowledge_graph/retrievers.py +++ b/llama_index/indices/knowledge_graph/retrievers.py @@ -11,9 +11,8 @@ from llama_index.indices.knowledge_graph.base import KnowledgeGraphIndex from llama_index.indices.query.embedding_utils import get_top_k_embeddings from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext -from llama_index.prompts.base import Prompt, PromptType +from llama_index.prompts import BasePromptTemplate, PromptTemplate, PromptType from llama_index.prompts.default_prompts import DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE -from llama_index.prompts.prompts import QueryKeywordExtractPrompt from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, TextNode from llama_index.storage.storage_context import StorageContext from llama_index.utils import truncate_text @@ -53,9 +52,9 @@ class KGTableRetriever(BaseRetriever): query_keyword_extract_template (Optional[QueryKGExtractPrompt]): A Query KG Extraction Prompt (see :ref:`Prompt-Templates`). - refine_template (Optional[RefinePrompt]): A Refinement Prompt + refine_template (Optional[BasePromptTemplate]): A Refinement Prompt (see :ref:`Prompt-Templates`). - text_qa_template (Optional[QuestionAnswerPrompt]): A Question Answering Prompt + text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt (see :ref:`Prompt-Templates`). max_keywords_per_query (int): Maximum number of keywords to extract from query. num_chunks_per_query (int): Maximum number of text chunks to query. @@ -77,7 +76,7 @@ class KGTableRetriever(BaseRetriever): def __init__( self, index: KnowledgeGraphIndex, - query_keyword_extract_template: Optional[QueryKeywordExtractPrompt] = None, + query_keyword_extract_template: Optional[BasePromptTemplate] = None, max_keywords_per_query: int = 10, num_chunks_per_query: int = 10, include_text: bool = True, @@ -328,7 +327,7 @@ KEYWORDS: {question} ---- """ -DEFAULT_SYNONYM_EXPAND_PROMPT = Prompt( +DEFAULT_SYNONYM_EXPAND_PROMPT = PromptTemplate( DEFAULT_SYNONYM_EXPAND_TEMPLATE, prompt_type=PromptType.QUERY_KEYWORD_EXTRACT, ) @@ -344,7 +343,7 @@ class KnowledgeGraphRAGRetriever(BaseRetriever): service_context (Optional[ServiceContext]): A service context to use. storage_context (Optional[StorageContext]): A storage context to use. entity_extract_fn (Optional[Callable]): A function to extract entities. - entity_extract_template Optional[QueryKeywordExtractPrompt]): A Query Key Entity + entity_extract_template Optional[BasePromptTemplate]): A Query Key Entity Extraction Prompt (see :ref:`Prompt-Templates`). entity_extract_policy (Optional[str]): The entity extraction policy to use. default: "union" @@ -376,10 +375,10 @@ class KnowledgeGraphRAGRetriever(BaseRetriever): service_context: Optional[ServiceContext] = None, storage_context: Optional[StorageContext] = None, entity_extract_fn: Optional[Callable] = None, - entity_extract_template: Optional[QueryKeywordExtractPrompt] = None, + entity_extract_template: Optional[BasePromptTemplate] = None, entity_extract_policy: Optional[str] = "union", synonym_expand_fn: Optional[Callable] = None, - synonym_expand_template: Optional[Prompt] = None, + synonym_expand_template: Optional[BasePromptTemplate] = None, synonym_expand_policy: Optional[str] = "union", max_entities: int = 5, max_synonyms: int = 5, @@ -451,7 +450,7 @@ class KnowledgeGraphRAGRetriever(BaseRetriever): self, query_str: str, handle_fn: Optional[Callable], - handle_llm_prompt_template: Optional[Prompt], + handle_llm_prompt_template: Optional[BasePromptTemplate], cross_handle_policy: Optional[str] = "union", max_items: Optional[int] = 5, result_start_token: str = "KEYWORDS:", @@ -501,7 +500,7 @@ class KnowledgeGraphRAGRetriever(BaseRetriever): self, query_str: str, handle_fn: Optional[Callable], - handle_llm_prompt_template: Optional[Prompt], + handle_llm_prompt_template: Optional[BasePromptTemplate], cross_handle_policy: Optional[str] = "union", max_items: Optional[int] = 5, result_start_token: str = "KEYWORDS:", diff --git a/llama_index/indices/list/base.py b/llama_index/indices/list/base.py index a55270308a40d89e8b69871a43ad14b3642b2d50..1eb213cb6d1b34d24364068e9451177fbaa5c43c 100644 --- a/llama_index/indices/list/base.py +++ b/llama_index/indices/list/base.py @@ -7,7 +7,6 @@ in sequence in order to answer a given query. from enum import Enum from typing import Any, Dict, Optional, Sequence, Union -from llama_index.utils import get_tqdm_iterable from llama_index.data_structs.data_structs import IndexList from llama_index.indices.base import BaseIndex @@ -15,6 +14,7 @@ from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.service_context import ServiceContext from llama_index.schema import BaseNode from llama_index.storage.docstore.types import RefDocInfo +from llama_index.utils import get_tqdm_iterable class ListRetrieverMode(str, Enum): @@ -35,7 +35,7 @@ class ListIndex(BaseIndex[IndexList]): answer from all the nodes. Args: - text_qa_template (Optional[QuestionAnswerPrompt]): A Question-Answer Prompt + text_qa_template (Optional[BasePromptTemplate]): A Question-Answer Prompt (see :ref:`Prompt-Templates`). NOTE: this is a deprecated field. show_progress (bool): Whether to show tqdm progress bars. Defaults to False. diff --git a/llama_index/indices/list/retrievers.py b/llama_index/indices/list/retrievers.py index 6c05987096b2e26a010b08e3bbe541937b2513b4..7d1eefcc7e9f7f5e8552421ea92119d0fa66dd74 100644 --- a/llama_index/indices/list/retrievers.py +++ b/llama_index/indices/list/retrievers.py @@ -11,9 +11,9 @@ from llama_index.indices.utils import ( default_format_node_batch_fn, default_parse_choice_select_answer_fn, ) -from llama_index.prompts.choice_select import ( +from llama_index.prompts import PromptTemplate +from llama_index.prompts.default_prompts import ( DEFAULT_CHOICE_SELECT_PROMPT, - ChoiceSelectPrompt, ) from llama_index.schema import BaseNode, NodeWithScore, MetadataMode @@ -123,7 +123,7 @@ class ListIndexLLMRetriever(BaseRetriever): Args: index (ListIndex): The index to retrieve from. - choice_select_prompt (Optional[ChoiceSelectPrompt]): A Choice-Select Prompt + choice_select_prompt (Optional[PromptTemplate]): A Choice-Select Prompt (see :ref:`Prompt-Templates`).) choice_batch_size (int): The number of nodes to query at a time. format_node_batch_fn (Optional[Callable]): A function that formats a @@ -137,7 +137,7 @@ class ListIndexLLMRetriever(BaseRetriever): def __init__( self, index: ListIndex, - choice_select_prompt: Optional[ChoiceSelectPrompt] = None, + choice_select_prompt: Optional[PromptTemplate] = None, choice_batch_size: int = 10, format_node_batch_fn: Optional[Callable] = None, parse_choice_select_answer_fn: Optional[Callable] = None, diff --git a/llama_index/indices/postprocessor/llm_rerank.py b/llama_index/indices/postprocessor/llm_rerank.py index 6343c455e5d1f18c0858584193d074ab9b58f29e..d22b98aa0eacc2701840c33fee79352ad0049b4d 100644 --- a/llama_index/indices/postprocessor/llm_rerank.py +++ b/llama_index/indices/postprocessor/llm_rerank.py @@ -8,8 +8,8 @@ from llama_index.indices.utils import ( default_format_node_batch_fn, default_parse_choice_select_answer_fn, ) -from llama_index.prompts.choice_select import DEFAULT_CHOICE_SELECT_PROMPT -from llama_index.prompts.prompts import QuestionAnswerPrompt +from llama_index.prompts import BasePromptTemplate +from llama_index.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT from llama_index.schema import NodeWithScore @@ -18,7 +18,7 @@ class LLMRerank(BaseNodePostprocessor): def __init__( self, - choice_select_prompt: Optional[QuestionAnswerPrompt] = None, + choice_select_prompt: Optional[BasePromptTemplate] = None, choice_batch_size: int = 10, format_node_batch_fn: Optional[Callable] = None, parse_choice_select_answer_fn: Optional[Callable] = None, diff --git a/llama_index/indices/postprocessor/node.py b/llama_index/indices/postprocessor/node.py index 2a55fa73ac5ccb960242b0c6d3af7ec727a61293..683f6e5a00f102a09ea4be66765e4311775c1eaf 100644 --- a/llama_index/indices/postprocessor/node.py +++ b/llama_index/indices/postprocessor/node.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field, validator from llama_index.indices.postprocessor.types import BaseNodePostprocessor from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext -from llama_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt +from llama_index.prompts.base import PromptTemplate from llama_index.response_synthesizers import ResponseMode, get_response_synthesizer from llama_index.schema import NodeRelationship, NodeWithScore from llama_index.storage.docstore import BaseDocumentStore @@ -310,10 +310,10 @@ class AutoPrevNextNodePostprocessor(BasePydanticNodePostprocessor): if query_bundle is None: raise ValueError("Missing query bundle.") - infer_prev_next_prompt = QuestionAnswerPrompt( + infer_prev_next_prompt = PromptTemplate( self.infer_prev_next_tmpl, ) - refine_infer_prev_next_prompt = RefinePrompt(self.refine_prev_next_tmpl) + refine_infer_prev_next_prompt = PromptTemplate(self.refine_prev_next_tmpl) all_nodes: Dict[str, NodeWithScore] = {} for node in nodes: diff --git a/llama_index/indices/postprocessor/node_recency.py b/llama_index/indices/postprocessor/node_recency.py index 60db27f2cbdf67c2522182d519f40dbfe67afc69..dd18fa23f3571e1a81886f409345cce4347aa02f 100644 --- a/llama_index/indices/postprocessor/node_recency.py +++ b/llama_index/indices/postprocessor/node_recency.py @@ -1,15 +1,15 @@ """Node recency post-processor.""" -from pydantic import Field -from typing import Optional, List, Set -import pandas as pd -import numpy as np from datetime import datetime +from typing import List, Optional, Set + +import numpy as np +import pandas as pd +from pydantic import Field from llama_index.indices.postprocessor.node import BasePydanticNodePostprocessor from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext -from llama_index.schema import NodeWithScore, MetadataMode - +from llama_index.schema import MetadataMode, NodeWithScore # NOTE: currently not being used # DEFAULT_INFER_RECENCY_TMPL = ( @@ -67,17 +67,6 @@ class FixedRecencyPostprocessor(BasePydanticNodePostprocessor): if query_bundle is None: raise ValueError("Missing query bundle in extra info.") - # query_bundle = cast(QueryBundle, metadata["query_bundle"]) - # infer_recency_prompt = SimpleInputPrompt(self.infer_recency_tmpl) - # raw_pred = self.service_context.llm_predictor.predict( - # prompt=infer_recency_prompt, - # query_str=query_bundle.query_str, - # ) - # pred = parse_recency_pred(raw_pred) - # # if no need to use recency post-processor, return nodes as is - # if not pred: - # return nodes - # sort nodes by date node_dates = pd.to_datetime( [node.node.metadata[self.date_key] for node in nodes] @@ -130,17 +119,6 @@ class EmbeddingRecencyPostprocessor(BasePydanticNodePostprocessor): if query_bundle is None: raise ValueError("Missing query bundle in extra info.") - # query_bundle = cast(QueryBundle, metadata["query_bundle"]) - # infer_recency_prompt = SimpleInputPrompt(self.infer_recency_tmpl) - # raw_pred = self.service_context.llm_predictor.predict( - # prompt=infer_recency_prompt, - # query_str=query_bundle.query_str, - # ) - # pred = parse_recency_pred(raw_pred) - # # if no need to use recency post-processor, return nodes as is - # if not pred: - # return nodes - # sort nodes by date node_dates = pd.to_datetime( [node.node.metadata[self.date_key] for node in nodes] diff --git a/llama_index/indices/postprocessor/pii.py b/llama_index/indices/postprocessor/pii.py index 152932af3b4137efb8acc892dfdc81b9ecb13c6d..adadfacfb48049b2f520afa2056c4aaa2fbe1484 100644 --- a/llama_index/indices/postprocessor/pii.py +++ b/llama_index/indices/postprocessor/pii.py @@ -6,7 +6,7 @@ from typing import List, Optional, Dict, Tuple, Callable from llama_index.indices.postprocessor.node import BasePydanticNodePostprocessor from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext -from llama_index.prompts.prompts import QuestionAnswerPrompt +from llama_index.prompts.base import PromptTemplate from llama_index.schema import NodeWithScore, MetadataMode @@ -56,7 +56,7 @@ class PIINodePostprocessor(BasePydanticNodePostprocessor): def mask_pii(self, text: str) -> Tuple[str, Dict]: """Mask PII in text.""" - pii_prompt = QuestionAnswerPrompt(self.pii_str_tmpl) + pii_prompt = PromptTemplate(self.pii_str_tmpl) # TODO: allow customization task_str = ( "Mask out the PII, replace each PII with a tag, and return the text. " diff --git a/llama_index/indices/prompt_helper.py b/llama_index/indices/prompt_helper.py index 057c18910e2fbed77bb5007dfc3c88ebdc0a487d..4aee137e4b70b8dc09ceed2f087829892f5b2d8c 100644 --- a/llama_index/indices/prompt_helper.py +++ b/llama_index/indices/prompt_helper.py @@ -12,11 +12,12 @@ import logging from pydantic import BaseModel, Field, PrivateAttr from typing import Callable, List, Optional, Sequence +from llama_index.prompts import BasePromptTemplate + from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS from llama_index.llms.openai_utils import is_chat_model from llama_index.llm_predictor.base import LLMMetadata -from llama_index.prompts.base import Prompt -from llama_index.prompts.utils import get_empty_prompt_txt +from llama_index.prompts.prompt_utils import get_empty_prompt_txt from llama_index.text_splitter import TokenTextSplitter from llama_index.text_splitter.utils import truncate_text from llama_index.utils import globals_helper @@ -127,7 +128,7 @@ class PromptHelper(BaseModel): separator=separator, ) - def _get_available_context_size(self, prompt: Prompt) -> int: + def _get_available_context_size(self, prompt: BasePromptTemplate) -> int: """Get available context size. This is calculated as: @@ -142,7 +143,7 @@ class PromptHelper(BaseModel): return self.context_window - num_prompt_tokens - self.num_output def _get_available_chunk_size( - self, prompt: Prompt, num_chunks: int = 1, padding: int = 5 + self, prompt: BasePromptTemplate, num_chunks: int = 1, padding: int = 5 ) -> int: """Get available chunk size. @@ -169,7 +170,10 @@ class PromptHelper(BaseModel): return result def get_text_splitter_given_prompt( - self, prompt: Prompt, num_chunks: int = 1, padding: int = DEFAULT_PADDING + self, + prompt: BasePromptTemplate, + num_chunks: int = 1, + padding: int = DEFAULT_PADDING, ) -> TokenTextSplitter: """Get text splitter configured to maximally pack available context window, taking into account of given prompt, and desired number of chunks. @@ -187,7 +191,10 @@ class PromptHelper(BaseModel): return text_splitter def truncate( - self, prompt: Prompt, text_chunks: Sequence[str], padding: int = DEFAULT_PADDING + self, + prompt: BasePromptTemplate, + text_chunks: Sequence[str], + padding: int = DEFAULT_PADDING, ) -> List[str]: """Truncate text chunks to fit available context window.""" text_splitter = self.get_text_splitter_given_prompt( @@ -198,7 +205,10 @@ class PromptHelper(BaseModel): return [truncate_text(chunk, text_splitter) for chunk in text_chunks] def repack( - self, prompt: Prompt, text_chunks: Sequence[str], padding: int = DEFAULT_PADDING + self, + prompt: BasePromptTemplate, + text_chunks: Sequence[str], + padding: int = DEFAULT_PADDING, ) -> List[str]: """Repack text chunks to fit available context window. diff --git a/llama_index/indices/query/query_transform/base.py b/llama_index/indices/query/query_transform/base.py index d643bf7909cf225a623ca77e714cb45f6a5d6b23..26779911680c52d25e9b35dc3db28fa2b529cdd6 100644 --- a/llama_index/indices/query/query_transform/base.py +++ b/llama_index/indices/query/query_transform/base.py @@ -17,7 +17,7 @@ from llama_index.indices.query.query_transform.prompts import ( from llama_index.indices.query.schema import QueryBundle, QueryType from llama_index.llm_predictor import LLMPredictor from llama_index.llm_predictor.base import BaseLLMPredictor -from llama_index.prompts.base import Prompt +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_HYDE_PROMPT from llama_index.response.schema import Response @@ -87,7 +87,7 @@ class HyDEQueryTransform(BaseQueryTransform): def __init__( self, llm_predictor: Optional[BaseLLMPredictor] = None, - hyde_prompt: Optional[Prompt] = None, + hyde_prompt: Optional[BasePromptTemplate] = None, include_original: bool = True, ) -> None: """Initialize HyDEQueryTransform. @@ -95,7 +95,7 @@ class HyDEQueryTransform(BaseQueryTransform): Args: llm_predictor (Optional[LLMPredictor]): LLM for generating hypothetical documents - hyde_prompt (Optional[Prompt]): Custom prompt for HyDE + hyde_prompt (Optional[BasePromptTemplate]): Custom prompt for HyDE include_original (bool): Whether to include original query string as one of the embedding strings """ diff --git a/llama_index/indices/query/query_transform/feedback_transform.py b/llama_index/indices/query/query_transform/feedback_transform.py index 4ba335c41d15b4b57d373a31d655bb5d32943f74..0874ec9170bed1c27877f66e4fa99947337214bd 100644 --- a/llama_index/indices/query/query_transform/feedback_transform.py +++ b/llama_index/indices/query/query_transform/feedback_transform.py @@ -6,7 +6,7 @@ from llama_index.indices.query.query_transform.base import BaseQueryTransform from llama_index.indices.query.schema import QueryBundle from llama_index.llm_predictor import LLMPredictor from llama_index.llm_predictor.base import BaseLLMPredictor -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.response.schema import Response logger = logging.getLogger(__name__) @@ -22,7 +22,7 @@ DEFAULT_RESYNTHESIS_PROMPT_TMPL = ( "Otherwise, please return the original query.\n" ) -DEFAULT_RESYNTHESIS_PROMPT = Prompt(DEFAULT_RESYNTHESIS_PROMPT_TMPL) +DEFAULT_RESYNTHESIS_PROMPT = PromptTemplate(DEFAULT_RESYNTHESIS_PROMPT_TMPL) class FeedbackQueryTransformation(BaseQueryTransform): @@ -32,7 +32,7 @@ class FeedbackQueryTransformation(BaseQueryTransform): eval(Evaluation): An evaluation object. llm_predictor(BaseLLMPredictor): An LLM predictor. resynthesize_query(bool): Whether to resynthesize the query. - resynthesis_prompt(Prompt): A prompt for resynthesizing the query. + resynthesis_prompt(BasePromptTemplate): A prompt for resynthesizing the query. """ @@ -40,7 +40,7 @@ class FeedbackQueryTransformation(BaseQueryTransform): self, llm_predictor: Optional[BaseLLMPredictor] = None, resynthesize_query: bool = False, - resynthesis_prompt: Optional[Prompt] = None, + resynthesis_prompt: Optional[BasePromptTemplate] = None, ) -> None: super().__init__() self.llm_predictor = llm_predictor or LLMPredictor() diff --git a/llama_index/indices/query/query_transform/prompts.py b/llama_index/indices/query/query_transform/prompts.py index b9632226743a26cb006d68bfdd3d5ad15cfa84a2..fd65cb49c7d85a15e2c6a6ffabd8f3e4e77e2dd3 100644 --- a/llama_index/indices/query/query_transform/prompts.py +++ b/llama_index/indices/query/query_transform/prompts.py @@ -1,35 +1,35 @@ """Query transform prompts.""" -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType # deprecated, kept for backwards compatibility """Decompose prompt for query transformation. -Prompt to "decompose" a query into another query +PromptTemplate to "decompose" a query into another query given the existing context. Required template variables: `context_str`, `query_str` """ -DecomposeQueryTransformPrompt = Prompt +DecomposeQueryTransformPrompt = PromptTemplate """Step Decompose prompt for query transformation. -Prompt to "decompose" a query into another query +PromptTemplate to "decompose" a query into another query given the existing context + previous reasoning (the previous steps). Required template variables: `context_str`, `query_str`, `prev_reasoning` """ -StepDecomposeQueryTransformPrompt = Prompt +StepDecomposeQueryTransformPrompt = PromptTemplate """Image output prompt for query transformation. -Prompt to add instructions for formatting image output. +PromptTemplate to add instructions for formatting image output. Required template variables: `query_str`, `image_width` """ -ImageOutputQueryTransformPrompt = Prompt +ImageOutputQueryTransformPrompt = PromptTemplate DEFAULT_DECOMPOSE_QUERY_TRANSFORM_TMPL = ( @@ -59,7 +59,7 @@ DEFAULT_DECOMPOSE_QUERY_TRANSFORM_TMPL = ( "New question: " ) -DEFAULT_DECOMPOSE_QUERY_TRANSFORM_PROMPT = Prompt( +DEFAULT_DECOMPOSE_QUERY_TRANSFORM_PROMPT = PromptTemplate( DEFAULT_DECOMPOSE_QUERY_TRANSFORM_TMPL, prompt_type=PromptType.DECOMPOSE ) @@ -70,7 +70,7 @@ DEFAULT_IMAGE_OUTPUT_TMPL = ( 'e.g., <image src="data/img.jpg" width="{image_width}" />.' ) -DEFAULT_IMAGE_OUTPUT_PROMPT = Prompt(DEFAULT_IMAGE_OUTPUT_TMPL) +DEFAULT_IMAGE_OUTPUT_PROMPT = PromptTemplate(DEFAULT_IMAGE_OUTPUT_TMPL) DEFAULT_STEP_DECOMPOSE_QUERY_TRANSFORM_TMPL = ( @@ -125,6 +125,6 @@ DEFAULT_STEP_DECOMPOSE_QUERY_TRANSFORM_TMPL = ( "New question: " ) -DEFAULT_STEP_DECOMPOSE_QUERY_TRANSFORM_PROMPT = Prompt( +DEFAULT_STEP_DECOMPOSE_QUERY_TRANSFORM_PROMPT = PromptTemplate( DEFAULT_STEP_DECOMPOSE_QUERY_TRANSFORM_TMPL ) diff --git a/llama_index/indices/service_context.py b/llama_index/indices/service_context.py index 496a09077fb1302a60e62a9d41fc27ca28213c0e..6ae333bf241c009c588a76489d87d6cf3a1023a8 100644 --- a/llama_index/indices/service_context.py +++ b/llama_index/indices/service_context.py @@ -5,6 +5,7 @@ from typing import Optional import llama_index from llama_index.callbacks.base import CallbackManager from llama_index.embeddings.base import BaseEmbedding +from llama_index.embeddings.utils import EmbedType, resolve_embed_model from llama_index.indices.prompt_helper import PromptHelper from llama_index.llm_predictor import LLMPredictor from llama_index.llm_predictor.base import BaseLLMPredictor, LLMMetadata @@ -13,8 +14,7 @@ from llama_index.llms.utils import LLMType, resolve_llm from llama_index.logger import LlamaLogger from llama_index.node_parser.interface import NodeParser from llama_index.node_parser.simple import SimpleNodeParser -from llama_index.prompts.prompts import Prompt -from llama_index.embeddings.utils import resolve_embed_model, EmbedType +from llama_index.prompts.base import BasePromptTemplate logger = logging.getLogger(__name__) @@ -78,7 +78,7 @@ class ServiceContext: llama_logger: Optional[LlamaLogger] = None, callback_manager: Optional[CallbackManager] = None, system_prompt: Optional[str] = None, - query_wrapper_prompt: Optional[Prompt] = None, + query_wrapper_prompt: Optional[BasePromptTemplate] = None, # node parser kwargs chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, @@ -106,7 +106,7 @@ class ServiceContext: callback_manager (Optional[CallbackManager]): CallbackManager system_prompt (Optional[str]): System-wide prompt to be prepended to all input prompts, used to guide system "decision making" - query_wrapper_prompt (Optional[SimpleInputPrompt]): A format to wrap + query_wrapper_prompt (Optional[BasePromptTemplate]): A format to wrap passed-in input queries. Deprecated Args: @@ -184,7 +184,7 @@ class ServiceContext: llama_logger: Optional[LlamaLogger] = None, callback_manager: Optional[CallbackManager] = None, system_prompt: Optional[str] = None, - query_wrapper_prompt: Optional[Prompt] = None, + query_wrapper_prompt: Optional[BasePromptTemplate] = None, # node parser kwargs chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, diff --git a/llama_index/indices/struct_store/base.py b/llama_index/indices/struct_store/base.py index 7bda01f48eb326d4010dddbf36abbe7096b3074d..1657c4197ec97d498fd8a1eede79292757cf9dbd 100644 --- a/llama_index/indices/struct_store/base.py +++ b/llama_index/indices/struct_store/base.py @@ -6,8 +6,8 @@ from typing import Any, Callable, Dict, Generic, Optional, Sequence, TypeVar from llama_index.data_structs.table import BaseStructTable from llama_index.indices.base import BaseIndex from llama_index.indices.service_context import ServiceContext +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_SCHEMA_EXTRACT_PROMPT -from llama_index.prompts.prompts import SchemaExtractPrompt from llama_index.schema import BaseNode from llama_index.storage.docstore.types import RefDocInfo @@ -44,7 +44,7 @@ class BaseStructStoreIndex(BaseIndex[BST], Generic[BST]): nodes: Optional[Sequence[BaseNode]] = None, index_struct: Optional[BST] = None, service_context: Optional[ServiceContext] = None, - schema_extract_prompt: Optional[SchemaExtractPrompt] = None, + schema_extract_prompt: Optional[BasePromptTemplate] = None, output_parser: Optional[OUTPUT_PARSER_TYPE] = None, **kwargs: Any, ) -> None: diff --git a/llama_index/indices/struct_store/json_query.py b/llama_index/indices/struct_store/json_query.py index 9d1486037c9e851c2ed740c24c76f3273cc1a217..0a8162e43d56f541fcf3dd38a220a3ed652e27c7 100644 --- a/llama_index/indices/struct_store/json_query.py +++ b/llama_index/indices/struct_store/json_query.py @@ -7,7 +7,7 @@ from llama_index.bridge.langchain import print_text from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext -from llama_index.prompts.base import Prompt +from llama_index.prompts import PromptTemplate, BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_JSON_PATH_PROMPT from llama_index.prompts.prompt_type import PromptType from llama_index.response.schema import Response @@ -31,7 +31,7 @@ DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL = ( "Query: {query_str}\n" "Response: " ) -DEFAULT_RESPONSE_SYNTHESIS_PROMPT = Prompt( +DEFAULT_RESPONSE_SYNTHESIS_PROMPT = PromptTemplate( DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL, prompt_type=PromptType.SQL_RESPONSE_SYNTHESIS, ) @@ -58,7 +58,7 @@ class JSONQueryEngine(BaseQueryEngine): json_value (JSONType): JSON value json_schema (JSONType): JSON schema service_context (ServiceContext): ServiceContext - json_path_prompt (Prompt): The JSON Path prompt to use. + json_path_prompt (BasePromptTemplate): The JSON Path prompt to use. output_processor (Callable): The output processor that executes the JSON Path query. output_kwargs (dict): Additional output processor kwargs for the @@ -71,11 +71,11 @@ class JSONQueryEngine(BaseQueryEngine): json_value: JSONType, json_schema: JSONType, service_context: ServiceContext, - json_path_prompt: Optional[Prompt] = None, + json_path_prompt: Optional[BasePromptTemplate] = None, output_processor: Optional[Callable] = None, output_kwargs: Optional[dict] = None, synthesize_response: bool = True, - response_synthesis_prompt: Optional[Prompt] = None, + response_synthesis_prompt: Optional[BasePromptTemplate] = None, verbose: bool = False, **kwargs: Any, ) -> None: diff --git a/llama_index/indices/struct_store/sql_query.py b/llama_index/indices/struct_store/sql_query.py index 89e8acaf00ec614075838e26e02f169d51ae5d59..33aefc4e300fceae6fe3c1baa491c771c66e3cd4 100644 --- a/llama_index/indices/struct_store/sql_query.py +++ b/llama_index/indices/struct_store/sql_query.py @@ -15,7 +15,7 @@ from llama_index.indices.struct_store.sql import SQLStructStoreIndex from llama_index.langchain_helpers.sql_wrapper import SQLDatabase from llama_index.objects.base import ObjectRetriever from llama_index.objects.table_node_mapping import SQLTableSchema -from llama_index.prompts.base import Prompt +from llama_index.prompts import BasePromptTemplate, PromptTemplate from llama_index.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT from llama_index.prompts.prompt_type import PromptType from llama_index.response.schema import Response @@ -30,7 +30,7 @@ DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL = ( "SQL Response: {sql_response_str}\n" "Response: " ) -DEFAULT_RESPONSE_SYNTHESIS_PROMPT = Prompt( +DEFAULT_RESPONSE_SYNTHESIS_PROMPT = PromptTemplate( DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL, prompt_type=PromptType.SQL_RESPONSE_SYNTHESIS, ) @@ -85,24 +85,25 @@ class NLStructStoreQueryEngine(BaseQueryEngine): Args: index (SQLStructStoreIndex): A SQL Struct Store Index - text_to_sql_prompt (Optional[Prompt]): A Text to SQL Prompt - to use for the query. Defaults to DEFAULT_TEXT_TO_SQL_PROMPT. + text_to_sql_prompt (Optional[BasePromptTemplate]): A Text to SQL + BasePromptTemplate to use for the query. + Defaults to DEFAULT_TEXT_TO_SQL_PROMPT. context_query_kwargs (Optional[dict]): Keyword arguments for the context query. Defaults to {}. synthesize_response (bool): Whether to synthesize a response from the query results. Defaults to True. - response_synthesis_prompt (Optional[Prompt]): A - Response Synthesis Prompt to use for the query. Defaults to + response_synthesis_prompt (Optional[BasePromptTemplate]): A + Response Synthesis BasePromptTemplate to use for the query. Defaults to DEFAULT_RESPONSE_SYNTHESIS_PROMPT. """ def __init__( self, index: SQLStructStoreIndex, - text_to_sql_prompt: Optional[Prompt] = None, + text_to_sql_prompt: Optional[BasePromptTemplate] = None, context_query_kwargs: Optional[dict] = None, synthesize_response: bool = True, - response_synthesis_prompt: Optional[Prompt] = None, + response_synthesis_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -215,10 +216,10 @@ class BaseSQLTableQueryEngine(BaseQueryEngine): def __init__( self, sql_database: SQLDatabase, - text_to_sql_prompt: Optional[Prompt] = None, + text_to_sql_prompt: Optional[BasePromptTemplate] = None, context_query_kwargs: Optional[dict] = None, synthesize_response: bool = True, - response_synthesis_prompt: Optional[Prompt] = None, + response_synthesis_prompt: Optional[BasePromptTemplate] = None, service_context: Optional[ServiceContext] = None, **kwargs: Any, ) -> None: @@ -321,10 +322,10 @@ class NLSQLTableQueryEngine(BaseSQLTableQueryEngine): def __init__( self, sql_database: SQLDatabase, - text_to_sql_prompt: Optional[Prompt] = None, + text_to_sql_prompt: Optional[BasePromptTemplate] = None, context_query_kwargs: Optional[dict] = None, synthesize_response: bool = True, - response_synthesis_prompt: Optional[Prompt] = None, + response_synthesis_prompt: Optional[BasePromptTemplate] = None, tables: Optional[Union[List[str], List[Table]]] = None, service_context: Optional[ServiceContext] = None, **kwargs: Any, @@ -389,10 +390,10 @@ class SQLTableRetrieverQueryEngine(BaseSQLTableQueryEngine): self, sql_database: SQLDatabase, table_retriever: ObjectRetriever[SQLTableSchema], - text_to_sql_prompt: Optional[Prompt] = None, + text_to_sql_prompt: Optional[BasePromptTemplate] = None, context_query_kwargs: Optional[dict] = None, synthesize_response: bool = True, - response_synthesis_prompt: Optional[Prompt] = None, + response_synthesis_prompt: Optional[BasePromptTemplate] = None, service_context: Optional[ServiceContext] = None, context_str_prefix: Optional[str] = None, **kwargs: Any, diff --git a/llama_index/indices/tree/all_leaf_retriever.py b/llama_index/indices/tree/all_leaf_retriever.py index 85083b0d908b7653376d47912bc9ca7e008a3efb..91d208c2d2755e9597a5ba465f629eea56238d27 100644 --- a/llama_index/indices/tree/all_leaf_retriever.py +++ b/llama_index/indices/tree/all_leaf_retriever.py @@ -23,7 +23,7 @@ class TreeAllLeafRetriever(BaseRetriever): when initialized, since we rebuild the tree for each query. Args: - text_qa_template (Optional[QuestionAnswerPrompt]): Question-Answer Prompt + text_qa_template (Optional[BasePromptTemplate]): Question-Answer Prompt (see :ref:`Prompt-Templates`). """ diff --git a/llama_index/indices/tree/base.py b/llama_index/indices/tree/base.py index 67ac0ffa6e5885a959332165e5eafdfa3d46299f..bfd8da7f4ea0b5579de77a589b6570ee87bbf56b 100644 --- a/llama_index/indices/tree/base.py +++ b/llama_index/indices/tree/base.py @@ -10,11 +10,11 @@ from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.common_tree.base import GPTTreeIndexBuilder from llama_index.indices.service_context import ServiceContext from llama_index.indices.tree.inserter import TreeIndexInserter +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import ( DEFAULT_INSERT_PROMPT, DEFAULT_SUMMARY_PROMPT, ) -from llama_index.prompts.prompts import SummaryPrompt, TreeInsertPrompt from llama_index.schema import BaseNode from llama_index.storage.docstore.types import RefDocInfo @@ -45,9 +45,9 @@ class TreeIndex(BaseIndex[IndexGraph]): A secondary answer is to directly synthesize the answer from the root nodes. Args: - summary_template (Optional[SummaryPrompt]): A Summarization Prompt + summary_template (Optional[BasePromptTemplate]): A Summarization Prompt (see :ref:`Prompt-Templates`). - insert_prompt (Optional[TreeInsertPrompt]): An Tree Insertion Prompt + insert_prompt (Optional[BasePromptTemplate]): An Tree Insertion Prompt (see :ref:`Prompt-Templates`). num_children (int): The number of children each node should have. build_tree (bool): Whether to build the tree during index construction. @@ -62,8 +62,8 @@ class TreeIndex(BaseIndex[IndexGraph]): nodes: Optional[Sequence[BaseNode]] = None, index_struct: Optional[IndexGraph] = None, service_context: Optional[ServiceContext] = None, - summary_template: Optional[SummaryPrompt] = None, - insert_prompt: Optional[TreeInsertPrompt] = None, + summary_template: Optional[BasePromptTemplate] = None, + insert_prompt: Optional[BasePromptTemplate] = None, num_children: int = 10, build_tree: bool = True, use_async: bool = False, @@ -74,7 +74,7 @@ class TreeIndex(BaseIndex[IndexGraph]): # need to set parameters before building index in base class. self.num_children = num_children self.summary_template = summary_template or DEFAULT_SUMMARY_PROMPT - self.insert_prompt: TreeInsertPrompt = insert_prompt or DEFAULT_INSERT_PROMPT + self.insert_prompt: BasePromptTemplate = insert_prompt or DEFAULT_INSERT_PROMPT self.build_tree = build_tree self._use_async = use_async super().__init__( diff --git a/llama_index/indices/tree/inserter.py b/llama_index/indices/tree/inserter.py index 879061a7efec85d2c8060341e11296daeb5bf97e..0e2882395546d8089f515115c4dd2c5d381bb9c2 100644 --- a/llama_index/indices/tree/inserter.py +++ b/llama_index/indices/tree/inserter.py @@ -11,7 +11,7 @@ from llama_index.indices.utils import ( extract_numbers_given_response, get_sorted_node_list, ) -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import BasePromptTemplate from llama_index.prompts.default_prompts import ( DEFAULT_INSERT_PROMPT, DEFAULT_SUMMARY_PROMPT, @@ -27,8 +27,8 @@ class TreeIndexInserter: index_graph: IndexGraph, service_context: ServiceContext, num_children: int = 10, - insert_prompt: Prompt = DEFAULT_INSERT_PROMPT, - summary_prompt: Prompt = DEFAULT_SUMMARY_PROMPT, + insert_prompt: BasePromptTemplate = DEFAULT_INSERT_PROMPT, + summary_prompt: BasePromptTemplate = DEFAULT_SUMMARY_PROMPT, docstore: Optional[BaseDocumentStore] = None, ) -> None: """Initialize with params.""" diff --git a/llama_index/indices/tree/select_leaf_embedding_retriever.py b/llama_index/indices/tree/select_leaf_embedding_retriever.py index 009cc3285efcac55bd8b159f79e74e2405d2f68b..94031f7f943d8745d9c59d5b0ad5ba16c2200f09 100644 --- a/llama_index/indices/tree/select_leaf_embedding_retriever.py +++ b/llama_index/indices/tree/select_leaf_embedding_retriever.py @@ -3,11 +3,8 @@ import logging from typing import Dict, List, Tuple, cast - from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.tree.select_leaf_retriever import ( - TreeSelectLeafRetriever, -) +from llama_index.indices.tree.select_leaf_retriever import TreeSelectLeafRetriever from llama_index.indices.utils import get_sorted_node_list from llama_index.schema import BaseNode, MetadataMode @@ -21,14 +18,14 @@ class TreeSelectLeafEmbeddingRetriever(TreeSelectLeafRetriever): query and the node text. Args: - query_template (Optional[TreeSelectPrompt]): Tree Select Query Prompt + query_template (Optional[BasePromptTemplate]): Tree Select Query Prompt (see :ref:`Prompt-Templates`). - query_template_multiple (Optional[TreeSelectMultiplePrompt]): Tree Select + query_template_multiple (Optional[BasePromptTemplate]): Tree Select Query Prompt (Multiple) (see :ref:`Prompt-Templates`). - text_qa_template (Optional[QuestionAnswerPrompt]): Question-Answer Prompt + text_qa_template (Optional[BasePromptTemplate]): Question-Answer Prompt (see :ref:`Prompt-Templates`). - refine_template (Optional[RefinePrompt]): Refinement Prompt + refine_template (Optional[BasePromptTemplate]): Refinement Prompt (see :ref:`Prompt-Templates`). child_branch_factor (int): Number of child nodes to consider at each level. If child_branch_factor is 1, then the query will only choose one child node diff --git a/llama_index/indices/tree/select_leaf_retriever.py b/llama_index/indices/tree/select_leaf_retriever.py index 6797728c9ec306e9e282a9f19627c8a908540e33..f921ec8d5c44e03ca86dfde4eb9a65b69eb3401e 100644 --- a/llama_index/indices/tree/select_leaf_retriever.py +++ b/llama_index/indices/tree/select_leaf_retriever.py @@ -4,7 +4,6 @@ import logging from typing import Any, Dict, List, Optional, cast from llama_index.bridge.langchain import print_text - from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.query.schema import QueryBundle from llama_index.indices.tree.base import TreeIndex @@ -13,21 +12,16 @@ from llama_index.indices.utils import ( extract_numbers_given_response, get_sorted_node_list, ) +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import DEFAULT_REFINE_PROMPT_SEL from llama_index.prompts.default_prompts import ( DEFAULT_QUERY_PROMPT, DEFAULT_QUERY_PROMPT_MULTIPLE, DEFAULT_TEXT_QA_PROMPT, ) -from llama_index.prompts.prompts import ( - QuestionAnswerPrompt, - RefinePrompt, - TreeSelectMultiplePrompt, - TreeSelectPrompt, -) from llama_index.response.schema import Response from llama_index.response_synthesizers import get_response_synthesizer -from llama_index.schema import BaseNode, NodeWithScore, MetadataMode +from llama_index.schema import BaseNode, MetadataMode, NodeWithScore from llama_index.utils import truncate_text logger = logging.getLogger(__name__) @@ -57,9 +51,9 @@ class TreeSelectLeafRetriever(BaseRetriever): answer the query. Args: - query_template (Optional[TreeSelectPrompt]): Tree Select Query Prompt + query_template (Optional[BasePromptTemplate]): Tree Select Query Prompt (see :ref:`Prompt-Templates`). - query_template_multiple (Optional[TreeSelectMultiplePrompt]): Tree Select + query_template_multiple (Optional[BasePromptTemplate]): Tree Select Query Prompt (Multiple) (see :ref:`Prompt-Templates`). child_branch_factor (int): Number of child nodes to consider at each level. @@ -72,10 +66,10 @@ class TreeSelectLeafRetriever(BaseRetriever): def __init__( self, index: TreeIndex, - query_template: Optional[TreeSelectPrompt] = None, - text_qa_template: Optional[QuestionAnswerPrompt] = None, - refine_template: Optional[RefinePrompt] = None, - query_template_multiple: Optional[TreeSelectMultiplePrompt] = None, + query_template: Optional[BasePromptTemplate] = None, + text_qa_template: Optional[BasePromptTemplate] = None, + refine_template: Optional[BasePromptTemplate] = None, + query_template_multiple: Optional[BasePromptTemplate] = None, child_branch_factor: int = 1, verbose: bool = False, **kwargs: Any, diff --git a/llama_index/indices/vector_store/retrievers/auto_retriever/prompts.py b/llama_index/indices/vector_store/retrievers/auto_retriever/prompts.py index 7f3c7e2aad86afc0f2eb44a7acc954b433215a61..37f1ae00bbdc184efe853b1cc9b75241004836bf 100644 --- a/llama_index/indices/vector_store/retrievers/auto_retriever/prompts.py +++ b/llama_index/indices/vector_store/retrievers/auto_retriever/prompts.py @@ -1,7 +1,7 @@ """Autoretriever prompts.""" -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.vector_stores.types import ( ExactMatchFilter, @@ -103,9 +103,9 @@ DEFAULT_VECTOR_STORE_QUERY_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX # deprecated, kept for backwards compatibility """Vector store query prompt.""" -VectorStoreQueryPrompt = Prompt +VectorStoreQueryPrompt = PromptTemplate -DEFAULT_VECTOR_STORE_QUERY_PROMPT = Prompt( +DEFAULT_VECTOR_STORE_QUERY_PROMPT = PromptTemplate( template=DEFAULT_VECTOR_STORE_QUERY_PROMPT_TMPL, prompt_type=PromptType.VECTOR_STORE_QUERY, ) diff --git a/llama_index/llm_predictor/base.py b/llama_index/llm_predictor/base.py index d65cbb7cb2c4a87dd0aec626d98e71771d463347..0949da7691be9641092e02b63c19aa2e6d284d2e 100644 --- a/llama_index/llm_predictor/base.py +++ b/llama_index/llm_predictor/base.py @@ -1,10 +1,11 @@ """Wrapper functions around an LLM chain.""" import logging -from abc import abstractmethod, ABC -from pydantic import BaseModel, PrivateAttr +from abc import ABC, abstractmethod from typing import Any, List, Optional +from pydantic import BaseModel, PrivateAttr + from llama_index.callbacks.base import CallbackManager from llama_index.llm_predictor.utils import ( astream_chat_response_to_tokens, @@ -15,8 +16,12 @@ from llama_index.llm_predictor.utils import ( from llama_index.llms.base import LLM, ChatMessage, LLMMetadata, MessageRole from llama_index.llms.generic_utils import messages_to_prompt from llama_index.llms.utils import LLMType, resolve_llm -from llama_index.prompts.base import Prompt -from llama_index.prompts.prompts import SimpleInputPrompt +from llama_index.prompts.base import ( + BasePromptTemplate, + ChatPromptTemplate, + PromptTemplate, + SelectorPromptTemplate, +) from llama_index.types import TokenAsyncGen, TokenGen logger = logging.getLogger(__name__) @@ -36,19 +41,21 @@ class BaseLLMPredictor(BaseModel, ABC): """Get LLM metadata.""" @abstractmethod - def predict(self, prompt: Prompt, **prompt_args: Any) -> str: + def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Predict the answer to a query.""" @abstractmethod - def stream(self, prompt: Prompt, **prompt_args: Any) -> TokenGen: + def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen: """Stream the answer to a query.""" @abstractmethod - async def apredict(self, prompt: Prompt, **prompt_args: Any) -> str: + async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Async predict the answer to a query.""" @abstractmethod - async def astream(self, prompt: Prompt, **prompt_args: Any) -> TokenAsyncGen: + async def astream( + self, prompt: BasePromptTemplate, **prompt_args: Any + ) -> TokenAsyncGen: """Async predict the answer to a query.""" @@ -67,7 +74,7 @@ class LLMPredictor(BaseLLMPredictor): arbitrary_types_allowed = True system_prompt: Optional[str] - query_wrapper_prompt: Optional[Prompt] + query_wrapper_prompt: Optional[BasePromptTemplate] _llm: LLM = PrivateAttr() def __init__( @@ -75,7 +82,7 @@ class LLMPredictor(BaseLLMPredictor): llm: Optional[LLMType] = None, callback_manager: Optional[CallbackManager] = None, system_prompt: Optional[str] = None, - query_wrapper_prompt: Optional[SimpleInputPrompt] = None, + query_wrapper_prompt: Optional[BasePromptTemplate] = None, ) -> None: """Initialize params.""" self._llm = resolve_llm(llm) @@ -97,7 +104,7 @@ class LLMPredictor(BaseLLMPredictor): """Get LLM metadata.""" return self._llm.metadata - def predict(self, prompt: Prompt, **prompt_args: Any) -> str: + def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Predict.""" if self._llm.metadata.is_chat_model: messages = prompt.format_messages(llm=self._llm, **prompt_args) @@ -116,7 +123,7 @@ class LLMPredictor(BaseLLMPredictor): return output - def stream(self, prompt: Prompt, **prompt_args: Any) -> TokenGen: + def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen: """Stream.""" if self._llm.metadata.is_chat_model: messages = prompt.format_messages(llm=self._llm, **prompt_args) @@ -130,7 +137,7 @@ class LLMPredictor(BaseLLMPredictor): stream_tokens = stream_completion_response_to_tokens(stream_response) return stream_tokens - async def apredict(self, prompt: Prompt, **prompt_args: Any) -> str: + async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Async predict.""" if self._llm.metadata.is_chat_model: messages = prompt.format_messages(llm=self._llm, **prompt_args) @@ -149,7 +156,9 @@ class LLMPredictor(BaseLLMPredictor): return output - async def astream(self, prompt: Prompt, **prompt_args: Any) -> TokenAsyncGen: + async def astream( + self, prompt: BasePromptTemplate, **prompt_args: Any + ) -> TokenAsyncGen: """Async stream.""" if self._llm.metadata.is_chat_model: messages = prompt.format_messages(llm=self._llm, **prompt_args) @@ -163,18 +172,40 @@ class LLMPredictor(BaseLLMPredictor): stream_tokens = await astream_completion_response_to_tokens(stream_response) return stream_tokens - def _extend_prompt(self, prompt: Prompt) -> Prompt: + def _extend_prompt(self, prompt: BasePromptTemplate) -> BasePromptTemplate: """Add system and query wrapper prompts to base prompt""" + # TODO: avoid mutating prompt attributes if self.system_prompt: - prompt.prompt_selector.default_prompt.template = ( - self.system_prompt - + "\n\n" - + prompt.prompt_selector.default_prompt.template - ) + if isinstance(prompt, SelectorPromptTemplate): + default_template = prompt.default_template + if isinstance(default_template, PromptTemplate): + default_template.template = ( + self.system_prompt + "\n\n" + default_template.template + ) + else: + raise ValueError("PromptTemplate expected as default_template") + elif isinstance(prompt, ChatPromptTemplate): + prompt.message_templates = [ + ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt) + ] + prompt.message_templates + elif isinstance(prompt, PromptTemplate): + prompt.template = self.system_prompt + "\n\n" + prompt.template + if self.query_wrapper_prompt: - prompt.partial_dict["query_str"] = self.query_wrapper_prompt.format( - query_str=prompt.partial_dict["query_str"] - ) + if isinstance(prompt, (PromptTemplate, ChatPromptTemplate)): + prompt.kwargs["query_str"] = self.query_wrapper_prompt.format( + query_str=prompt.kwargs["query_str"] + ) + elif isinstance(prompt, SelectorPromptTemplate): + if isinstance(default_template, PromptTemplate): + prompt.default_template.kwargs[ + "query_str" + ] = self.query_wrapper_prompt.format( + query_str=prompt.default_template.kwargs["query_str"] + ) + else: + raise ValueError("PromptTemplate expected as default_template") + return prompt def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: diff --git a/llama_index/llm_predictor/mock.py b/llama_index/llm_predictor/mock.py index 37e442ce80f2aabbf47459015b79911f3ae3b828..42ff863c30fbecfd52bda0fe7e9e90c685fab931 100644 --- a/llama_index/llm_predictor/mock.py +++ b/llama_index/llm_predictor/mock.py @@ -1,11 +1,12 @@ """Mock LLM Predictor.""" -from pydantic import Field from typing import Any, Dict +from pydantic import Field + from llama_index.constants import DEFAULT_NUM_OUTPUTS from llama_index.llm_predictor.base import BaseLLMPredictor -from llama_index.llms.base import LLMMetadata, LLM -from llama_index.prompts.base import Prompt +from llama_index.llms.base import LLM, LLMMetadata +from llama_index.prompts.base import BasePromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.token_counter.utils import ( mock_extract_keywords_response, @@ -49,13 +50,13 @@ def _mock_answer(max_tokens: int, prompt_args: Dict) -> str: return " ".join(["answer"] * token_limit) -def _mock_refine(max_tokens: int, prompt: Prompt, prompt_args: Dict) -> str: +def _mock_refine(max_tokens: int, prompt: BasePromptTemplate, prompt_args: Dict) -> str: """Mock refine.""" # tokens in response shouldn't be larger than tokens in # `existing_answer` + `context_msg` # NOTE: if existing_answer is not in prompt_args, we need to get it from the prompt if "existing_answer" not in prompt_args: - existing_answer = prompt.partial_dict["existing_answer"] + existing_answer = prompt.kwargs["existing_answer"] else: existing_answer = prompt_args["existing_answer"] num_ctx_tokens = len(globals_helper.tokenizer(prompt_args["context_msg"])) @@ -96,10 +97,10 @@ class MockLLMPredictor(BaseLLMPredictor): def llm(self) -> LLM: raise NotImplementedError("MockLLMPredictor does not have an LLM model.") - def predict(self, prompt: Prompt, **prompt_args: Any) -> str: + def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Mock predict.""" - prompt_str = prompt.prompt_type + prompt_str = prompt.metadata["prompt_type"] if prompt_str == PromptType.SUMMARY: output = _mock_summary_predict(self.max_tokens, prompt_args) elif prompt_str == PromptType.TREE_INSERT: @@ -118,7 +119,7 @@ class MockLLMPredictor(BaseLLMPredictor): output = _mock_query_keyword_extract(prompt_args) elif prompt_str == PromptType.KNOWLEDGE_TRIPLET_EXTRACT: output = _mock_knowledge_graph_triplet_extract( - prompt_args, prompt.partial_dict.get("max_knowledge_triplets", 2) + prompt_args, int(prompt.kwargs.get("max_knowledge_triplets", 2)) ) elif prompt_str == PromptType.CUSTOM: # we don't know specific prompt type, return generic response @@ -128,11 +129,13 @@ class MockLLMPredictor(BaseLLMPredictor): return output - def stream(self, prompt: Prompt, **prompt_args: Any) -> TokenGen: + def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen: raise NotImplementedError - async def apredict(self, prompt: Prompt, **prompt_args: Any) -> str: + async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: return self.predict(prompt, **prompt_args) - async def astream(self, prompt: Prompt, **prompt_args: Any) -> TokenAsyncGen: + async def astream( + self, prompt: BasePromptTemplate, **prompt_args: Any + ) -> TokenAsyncGen: raise NotImplementedError diff --git a/llama_index/llm_predictor/structured.py b/llama_index/llm_predictor/structured.py index 76eb862e91a870b9a9daa22f8ee7a55c4cc94ffc..9dfcdb2e3618a0868026a857cd1b793307dd5789 100644 --- a/llama_index/llm_predictor/structured.py +++ b/llama_index/llm_predictor/structured.py @@ -5,7 +5,7 @@ import logging from typing import Any from llama_index.llm_predictor.base import LLMPredictor -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import BasePromptTemplate from llama_index.types import TokenGen logger = logging.getLogger(__name__) @@ -19,11 +19,11 @@ class StructuredLLMPredictor(LLMPredictor): """ - def predict(self, prompt: Prompt, **prompt_args: Any) -> str: + def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Predict the answer to a query. Args: - prompt (Prompt): Prompt to use for prediction. + prompt (BasePromptTemplate): BasePromptTemplate to use for prediction. Returns: Tuple[str, str]: Tuple of the predicted answer and the formatted prompt. @@ -33,20 +33,21 @@ class StructuredLLMPredictor(LLMPredictor): # run output parser if prompt.output_parser is not None: # TODO: return other formats - parsed_llm_prediction = str(prompt.output_parser.parse(llm_prediction)) + output_parser = prompt.output_parser + parsed_llm_prediction = str(output_parser.parse(llm_prediction)) else: parsed_llm_prediction = llm_prediction return parsed_llm_prediction - def stream(self, prompt: Prompt, **prompt_args: Any) -> TokenGen: + def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen: """Stream the answer to a query. NOTE: this is a beta feature. Will try to build or use better abstractions about response handling. Args: - prompt (Prompt): Prompt to use for prediction. + prompt (BasePromptTemplate): BasePromptTemplate to use for prediction. Returns: str: The predicted answer. @@ -56,11 +57,11 @@ class StructuredLLMPredictor(LLMPredictor): "Streaming is not supported for structured LLM predictor." ) - async def apredict(self, prompt: Prompt, **prompt_args: Any) -> str: + async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Async predict the answer to a query. Args: - prompt (Prompt): Prompt to use for prediction. + prompt (BasePromptTemplate): BasePromptTemplate to use for prediction. Returns: Tuple[str, str]: Tuple of the predicted answer and the formatted prompt. @@ -68,7 +69,8 @@ class StructuredLLMPredictor(LLMPredictor): """ llm_prediction = await super().apredict(prompt, **prompt_args) if prompt.output_parser is not None: - parsed_llm_prediction = str(prompt.output_parser.parse(llm_prediction)) + output_parser = prompt.output_parser + parsed_llm_prediction = str(output_parser.parse(llm_prediction)) else: parsed_llm_prediction = llm_prediction return parsed_llm_prediction diff --git a/llama_index/llm_predictor/vellum/predictor.py b/llama_index/llm_predictor/vellum/predictor.py index 1cf6a05f5460649ee63b55890cfbdde4c1efc7e7..b262c49606b9a00b49db478f90b7220c0fbc7453 100644 --- a/llama_index/llm_predictor/vellum/predictor.py +++ b/llama_index/llm_predictor/vellum/predictor.py @@ -3,7 +3,7 @@ from __future__ import annotations from pydantic import Field, PrivateAttr from typing import Any, Optional, Tuple, cast -from llama_index import Prompt +from llama_index.prompts import BasePromptTemplate from llama_index.callbacks import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.llm_predictor.base import BaseLLMPredictor, LLMMetadata, LLM @@ -64,7 +64,7 @@ class VellumPredictor(BaseLLMPredictor): """Get the LLM.""" raise NotImplementedError("Vellum does not expose the LLM.") - def predict(self, prompt: Prompt, **prompt_args: Any) -> str: + def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Predict the answer to a query.""" from vellum import GenerateRequest @@ -73,11 +73,13 @@ class VellumPredictor(BaseLLMPredictor): prompt, **prompt_args ) + input_values = { + **prompt.kwargs, + **prompt_args, + } result = self._vellum_client.generate( deployment_id=registered_prompt.deployment_id, - requests=[ - GenerateRequest(input_values=prompt.get_full_format_args(prompt_args)) - ], + requests=[GenerateRequest(input_values=input_values)], ) completion_text = self._process_generate_response( @@ -86,7 +88,7 @@ class VellumPredictor(BaseLLMPredictor): return completion_text - def stream(self, prompt: Prompt, **prompt_args: Any) -> TokenGen: + def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen: """Stream the answer to a query.""" from vellum import GenerateRequest, GenerateStreamResult @@ -95,11 +97,13 @@ class VellumPredictor(BaseLLMPredictor): prompt, **prompt_args ) + input_values = { + **prompt.kwargs, + **prompt_args, + } responses = self._vellum_client.generate_stream( deployment_id=registered_prompt.deployment_id, - requests=[ - GenerateRequest(input_values=prompt.get_full_format_args(prompt_args)) - ], + requests=[GenerateRequest(input_values=input_values)], ) def text_generator() -> TokenGen: @@ -135,7 +139,7 @@ class VellumPredictor(BaseLLMPredictor): return text_generator() - async def apredict(self, prompt: Prompt, **prompt_args: Any) -> str: + async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Asynchronously predict the answer to a query.""" from vellum import GenerateRequest @@ -144,11 +148,13 @@ class VellumPredictor(BaseLLMPredictor): prompt, **prompt_args ) + input_values = { + **prompt.kwargs, + **prompt_args, + } result = await self._async_vellum_client.generate( deployment_id=registered_prompt.deployment_id, - requests=[ - GenerateRequest(input_values=prompt.get_full_format_args(prompt_args)) - ], + requests=[GenerateRequest(input_values=input_values)], ) completion_text = self._process_generate_response( @@ -157,7 +163,9 @@ class VellumPredictor(BaseLLMPredictor): return completion_text - async def astream(self, prompt: Prompt, **prompt_args: Any) -> TokenAsyncGen: + async def astream( + self, prompt: BasePromptTemplate, **prompt_args: Any + ) -> TokenAsyncGen: async def gen() -> TokenAsyncGen: for token in self.stream(prompt, **prompt_args): yield token @@ -166,7 +174,7 @@ class VellumPredictor(BaseLLMPredictor): return gen() def _prepare_generate_call( - self, prompt: Prompt, **prompt_args: Any + self, prompt: BasePromptTemplate, **prompt_args: Any ) -> Tuple[VellumRegisteredPrompt, VellumCompiledPrompt, str]: """Prepare a generate call.""" diff --git a/llama_index/llm_predictor/vellum/prompt_registry.py b/llama_index/llm_predictor/vellum/prompt_registry.py index 4421c75a92f654867709f8eff98cfe480c012cfe..ed7b9cb1553ca658931aae2d50441366d0c7f808 100644 --- a/llama_index/llm_predictor/vellum/prompt_registry.py +++ b/llama_index/llm_predictor/vellum/prompt_registry.py @@ -3,12 +3,13 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, List, Tuple from uuid import uuid4 -from llama_index import Prompt from llama_index.llm_predictor.vellum.types import ( VellumCompiledPrompt, VellumRegisteredPrompt, ) from llama_index.llm_predictor.vellum.utils import convert_to_kebab_case +from llama_index.prompts import BasePromptTemplate +from llama_index.prompts.base import PromptTemplate if TYPE_CHECKING: import vellum # noqa: F401 @@ -33,7 +34,7 @@ class VellumPromptRegistry: self._vellum_client = Vellum(api_key=vellum_api_key) - def from_prompt(self, initial_prompt: Prompt) -> VellumRegisteredPrompt: + def from_prompt(self, initial_prompt: BasePromptTemplate) -> VellumRegisteredPrompt: """Accepts a LlamaIndex prompt and retrieves a corresponding registered prompt from Vellum. @@ -46,7 +47,7 @@ class VellumPromptRegistry: You can reference a previously registered prompt by providing either `vellum_deployment_id` or `vellum_deployment_name` as key/value pairs within - `Prompt.metadata`. + `BasePromptTemplate.metadata`. """ from vellum.core import ApiError @@ -112,7 +113,7 @@ class VellumPromptRegistry: prompt_id=prompt_id, ) - def _register_prompt(self, prompt: Prompt) -> VellumRegisteredPrompt: + def _register_prompt(self, prompt: BasePromptTemplate) -> VellumRegisteredPrompt: """Registers a prompt with Vellum. By registering a prompt, Vellum will: @@ -151,7 +152,7 @@ class VellumPromptRegistry: parameters=params, meta={ "source": "llamaindex", - "prompt_type": prompt.prompt_type, + "prompt_type": prompt.metadata["prompt_type"], }, ) @@ -164,22 +165,24 @@ class VellumPromptRegistry: prompt_id=resp.prompt.id, ) - def _generate_default_label(self, prompt: Prompt) -> str: - return f"LlamaIndex Demo: {prompt.prompt_type}" + def _generate_default_label(self, prompt: BasePromptTemplate) -> str: + prompt_type = prompt.metadata["prompt_type"] + return f"LlamaIndex Demo: {prompt_type}'" - def _generate_default_name(self, prompt: Prompt) -> str: + def _generate_default_name(self, prompt: BasePromptTemplate) -> str: default_label = self._generate_default_label(prompt) return convert_to_kebab_case(default_label) def _construct_prompt_info( - self, prompt: Prompt, for_chat_model: bool = True + self, prompt: BasePromptTemplate, for_chat_model: bool = True ) -> vellum.RegisterPromptPromptInfoRequest: """Converts a LlamaIndex prompt into Vellum's prompt representation.""" import vellum - prompt_template = prompt.original_template - for input_variable in prompt.get_langchain_prompt().input_variables: + assert isinstance(prompt, PromptTemplate) + prompt_template = prompt.template + for input_variable in prompt.template_vars: prompt_template = prompt_template.replace( input_variable, f"{{ {input_variable} }}" ) @@ -190,8 +193,8 @@ class VellumPromptRegistry: block_type=vellum.BlockTypeEnum.JINJA, properties=vellum.PromptTemplateBlockPropertiesRequest( template=self._prepare_prompt_jinja_template( - prompt.original_template, - prompt.get_langchain_prompt().input_variables, + prompt.template, + prompt.template_vars, ), ), ) @@ -213,10 +216,7 @@ class VellumPromptRegistry: version=1, blocks=[block], ), - input_variables=[ - {"key": input_var} - for input_var in prompt.get_langchain_prompt().input_variables - ], + input_variables=[{"key": input_var} for input_var in prompt.template_vars], ) def _prepare_prompt_jinja_template( diff --git a/llama_index/llms/generic_utils.py b/llama_index/llms/generic_utils.py index 6ee3ab508687d9549f36fb1f4121e8cd613a4b4b..0d09257a4b3741d678b9451a982623e69cc7087b 100644 --- a/llama_index/llms/generic_utils.py +++ b/llama_index/llms/generic_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Awaitable, Callable, Sequence +from typing import Any, Awaitable, Callable, List, Sequence from llama_index.llms.base import ( ChatMessage, @@ -44,7 +44,7 @@ def messages_to_prompt(messages: Sequence[ChatMessage]) -> str: return "\n".join(string_messages) -def prompt_to_messages(prompt: str) -> Sequence[ChatMessage]: +def prompt_to_messages(prompt: str) -> List[ChatMessage]: """Convert a string prompt to a sequence of messages.""" return [ChatMessage(role=MessageRole.USER, content=prompt)] diff --git a/llama_index/llms/palm.py b/llama_index/llms/palm.py index 567ca29ae185d17d7be737b66869f6483ef55952..68ba5703558cfb50b15c2e0666d5ee486872bcc6 100644 --- a/llama_index/llms/palm.py +++ b/llama_index/llms/palm.py @@ -80,7 +80,7 @@ class PaLM(CustomLLM): """Predict the answer to a query. Args: - prompt (Prompt): Prompt to use for prediction. + prompt (str): Prompt to use for prediction. Returns: Tuple[str, str]: Tuple of the predicted answer and the formatted prompt. @@ -104,7 +104,7 @@ class PaLM(CustomLLM): better abstractions about response handling. Args: - prompt (Prompt): Prompt to use for prediction. + prompt (str): Prompt to use for prediction. Returns: str: The predicted answer. diff --git a/llama_index/node_parser/extractors/metadata_extractors.py b/llama_index/node_parser/extractors/metadata_extractors.py index 06c8f400730b95f45fabdb22a0a498e8ff7ce71c..3b23744880f58f5cb1295d253be7201ac1ae5011 100644 --- a/llama_index/node_parser/extractors/metadata_extractors.py +++ b/llama_index/node_parser/extractors/metadata_extractors.py @@ -21,14 +21,15 @@ disambiguate the document or subsection from other similar documents or subsecti """ import json from abc import abstractmethod -from pydantic import Field, PrivateAttr -from typing import Any, List, Optional, Sequence, cast, Dict, Callable from functools import reduce +from typing import Any, Callable, Dict, List, Optional, Sequence, cast + +from pydantic import Field, PrivateAttr -from llama_index.llms.base import LLM from llama_index.llm_predictor.base import BaseLLMPredictor, LLMPredictor +from llama_index.llms.base import LLM from llama_index.node_parser.interface import BaseExtractor -from llama_index.prompts.base import Prompt +from llama_index.prompts import PromptTemplate from llama_index.schema import BaseNode, TextNode @@ -199,7 +200,7 @@ class TitleExtractor(MetadataFeatureExtractor): title_candidates = [ self.llm_predictor.predict( - Prompt(template=self.node_template), + PromptTemplate(template=self.node_template), context_str=cast(TextNode, node).text, ) for node in nodes_to_extract_title @@ -210,7 +211,7 @@ class TitleExtractor(MetadataFeatureExtractor): ) title = self.llm_predictor.predict( - Prompt(template=self.combine_template), + PromptTemplate(template=self.combine_template), context_str=titles, ) else: @@ -257,7 +258,7 @@ class KeywordExtractor(MetadataFeatureExtractor): # TODO: figure out a good way to allow users to customize keyword template keywords = self.llm_predictor.predict( - Prompt( + PromptTemplate( template=f"""\ {{context_str}}. Give {self.keywords} unique keywords for this \ document. Format as comma separated. Keywords: """ @@ -322,7 +323,7 @@ class QuestionsAnsweredExtractor(MetadataFeatureExtractor): # Extract the title from the first node # TODO: figure out a good way to allow users to customize template questions = self.llm_predictor.predict( - Prompt( + PromptTemplate( template=self.prompt_template or f"""\ {{context_str}}. Given the contextual information, \ @@ -399,7 +400,7 @@ class SummaryExtractor(MetadataFeatureExtractor): raise ValueError("Only `TextNode` is allowed for `Summary` extractor") node_summaries = [ self.llm_predictor.predict( - Prompt(template=self.prompt_template), + PromptTemplate(template=self.prompt_template), context_str=cast(TextNode, node).text, ).strip() for node in nodes diff --git a/llama_index/program/llm_program.py b/llama_index/program/llm_program.py index 96c71f8c10e57d3a2697aacdbcbe56f8b286de26..ca19072dfd5f166f46a62a458eb50c5c2c2dcdcf 100644 --- a/llama_index/program/llm_program.py +++ b/llama_index/program/llm_program.py @@ -1,12 +1,12 @@ from typing import Any, Dict, Optional, Type, Union - from pydantic import BaseModel + from llama_index.llms.base import LLM from llama_index.llms.openai import OpenAI -from llama_index.program.base_program import BasePydanticProgram -from llama_index.prompts.base import Prompt from llama_index.output_parsers.pydantic import PydanticOutputParser +from llama_index.program.base_program import BasePydanticProgram +from llama_index.prompts.base import PromptTemplate class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]): @@ -20,7 +20,7 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]): def __init__( self, output_parser: PydanticOutputParser, - prompt: Prompt, + prompt: PromptTemplate, llm: LLM, function_call: Union[str, Dict[str, Any]], verbose: bool = False, @@ -42,7 +42,7 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]): **kwargs: Any, ) -> "LLMTextCompletionProgram": llm = llm or OpenAI(temperature=0, model="gpt-3.5-turbo-0613") - prompt = Prompt(prompt_template_str) + prompt = PromptTemplate(prompt_template_str) function_call = function_call or { "name": output_parser.output_cls.schema()["title"] } @@ -64,9 +64,9 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]): **kwargs: Any, ) -> BaseModel: prompt_with_parse_instrs_tmpl = self._output_parser.format( - self._prompt.original_template + self._prompt.template ) - prompt_with_parse_instrs = Prompt(prompt_with_parse_instrs_tmpl) + prompt_with_parse_instrs = PromptTemplate(prompt_with_parse_instrs_tmpl) formatted_prompt = prompt_with_parse_instrs.format(**kwargs) diff --git a/llama_index/program/openai_program.py b/llama_index/program/openai_program.py index 3e06c5541095119caea419332f9f4f00f34b1d1b..5dee0e81607b4c7df0739145a41fac05bc648c68 100644 --- a/llama_index/program/openai_program.py +++ b/llama_index/program/openai_program.py @@ -6,7 +6,7 @@ from llama_index.llms.base import LLM, ChatMessage, MessageRole from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import to_openai_function from llama_index.program.llm_prompt_program import BaseLLMFunctionProgram -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.types import Model from llama_index.program.utils import create_list_model from typing import Tuple @@ -46,7 +46,7 @@ class OpenAIPydanticProgram(BaseLLMFunctionProgram[LLM]): self, output_cls: Type[Model], llm: LLM, - prompt: Prompt, + prompt: BasePromptTemplate, function_call: Union[str, Dict[str, Any]], verbose: bool = False, ) -> None: @@ -75,7 +75,7 @@ class OpenAIPydanticProgram(BaseLLMFunctionProgram[LLM]): "function calling API. " ) - prompt = Prompt(prompt_template_str) + prompt = PromptTemplate(prompt_template_str) function_call = function_call or _default_function_call(output_cls) return cls( output_cls=output_cls, diff --git a/llama_index/program/predefined/df.py b/llama_index/program/predefined/df.py index 90418288305b55d4285e7b9a26efd52c5fae64ad..eb6f4a6402b37db56a50c4914f9683fce751f575 100644 --- a/llama_index/program/predefined/df.py +++ b/llama_index/program/predefined/df.py @@ -3,7 +3,6 @@ from llama_index.program.base_program import BasePydanticProgram from typing import Optional, List, Any, Type, cast from pydantic import BaseModel, Field from llama_index.program.openai_program import OpenAIPydanticProgram -from llama_index.prompts.prompts import Prompt from llama_index.program.llm_prompt_program import BaseLLMFunctionProgram import pandas as pd @@ -162,16 +161,12 @@ class DFRowsProgram(BasePydanticProgram[DataFrameRowsOnly]): ) -> None: """Init params.""" # partial format df parser template string with column schema - # NOTE: hack where we use prompt class to partial format - orig_prompt = Prompt(df_parser_template_str) - new_prompt = Prompt.from_prompt( - orig_prompt.partial_format( - column_schema=column_schema, - ) + prompt_template_str = df_parser_template_str.replace( + "{column_schema}", column_schema or "" ) pydantic_program = pydantic_program_cls.from_defaults( - DataFrameRowsOnly, new_prompt.original_template, **program_kwargs + DataFrameRowsOnly, prompt_template_str, **program_kwargs ) self._validate_program(pydantic_program) self._pydantic_program = pydantic_program diff --git a/llama_index/program/predefined/evaporate/extractor.py b/llama_index/program/predefined/evaporate/extractor.py index 62702e173c2300dbe24660f71b991bded6019230..a42ae8d2e1b481edf8571961f50b98f0712efb7c 100644 --- a/llama_index/program/predefined/evaporate/extractor.py +++ b/llama_index/program/predefined/evaporate/extractor.py @@ -15,7 +15,6 @@ from llama_index.schema import BaseNode, MetadataMode, NodeWithScore from llama_index.indices.query.schema import QueryBundle -from llama_index.prompts.prompts import QuestionAnswerPrompt from llama_index.program.predefined.evaporate.prompts import ( @@ -177,12 +176,11 @@ class EvaporateExtractor: else: expected_output_str = "" - new_prompt = self._fn_generate_prompt.partial_format( + qa_prompt = self._fn_generate_prompt.partial_format( attribute=field, function_field=function_field, expected_output_str=expected_output_str, ) - qa_prompt = QuestionAnswerPrompt.from_prompt(new_prompt) response_synthesizer = get_response_synthesizer( service_context=self._service_context, diff --git a/llama_index/program/predefined/evaporate/prompts.py b/llama_index/program/predefined/evaporate/prompts.py index 8aa7bdb403a84b3365b5572aa3621fff6a2394a8..7e7351c8e5cc8075efa506a200e7b0c96d68b99d 100644 --- a/llama_index/program/predefined/evaporate/prompts.py +++ b/llama_index/program/predefined/evaporate/prompts.py @@ -6,32 +6,32 @@ Full credits go to: https://github.com/HazyResearch/evaporate """ -from llama_index.prompts.prompts import Prompt +from llama_index.prompts import PromptTemplate # deprecated, kept for backward compatibility -"""Pandas prompt. Convert query to python code. +"""Pandas PromptTemplate. Convert query to python code. Required template variables: `chunk`, `topic`. Args: - template (str): Template for the prompt. - **prompt_kwargs: Keyword arguments for the prompt. + template (str): Template for the PromptTemplate. + **prompt_kwargs: Keyword arguments for the PromptTemplate. """ -SchemaIDPrompt = Prompt +SchemaIDPrompt = PromptTemplate -"""Function generation prompt. Generate a function from existing text. +"""Function generation PromptTemplate. Generate a function from existing text. Required template variables: `context_str`, `query_str`, `attribute`, `function_field`. Args: - template (str): Template for the prompt. - **prompt_kwargs: Keyword arguments for the prompt. + template (str): Template for the PromptTemplate. + **prompt_kwargs: Keyword arguments for the PromptTemplate. """ -FnGeneratePrompt = Prompt +FnGeneratePrompt = PromptTemplate # used for schema identification SCHEMA_ID_PROMPT_TMPL = f"""Sample text: @@ -81,7 +81,7 @@ Sample text: Question: List all relevant attributes about '{{topic:}}' that are exactly mentioned in this sample text if any. Answer:""" # noqa: E501, F541 -SCHEMA_ID_PROMPT = Prompt(SCHEMA_ID_PROMPT_TMPL) +SCHEMA_ID_PROMPT = PromptTemplate(SCHEMA_ID_PROMPT_TMPL) # used for function generation @@ -108,7 +108,7 @@ def get_{{function_field:}}_field(text: str): \""" """ # noqa: E501, F541 -FN_GENERATION_PROMPT = Prompt(FN_GENERATION_PROMPT_TMPL) +FN_GENERATION_PROMPT = PromptTemplate(FN_GENERATION_PROMPT_TMPL) FN_GENERATION_LIST_PROMPT_TMPL = f"""Here is a sample of text: @@ -134,7 +134,7 @@ def get_{{function_field:}}_field(text: str) -> List: \""" """ # noqa: E501, F541 -FN_GENERATION_LIST_PROMPT = Prompt(FN_GENERATION_LIST_PROMPT_TMPL) +FN_GENERATION_LIST_PROMPT = PromptTemplate(FN_GENERATION_LIST_PROMPT_TMPL) DEFAULT_EXPECTED_OUTPUT_PREFIX_TMPL = ( "Here is the expected output on the text after running the function. " diff --git a/llama_index/prompts/__init__.py b/llama_index/prompts/__init__.py index 872453c72eace5e61fcf2811a1307a5faa67a285..780c2b7cd93f99a43d19d08c8334e27ed7e823a7 100644 --- a/llama_index/prompts/__init__.py +++ b/llama_index/prompts/__init__.py @@ -1,5 +1,25 @@ """Prompt class.""" -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import ( + BasePromptTemplate, + ChatPromptTemplate, + LangchainPromptTemplate, + Prompt, + PromptTemplate, + PromptType, + SelectorPromptTemplate, +) -__all__ = ["Prompt"] +from llama_index.llms.base import ChatMessage, MessageRole + +__all__ = [ + "Prompt", + "PromptTemplate", + "SelectorPromptTemplate", + "ChatPromptTemplate", + "LangchainPromptTemplate", + "BasePromptTemplate", + "PromptType", + "ChatMessage", + "MessageRole", +] diff --git a/llama_index/prompts/base.py b/llama_index/prompts/base.py index 3911404ea5580b4aeb69c8ea8976ef40ef65992d..3c4d5f1e46b6501db7ac7104a608053966cca076 100644 --- a/llama_index/prompts/base.py +++ b/llama_index/prompts/base.py @@ -1,173 +1,295 @@ -"""Base module for prompts.""" +"""Prompts.""" + + +from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple + +from pydantic import BaseModel -from llama_index.bridge.langchain import BasePromptTemplate as BaseLangchainPrompt -from llama_index.bridge.langchain import PromptTemplate as LangchainPrompt +from llama_index.bridge.langchain import BasePromptTemplate as LangchainTemplate +from llama_index.bridge.langchain import ConditionalPromptSelector as LangchainSelector from llama_index.llms.base import LLM, ChatMessage +from llama_index.llms.generic_utils import messages_to_prompt, prompt_to_messages +from llama_index.llms.langchain import LangChainLLM from llama_index.llms.langchain_utils import from_lc_messages -from llama_index.prompts.prompt_selector import PromptSelector from llama_index.prompts.prompt_type import PromptType +from llama_index.prompts.utils import get_template_vars from llama_index.types import BaseOutputParser -class Prompt: - """Prompt class for LlamaIndex. +class BasePromptTemplate(BaseModel, ABC): + metadata: Dict[str, Any] + template_vars: List[str] + kwargs: Dict[str, str] + output_parser: Optional[BaseOutputParser] + + class Config: + arbitrary_types_allowed = True + + @abstractmethod + def partial_format(self, **kwargs: Any) -> "BasePromptTemplate": + ... + + @abstractmethod + def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: + ... + + @abstractmethod + def format_messages( + self, llm: Optional[LLM] = None, **kwargs: Any + ) -> List[ChatMessage]: + ... + + +class PromptTemplate(BasePromptTemplate): + template: str + + def __init__( + self, + template: str, + prompt_type: str = PromptType.CUSTOM, + output_parser: Optional[BaseOutputParser] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + if metadata is None: + metadata = {} + metadata["prompt_type"] = prompt_type + + template_vars = get_template_vars(template) + + super().__init__( + template=template, + template_vars=template_vars, + kwargs=kwargs, + metadata=metadata, + output_parser=output_parser, + ) + + def partial_format(self, **kwargs: Any) -> "PromptTemplate": + """Partially format the prompt.""" + prompt = deepcopy(self) + prompt.kwargs.update(kwargs) + return prompt + + def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: + """Format the prompt into a string.""" + del llm # unused + all_kwargs = { + **self.kwargs, + **kwargs, + } + return self.template.format(**all_kwargs) + + def format_messages( + self, llm: Optional[LLM] = None, **kwargs: Any + ) -> List[ChatMessage]: + """Format the prompt into a list of chat messages.""" + del llm # unused + prompt = self.format(**kwargs) + return prompt_to_messages(prompt) + + +class ChatPromptTemplate(BasePromptTemplate): + message_templates: List[ChatMessage] + + def __init__( + self, + message_templates: List[ChatMessage], + prompt_type: str = PromptType.CUSTOM, + output_parser: Optional[BaseOutputParser] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ): + if metadata is None: + metadata = {} + metadata["prompt_type"] = prompt_type + + template_vars = [] + for message_template in message_templates: + template_vars.extend(get_template_vars(message_template.content or "")) + + super().__init__( + message_templates=message_templates, + kwargs=kwargs, + metadata=metadata, + output_parser=output_parser, + template_vars=template_vars, + ) + + def partial_format(self, **kwargs: Any) -> "ChatPromptTemplate": + prompt = deepcopy(self) + prompt.kwargs.update(kwargs) + return prompt + + def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: + del llm # unused + messages = self.format_messages(**kwargs) + prompt = messages_to_prompt(messages) + return prompt + + def format_messages( + self, llm: Optional[LLM] = None, **kwargs: Any + ) -> List[ChatMessage]: + del llm # unused + """Format the prompt into a list of chat messages.""" + all_kwargs = { + **self.kwargs, + **kwargs, + } + + messages = [] + for message_template in self.message_templates: + template_vars = get_template_vars(message_template.content or "") + relevant_kwargs = { + k: v for k, v in all_kwargs.items() if k in template_vars + } + content_template = message_template.content or "" + content = content_template.format(**relevant_kwargs) - Wrapper around langchain's prompt class. Adds ability to: - - enforce certain prompt types - - partially fill values - """ + message = message_template.copy() + message.content = content + messages.append(message) + + return messages + + +class SelectorPromptTemplate(BasePromptTemplate): + default_template: BasePromptTemplate + conditionals: Optional[ + List[Tuple[Callable[[LLM], bool], BasePromptTemplate]] + ] = None def __init__( self, - template: Optional[str] = None, - langchain_prompt: Optional[BaseLangchainPrompt] = None, - langchain_prompt_selector: Optional[PromptSelector] = None, + default_template: BasePromptTemplate, + conditionals: Optional[ + List[Tuple[Callable[[LLM], bool], BasePromptTemplate]] + ] = None, + ): + metadata = default_template.metadata + kwargs = default_template.kwargs + template_vars = default_template.template_vars + output_parser = default_template.output_parser + super().__init__( + default_template=default_template, + conditionals=conditionals, + metadata=metadata, + kwargs=kwargs, + template_vars=template_vars, + output_parser=output_parser, + ) + + def _select(self, llm: Optional[LLM] = None) -> BasePromptTemplate: + if llm is None: + return self.default_template + + if self.conditionals is not None: + for condition, prompt in self.conditionals: + if condition(llm): + return prompt + return self.default_template + + def partial_format(self, **kwargs: Any) -> "SelectorPromptTemplate": + default_template = self.default_template.partial_format(**kwargs) + if self.conditionals is None: + conditionals = None + else: + conditionals = [ + (condition, prompt.partial_format(**kwargs)) + for condition, prompt in self.conditionals + ] + return SelectorPromptTemplate( + default_template=default_template, conditionals=conditionals + ) + + def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: + """Format the prompt into a string.""" + prompt = self._select(llm=llm) + return prompt.format(**kwargs) + + def format_messages( + self, llm: Optional[LLM] = None, **kwargs: Any + ) -> List[ChatMessage]: + """Format the prompt into a list of chat messages.""" + prompt = self._select(llm=llm) + return prompt.format_messages(**kwargs) + + +class LangchainPromptTemplate(BasePromptTemplate): + selector: LangchainSelector + + def __init__( + self, + template: Optional[LangchainTemplate] = None, + selector: Optional[LangchainSelector] = None, output_parser: Optional[BaseOutputParser] = None, prompt_type: str = PromptType.CUSTOM, metadata: Optional[Dict[str, Any]] = None, - **prompt_kwargs: Any, ) -> None: - """Init params.""" - # first check if langchain_prompt_selector is provided - # TODO: self.prompt is deprecated, switch to prompt_selector under the hood - if langchain_prompt_selector is not None: - self.prompt_selector = langchain_prompt_selector - self.prompt: BaseLangchainPrompt = self.prompt_selector.default_prompt - # then check if template is provided - elif langchain_prompt is None: + if selector is None: if template is None: - raise ValueError( - "`template` must be specified if `langchain_prompt` is None" - ) - - self.prompt = LangchainPrompt.from_template( - template=template, **prompt_kwargs - ) - self.prompt_selector = PromptSelector(default_prompt=self.prompt) - # finally, check if langchain_prompt is provided + raise ValueError("Must provide either template or selector.") + selector = LangchainSelector(default_prompt=template) else: - if template: - raise ValueError( - f"Both template ({template}) and langchain_prompt " - f"({langchain_prompt}) are provided, only one should be." - ) - self.prompt = langchain_prompt - self.prompt_selector = PromptSelector(default_prompt=self.prompt) - - self.partial_dict: Dict[str, Any] = {} - self.prompt_kwargs = prompt_kwargs - # NOTE: this is only used for token counting and testing - self.prompt_type = prompt_type - - self.output_parser = output_parser - - self._original_template = template - - # Metadata is used to pass arbitrary information to other consumers of the - # prompt. For example, VellumPromptRegistry uses this to access vellum-specific - # identifiers that users can pass along with the prompt. - self.metadata = metadata or {} - - @property - def original_template(self) -> str: - """Return the originally specified template, if supplied.""" - - if not self._original_template: - raise ValueError("No original template specified.") - - return self._original_template - - @classmethod - def from_langchain_prompt( - cls, prompt: BaseLangchainPrompt, **kwargs: Any - ) -> "Prompt": - """Load prompt from LangChain prompt.""" - return cls(langchain_prompt=prompt, **kwargs) - - @classmethod - def from_langchain_prompt_selector( - cls, prompt_selector: PromptSelector, **kwargs: Any - ) -> "Prompt": - """Load prompt from LangChain prompt.""" - return cls(langchain_prompt_selector=prompt_selector, **kwargs) - - def partial_format(self, **kwargs: Any) -> "Prompt": - """Format the prompt partially. - - Return an instance of itself. - - """ - try: - # NOTE: this is a hack to get around deepcopy failing on output parser - output_parser = self.output_parser - self.output_parser = None - - copy_obj = deepcopy(self) - copy_obj.output_parser = output_parser - copy_obj.partial_dict.update(kwargs) - self.output_parser = output_parser - except Exception as e: - raise e - - return copy_obj - - @classmethod - def from_prompt( - cls, - prompt: "Prompt", - llm: Optional[LLM] = None, - prompt_type: Optional[PromptType] = None, - ) -> "Prompt": - """Create a prompt from an existing prompt. - - Use case: If the existing prompt is already partially filled, - and the remaining fields satisfy the requirements of the - prompt class, then we can create a new prompt from the existing - partially filled prompt. - - """ - lc_prompt = prompt.get_langchain_prompt(llm=llm) - tmpl_vars = lc_prompt.input_variables - format_dict = {} - for var in tmpl_vars: - if var not in prompt.partial_dict: - format_dict[var] = f"{{{var}}}" - - template_str = prompt.format(llm=llm, **format_dict) - cls_obj = cls( - template_str, - prompt_type=prompt_type or PromptType.CUSTOM, - **prompt.prompt_kwargs, + if template is not None: + raise ValueError("Must provide either template or selector.") + selector = selector + + kwargs = selector.default_prompt.partial_variables + template_vars = selector.default_prompt.input_variables + + if metadata is None: + metadata = {} + metadata["prompt_type"] = prompt_type + + super().__init__( + selector=selector, + metadata=metadata, + kwargs=kwargs, + template_vars=template_vars, + output_parser=output_parser, ) - return cls_obj - def get_langchain_prompt(self, llm: Optional[LLM] = None) -> BaseLangchainPrompt: - """Get langchain prompt.""" - return self.prompt_selector.select(llm=llm) + def partial_format(self, **kwargs: Any) -> "BasePromptTemplate": + """Partially format the prompt.""" + default_prompt = self.selector.default_prompt.partial(**kwargs) + conditionals = [ + (condition, prompt.partial(**kwargs)) + for condition, prompt in self.selector.conditionals + ] + lc_selector = LangchainSelector( + default_prompt=default_prompt, conditionals=conditionals + ) + return LangchainPromptTemplate(selector=lc_selector) def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: """Format the prompt into a string.""" - kwargs.update(self.partial_dict) - lc_prompt = self.get_langchain_prompt(llm=llm) - return lc_prompt.format(**kwargs) + if llm is not None: + if not isinstance(llm, LangChainLLM): + raise ValueError("Must provide a LangChainLLM.") + lc_template = self.selector.get_prompt(llm=llm.llm) + else: + lc_template = self.selector.default_prompt + + return lc_template.format(**kwargs) def format_messages( self, llm: Optional[LLM] = None, **kwargs: Any ) -> List[ChatMessage]: """Format the prompt into a list of chat messages.""" - kwargs.update(self.partial_dict) - lc_template = self.get_langchain_prompt(llm=llm) - lc_value = lc_template.format_prompt(**kwargs) - lc_messages = lc_value.to_messages() - return from_lc_messages(lc_messages) - - def get_full_format_args(self, kwargs: Dict) -> Dict[str, Any]: - """Get dict of all format args. + if llm is not None: + if not isinstance(llm, LangChainLLM): + raise ValueError("Must provide a LangChainLLM.") + lc_template = self.selector.get_prompt(llm=llm.llm) + else: + lc_template = self.selector.default_prompt + lc_prompt_value = lc_template.format_prompt(**kwargs) + lc_messages = lc_prompt_value.to_messages() + messages = from_lc_messages(lc_messages) + return messages - Hack to pass into Langchain to pass validation. - """ - kwargs.update(self.partial_dict) - return kwargs +# NOTE: only for backwards compatibility +Prompt = PromptTemplate diff --git a/llama_index/prompts/chat_prompts.py b/llama_index/prompts/chat_prompts.py index 51a3c0ca564d3bb1e787856530f7fe45510415df..97d673294144b3ae18511cdff5d9701e00184257 100644 --- a/llama_index/prompts/chat_prompts.py +++ b/llama_index/prompts/chat_prompts.py @@ -1,114 +1,109 @@ """Prompts for ChatGPT.""" -from llama_index.bridge.langchain import ( - AIMessagePromptTemplate, - ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate, -) - -from llama_index.prompts.prompts import ( - QuestionAnswerPrompt, - SummaryPrompt, - RefinePrompt, - RefineTableContextPrompt, -) +from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.prompts.base import ChatPromptTemplate # text qa prompt -TEXT_QA_SYSTEM_PROMPT = SystemMessagePromptTemplate.from_template( - "You are an expert Q&A system that is trusted around the world.\n" - "Always answer the query using the provided context information, " - "and not prior knowledge.\n" - "Some rules to follow:\n" - "1. Never directly reference the given context in your answer.\n" - "2. Avoid statements like 'Based on the context, ...' or " - "'The context information ...' or anything along " - "those lines." +TEXT_QA_SYSTEM_PROMPT = ChatMessage( + content=( + "You are an expert Q&A system that is trusted around the world.\n" + "Always answer the query using the provided context information, " + "and not prior knowledge.\n" + "Some rules to follow:\n" + "1. Never directly reference the given context in your answer.\n" + "2. Avoid statements like 'Based on the context, ...' or " + "'The context information ...' or anything along " + "those lines." + ), + role=MessageRole.SYSTEM, ) TEXT_QA_PROMPT_TMPL_MSGS = [ TEXT_QA_SYSTEM_PROMPT, - HumanMessagePromptTemplate.from_template( - "Context information is below.\n" - "---------------------\n" - "{context_str}\n" - "---------------------\n" - "Given the context information and not prior knowledge, " - "answer the query.\n" - "Query: {query_str}\n" - "Answer: " + ChatMessage( + content=( + "Context information is below.\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "Given the context information and not prior knowledge, " + "answer the query.\n" + "Query: {query_str}\n" + "Answer: " + ), + role=MessageRole.USER, ), ] -CHAT_TEXT_QA_PROMPT_LC = ChatPromptTemplate.from_messages(TEXT_QA_PROMPT_TMPL_MSGS) -CHAT_TEXT_QA_PROMPT = QuestionAnswerPrompt.from_langchain_prompt(CHAT_TEXT_QA_PROMPT_LC) - +CHAT_TEXT_QA_PROMPT = ChatPromptTemplate(message_templates=TEXT_QA_PROMPT_TMPL_MSGS) # Tree Summarize TREE_SUMMARIZE_PROMPT_TMPL_MSGS = [ TEXT_QA_SYSTEM_PROMPT, - HumanMessagePromptTemplate.from_template( - "Context information from multiple sources is below.\n" - "---------------------\n" - "{context_str}\n" - "---------------------\n" - "Given the information from multiple sources and not prior knowledge, " - "answer the query.\n" - "Query: {query_str}\n" - "Answer: " + ChatMessage( + content=( + "Context information from multiple sources is below.\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "Given the information from multiple sources and not prior knowledge, " + "answer the query.\n" + "Query: {query_str}\n" + "Answer: " + ), + role=MessageRole.USER, ), ] -CHAT_TREE_SUMMARIZE_PROMPT_LC = ChatPromptTemplate.from_messages( - TREE_SUMMARIZE_PROMPT_TMPL_MSGS -) -CHAT_TREE_SUMMARIZE_PROMPT = SummaryPrompt.from_langchain_prompt( - CHAT_TREE_SUMMARIZE_PROMPT_LC +CHAT_TREE_SUMMARIZE_PROMPT = ChatPromptTemplate( + message_templates=TREE_SUMMARIZE_PROMPT_TMPL_MSGS ) # Refine Prompt CHAT_REFINE_PROMPT_TMPL_MSGS = [ - HumanMessagePromptTemplate.from_template( - "You are an expert Q&A system that stricly operates in two modes" - "when refining existing answers:\n" - "1. **Rewrite** an original answer using the new context.\n" - "2. **Repeat** the original answer if the new context isn't useful.\n" - "Never reference the original answer or context directly in your answer.\n" - "When in doubt, just repeat the original answer." - "New Context: {context_msg}\n" - "Query: {query_str}\n" - "Original Answer: {existing_answer}\n" - "New Answer: " - ), + ChatMessage( + content=( + "You are an expert Q&A system that stricly operates in two modes" + "when refining existing answers:\n" + "1. **Rewrite** an original answer using the new context.\n" + "2. **Repeat** the original answer if the new context isn't useful.\n" + "Never reference the original answer or context directly in your answer.\n" + "When in doubt, just repeat the original answer." + "New Context: {context_msg}\n" + "Query: {query_str}\n" + "Original Answer: {existing_answer}\n" + "New Answer: " + ), + role=MessageRole.USER, + ) ] -CHAT_REFINE_PROMPT_LC = ChatPromptTemplate.from_messages(CHAT_REFINE_PROMPT_TMPL_MSGS) -CHAT_REFINE_PROMPT = RefinePrompt.from_langchain_prompt(CHAT_REFINE_PROMPT_LC) +CHAT_REFINE_PROMPT = ChatPromptTemplate(message_templates=CHAT_REFINE_PROMPT_TMPL_MSGS) # Table Context Refine Prompt CHAT_REFINE_TABLE_CONTEXT_TMPL_MSGS = [ - HumanMessagePromptTemplate.from_template("{query_str}"), - AIMessagePromptTemplate.from_template("{existing_answer}"), - HumanMessagePromptTemplate.from_template( - "We have provided a table schema below. " - "---------------------\n" - "{schema}\n" - "---------------------\n" - "We have also provided some context information below. " - "{context_msg}\n" - "---------------------\n" - "Given the context information and the table schema, " - "refine the original answer to better " - "answer the question. " - "If the context isn't useful, return the original answer." + ChatMessage(content="{query_str}", role=MessageRole.USER), + ChatMessage(content="{existing_answer}", role=MessageRole.ASSISTANT), + ChatMessage( + content=( + "We have provided a table schema below. " + "---------------------\n" + "{schema}\n" + "---------------------\n" + "We have also provided some context information below. " + "{context_msg}\n" + "---------------------\n" + "Given the context information and the table schema, " + "refine the original answer to better " + "answer the question. " + "If the context isn't useful, return the original answer." + ), + role=MessageRole.USER, ), ] -CHAT_REFINE_TABLE_CONTEXT_PROMPT_LC = ChatPromptTemplate.from_messages( - CHAT_REFINE_TABLE_CONTEXT_TMPL_MSGS -) -CHAT_REFINE_TABLE_CONTEXT_PROMPT = RefineTableContextPrompt.from_langchain_prompt( - CHAT_REFINE_TABLE_CONTEXT_PROMPT_LC +CHAT_REFINE_TABLE_CONTEXT_PROMPT = ChatPromptTemplate( + message_templates=CHAT_REFINE_TABLE_CONTEXT_TMPL_MSGS ) diff --git a/llama_index/prompts/choice_select.py b/llama_index/prompts/choice_select.py deleted file mode 100644 index b03af6c0213762e9f722d3fbb668e1094c67ce2a..0000000000000000000000000000000000000000 --- a/llama_index/prompts/choice_select.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Default choice select prompt.""" - -from llama_index.prompts.base import Prompt -from llama_index.prompts.prompt_type import PromptType - -# deprecated, kept for backward compatibility -ChoiceSelectPrompt = Prompt - -DEFAULT_CHOICE_SELECT_PROMPT_TMPL = ( - "A list of documents is shown below. Each document has a number next to it along " - "with a summary of the document. A question is also provided. \n" - "Respond with the numbers of the documents " - "you should consult to answer the question, in order of relevance, as well \n" - "as the relevance score. The relevance score is a number from 1-10 based on " - "how relevant you think the document is to the question.\n" - "Do not include any documents that are not relevant to the question. \n" - "Example format: \n" - "Document 1:\n<summary of document 1>\n\n" - "Document 2:\n<summary of document 2>\n\n" - "...\n\n" - "Document 10:\n<summary of document 10>\n\n" - "Question: <question>\n" - "Answer:\n" - "Doc: 9, Relevance: 7\n" - "Doc: 3, Relevance: 4\n" - "Doc: 7, Relevance: 3\n\n" - "Let's try this now: \n\n" - "{context_str}\n" - "Question: {query_str}\n" - "Answer:\n" -) -DEFAULT_CHOICE_SELECT_PROMPT = Prompt( - DEFAULT_CHOICE_SELECT_PROMPT_TMPL, prompt_type=PromptType.CHOICE_SELECT -) diff --git a/llama_index/prompts/default_prompt_selectors.py b/llama_index/prompts/default_prompt_selectors.py index 17c0307f2490490ce0e63c29f062f65a9189cc57..d11b026d3dd0511f1a1b2d519c5dcb59d54955c8 100644 --- a/llama_index/prompts/default_prompt_selectors.py +++ b/llama_index/prompts/default_prompt_selectors.py @@ -1,59 +1,35 @@ -"""Prompt selectors.""" +"""Default prompt selectors.""" +from llama_index.prompts import SelectorPromptTemplate from llama_index.prompts.chat_prompts import ( - CHAT_TEXT_QA_PROMPT, - CHAT_TREE_SUMMARIZE_PROMPT, CHAT_REFINE_PROMPT, CHAT_REFINE_TABLE_CONTEXT_PROMPT, + CHAT_TEXT_QA_PROMPT, + CHAT_TREE_SUMMARIZE_PROMPT, ) from llama_index.prompts.default_prompts import ( - DEFAULT_TEXT_QA_PROMPT, - DEFAULT_TREE_SUMMARIZE_PROMPT, DEFAULT_REFINE_PROMPT, DEFAULT_REFINE_TABLE_CONTEXT_PROMPT, + DEFAULT_TEXT_QA_PROMPT, + DEFAULT_TREE_SUMMARIZE_PROMPT, ) -from llama_index.prompts.prompt_selector import PromptSelector, is_chat_model -from llama_index.prompts.prompt_type import PromptType -from llama_index.prompts.prompts import ( - QuestionAnswerPrompt, - RefinePrompt, - RefineTableContextPrompt, -) - -DEFAULT_TEXT_QA_PROMPT_SEL_LC = PromptSelector( - default_prompt=DEFAULT_TEXT_QA_PROMPT.get_langchain_prompt(), - conditionals=[(is_chat_model, CHAT_TEXT_QA_PROMPT.get_langchain_prompt())], -) -DEFAULT_TEXT_QA_PROMPT_SEL = QuestionAnswerPrompt( - langchain_prompt_selector=DEFAULT_TEXT_QA_PROMPT_SEL_LC, - prompt_type=PromptType.QUESTION_ANSWER, -) +from llama_index.prompts.utils import is_chat_model -DEFAULT_TREE_SUMMARIZE_PROMPT_SEL_LC = PromptSelector( - default_prompt=DEFAULT_TREE_SUMMARIZE_PROMPT.get_langchain_prompt(), - conditionals=[(is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT.get_langchain_prompt())], -) -DEFAULT_TREE_SUMMARIZE_PROMPT_SEL = QuestionAnswerPrompt( - langchain_prompt_selector=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL_LC, - prompt_type=PromptType.SUMMARY, +DEFAULT_TEXT_QA_PROMPT_SEL = SelectorPromptTemplate( + default_template=DEFAULT_TEXT_QA_PROMPT, + conditionals=[(is_chat_model, CHAT_TEXT_QA_PROMPT)], ) -DEFAULT_REFINE_PROMPT_SEL_LC = PromptSelector( - default_prompt=DEFAULT_REFINE_PROMPT.get_langchain_prompt(), - conditionals=[(is_chat_model, CHAT_REFINE_PROMPT.get_langchain_prompt())], -) -DEFAULT_REFINE_PROMPT_SEL = RefinePrompt( - langchain_prompt_selector=DEFAULT_REFINE_PROMPT_SEL_LC, - prompt_type=PromptType.REFINE, +DEFAULT_TREE_SUMMARIZE_PROMPT_SEL = SelectorPromptTemplate( + default_template=DEFAULT_TREE_SUMMARIZE_PROMPT, + conditionals=[(is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT)], ) -DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL_LC = PromptSelector( - default_prompt=DEFAULT_REFINE_TABLE_CONTEXT_PROMPT.get_langchain_prompt(), - conditionals=[ - (is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT.get_langchain_prompt()) - ], +DEFAULT_REFINE_PROMPT_SEL = SelectorPromptTemplate( + default_template=DEFAULT_REFINE_PROMPT, + conditionals=[(is_chat_model, CHAT_REFINE_PROMPT)], ) -DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL = RefineTableContextPrompt( - langchain_prompt_selector=DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL_LC, - prompt_type=PromptType.TABLE_CONTEXT, +DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL = SelectorPromptTemplate( + default_template=DEFAULT_REFINE_TABLE_CONTEXT_PROMPT, + conditionals=[(is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT)], ) diff --git a/llama_index/prompts/default_prompts.py b/llama_index/prompts/default_prompts.py index 56ea34a48fd155d56e12827ac4b0be833439ab55..41e583f23005efc88ba9c86df052b43312f142df 100644 --- a/llama_index/prompts/default_prompts.py +++ b/llama_index/prompts/default_prompts.py @@ -1,6 +1,6 @@ """Set of default prompts.""" -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType ############################################ @@ -19,7 +19,7 @@ DEFAULT_SUMMARY_PROMPT_TMPL = ( 'SUMMARY:"""\n' ) -DEFAULT_SUMMARY_PROMPT = Prompt( +DEFAULT_SUMMARY_PROMPT = PromptTemplate( DEFAULT_SUMMARY_PROMPT_TMPL, prompt_type=PromptType.SUMMARY ) @@ -37,7 +37,7 @@ DEFAULT_INSERT_PROMPT_TMPL = ( "The answer should be the number corresponding to the " "summary that is most relevant to the question.\n" ) -DEFAULT_INSERT_PROMPT = Prompt( +DEFAULT_INSERT_PROMPT = PromptTemplate( DEFAULT_INSERT_PROMPT_TMPL, prompt_type=PromptType.TREE_INSERT ) @@ -55,7 +55,7 @@ DEFAULT_QUERY_PROMPT_TMPL = ( "Provide choice in the following format: 'ANSWER: <number>' and explain why " "this summary was selected in relation to the question.\n" ) -DEFAULT_QUERY_PROMPT = Prompt( +DEFAULT_QUERY_PROMPT = PromptTemplate( DEFAULT_QUERY_PROMPT_TMPL, prompt_type=PromptType.TREE_SELECT ) @@ -73,7 +73,7 @@ DEFAULT_QUERY_PROMPT_MULTIPLE_TMPL = ( "Provide choices in the following format: 'ANSWER: <numbers>' and explain why " "these summaries were selected in relation to the question.\n" ) -DEFAULT_QUERY_PROMPT_MULTIPLE = Prompt( +DEFAULT_QUERY_PROMPT_MULTIPLE = PromptTemplate( DEFAULT_QUERY_PROMPT_MULTIPLE_TMPL, prompt_type=PromptType.TREE_SELECT_MULTIPLE ) @@ -91,7 +91,7 @@ DEFAULT_REFINE_PROMPT_TMPL = ( "If the context isn't useful, return the original answer.\n" "Refined Answer: " ) -DEFAULT_REFINE_PROMPT = Prompt( +DEFAULT_REFINE_PROMPT = PromptTemplate( DEFAULT_REFINE_PROMPT_TMPL, prompt_type=PromptType.REFINE ) @@ -106,7 +106,7 @@ DEFAULT_TEXT_QA_PROMPT_TMPL = ( "Query: {query_str}\n" "Answer: " ) -DEFAULT_TEXT_QA_PROMPT = Prompt( +DEFAULT_TEXT_QA_PROMPT = PromptTemplate( DEFAULT_TEXT_QA_PROMPT_TMPL, prompt_type=PromptType.QUESTION_ANSWER ) @@ -120,7 +120,7 @@ DEFAULT_TREE_SUMMARIZE_TMPL = ( "Query: {query_str}\n" "Answer: " ) -DEFAULT_TREE_SUMMARIZE_PROMPT = Prompt( +DEFAULT_TREE_SUMMARIZE_PROMPT = PromptTemplate( DEFAULT_TREE_SUMMARIZE_TMPL, prompt_type=PromptType.SUMMARY ) @@ -137,7 +137,7 @@ DEFAULT_KEYWORD_EXTRACT_TEMPLATE_TMPL = ( "---------------------\n" "Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n" ) -DEFAULT_KEYWORD_EXTRACT_TEMPLATE = Prompt( +DEFAULT_KEYWORD_EXTRACT_TEMPLATE = PromptTemplate( DEFAULT_KEYWORD_EXTRACT_TEMPLATE_TMPL, prompt_type=PromptType.KEYWORD_EXTRACT ) @@ -153,7 +153,7 @@ DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = ( "---------------------\n" "Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n" ) -DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE = Prompt( +DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE = PromptTemplate( DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL, prompt_type=PromptType.QUERY_KEYWORD_EXTRACT, ) @@ -179,7 +179,7 @@ DEFAULT_SCHEMA_EXTRACT_TMPL = ( "If no fields are present in the text, return a blank string.\n" "Fields: " ) -DEFAULT_SCHEMA_EXTRACT_PROMPT = Prompt( +DEFAULT_SCHEMA_EXTRACT_PROMPT = PromptTemplate( DEFAULT_SCHEMA_EXTRACT_TMPL, prompt_type=PromptType.SCHEMA_EXTRACT ) @@ -208,7 +208,7 @@ DEFAULT_TEXT_TO_SQL_TMPL = ( "SQLQuery: " ) -DEFAULT_TEXT_TO_SQL_PROMPT = Prompt( +DEFAULT_TEXT_TO_SQL_PROMPT = PromptTemplate( DEFAULT_TEXT_TO_SQL_TMPL, prompt_type=PromptType.TEXT_TO_SQL, ) @@ -238,11 +238,11 @@ DEFAULT_TABLE_CONTEXT_QUERY = ( "...\n\n" ) -DEFAULT_TABLE_CONTEXT_PROMPT = Prompt( +DEFAULT_TABLE_CONTEXT_PROMPT = PromptTemplate( DEFAULT_TABLE_CONTEXT_TMPL, prompt_type=PromptType.TABLE_CONTEXT ) -# NOTE: by partially filling schema, we can reduce to a RefinePrompt +# NOTE: by partially filling schema, we can reduce to a refine prompt # that we can feed to ur table DEFAULT_REFINE_TABLE_CONTEXT_TMPL = ( "We have provided a table schema below. " @@ -259,7 +259,7 @@ DEFAULT_REFINE_TABLE_CONTEXT_TMPL = ( "answer the question. " "If the context isn't useful, return the original answer." ) -DEFAULT_REFINE_TABLE_CONTEXT_PROMPT = Prompt( +DEFAULT_REFINE_TABLE_CONTEXT_PROMPT = PromptTemplate( DEFAULT_REFINE_TABLE_CONTEXT_TMPL, prompt_type=PromptType.TABLE_CONTEXT ) @@ -285,7 +285,7 @@ DEFAULT_KG_TRIPLET_EXTRACT_TMPL = ( "Text: {text}\n" "Triplets:\n" ) -DEFAULT_KG_TRIPLET_EXTRACT_PROMPT = Prompt( +DEFAULT_KG_TRIPLET_EXTRACT_PROMPT = PromptTemplate( DEFAULT_KG_TRIPLET_EXTRACT_TMPL, prompt_type=PromptType.KNOWLEDGE_TRIPLET_EXTRACT ) @@ -304,7 +304,7 @@ HYDE_TMPL = ( 'Passage:"""\n' ) -DEFAULT_HYDE_PROMPT = Prompt(HYDE_TMPL, prompt_type=PromptType.SUMMARY) +DEFAULT_HYDE_PROMPT = PromptTemplate(HYDE_TMPL, prompt_type=PromptType.SUMMARY) ############################################ @@ -312,7 +312,7 @@ DEFAULT_HYDE_PROMPT = Prompt(HYDE_TMPL, prompt_type=PromptType.SUMMARY) ############################################ DEFAULT_SIMPLE_INPUT_TMPL = "{query_str}" -DEFAULT_SIMPLE_INPUT_PROMPT = Prompt( +DEFAULT_SIMPLE_INPUT_PROMPT = PromptTemplate( DEFAULT_SIMPLE_INPUT_TMPL, prompt_type=PromptType.SIMPLE_INPUT ) @@ -333,7 +333,9 @@ DEFAULT_PANDAS_TMPL = ( "Output:\n" ) -DEFAULT_PANDAS_PROMPT = Prompt(DEFAULT_PANDAS_TMPL, prompt_type=PromptType.PANDAS) +DEFAULT_PANDAS_PROMPT = PromptTemplate( + DEFAULT_PANDAS_TMPL, prompt_type=PromptType.PANDAS +) ############################################ @@ -349,6 +351,38 @@ DEFAULT_JSON_PATH_TMPL = ( "JSONPath: " ) -DEFAULT_JSON_PATH_PROMPT = Prompt( +DEFAULT_JSON_PATH_PROMPT = PromptTemplate( DEFAULT_JSON_PATH_TMPL, prompt_type=PromptType.JSON_PATH ) + + +############################################ +# Choice Select +############################################ + +DEFAULT_CHOICE_SELECT_PROMPT_TMPL = ( + "A list of documents is shown below. Each document has a number next to it along " + "with a summary of the document. A question is also provided. \n" + "Respond with the numbers of the documents " + "you should consult to answer the question, in order of relevance, as well \n" + "as the relevance score. The relevance score is a number from 1-10 based on " + "how relevant you think the document is to the question.\n" + "Do not include any documents that are not relevant to the question. \n" + "Example format: \n" + "Document 1:\n<summary of document 1>\n\n" + "Document 2:\n<summary of document 2>\n\n" + "...\n\n" + "Document 10:\n<summary of document 10>\n\n" + "Question: <question>\n" + "Answer:\n" + "Doc: 9, Relevance: 7\n" + "Doc: 3, Relevance: 4\n" + "Doc: 7, Relevance: 3\n\n" + "Let's try this now: \n\n" + "{context_str}\n" + "Question: {query_str}\n" + "Answer:\n" +) +DEFAULT_CHOICE_SELECT_PROMPT = PromptTemplate( + DEFAULT_CHOICE_SELECT_PROMPT_TMPL, prompt_type=PromptType.CHOICE_SELECT +) diff --git a/llama_index/prompts/prompt_selector.py b/llama_index/prompts/prompt_selector.py deleted file mode 100644 index 36b02fe9d2c1ad7c24da143ff053c15199715a7f..0000000000000000000000000000000000000000 --- a/llama_index/prompts/prompt_selector.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Callable, List, Optional, Tuple - -from pydantic import BaseModel, Field - -from llama_index.bridge.langchain import BasePromptTemplate -from llama_index.llms.base import LLM - - -class PromptSelector(BaseModel): - default_prompt: BasePromptTemplate - conditionals: List[Tuple[Callable[[LLM], bool], BasePromptTemplate]] = Field( - default_factory=list - ) - - def select(self, llm: Optional[LLM] = None) -> BasePromptTemplate: - if llm is None: - return self.default_prompt - - for condition, prompt in self.conditionals: - if condition(llm): - return prompt - return self.default_prompt - - -def is_chat_model(llm: LLM) -> bool: - return llm.metadata.is_chat_model diff --git a/llama_index/prompts/prompt_utils.py b/llama_index/prompts/prompt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ca121ed1ba2871734336a960acbe3e18df80a18c --- /dev/null +++ b/llama_index/prompts/prompt_utils.py @@ -0,0 +1,31 @@ +from typing import List +from llama_index.prompts.base import BasePromptTemplate + + +def get_empty_prompt_txt(prompt: BasePromptTemplate) -> str: + """Get empty prompt text. + + Substitute empty strings in parts of the prompt that have + not yet been filled out. Skip variables that have already + been partially formatted. This is used to compute the initial tokens. + + """ + partial_kargs = prompt.kwargs + empty_kwargs = {v: "" for v in prompt.template_vars if v not in partial_kargs} + all_kwargs = {**partial_kargs, **empty_kwargs} + empty_prompt_txt = prompt.format(llm=None, **all_kwargs) + return empty_prompt_txt + + +def get_biggest_prompt(prompts: List[BasePromptTemplate]) -> BasePromptTemplate: + """Get biggest prompt. + + Oftentimes we need to fetch the biggest prompt, in order to + be the most conservative about chunking text. This + is a helper utility for that. + + """ + empty_prompt_txts = [get_empty_prompt_txt(prompt) for prompt in prompts] + empty_prompt_txt_lens = [len(txt) for txt in empty_prompt_txts] + biggest_prompt = prompts[empty_prompt_txt_lens.index(max(empty_prompt_txt_lens))] + return biggest_prompt diff --git a/llama_index/prompts/prompts.py b/llama_index/prompts/prompts.py index 8cfab47b6cd8f6fa24e758c7c5eea9ffde580934..fc1d542276ba9c690a8fc06f4d1de1e8266404cb 100644 --- a/llama_index/prompts/prompts.py +++ b/llama_index/prompts/prompts.py @@ -1,41 +1,41 @@ """Subclasses from base prompt.""" -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate # deprecated, kept for backward compatibility """Summary prompt. -Prompt to summarize the provided `context_str`. +PromptTemplate to summarize the provided `context_str`. Required template variables: `context_str` """ -SummaryPrompt = Prompt +SummaryPrompt = PromptTemplate """Tree Insert prompt. -Prompt to insert a new chunk of text `new_chunk_text` into the tree index. +PromptTemplate to insert a new chunk of text `new_chunk_text` into the tree index. More specifically, this prompt has the LLM select the relevant candidate child node to continue tree traversal. Required template variables: `num_chunks`, `context_list`, `new_chunk_text` """ -TreeInsertPrompt = Prompt +TreeInsertPrompt = PromptTemplate """Tree select prompt. -Prompt to select a candidate child node out of all child nodes +PromptTemplate to select a candidate child node out of all child nodes provided in `context_list`, given a query `query_str`. `num_chunks` is the number of child nodes in `context_list`. Required template variables: `num_chunks`, `context_list`, `query_str` """ -TreeSelectPrompt = Prompt +TreeSelectPrompt = PromptTemplate """Tree select multiple prompt. -Prompt to select multiple candidate child nodes out of all +PromptTemplate to select multiple candidate child nodes out of all child nodes provided in `context_list`, given a query `query_str`. `branching_factor` refers to the number of child nodes to select, and `num_chunks` is the number of child nodes in `context_list`. @@ -43,91 +43,98 @@ child nodes provided in `context_list`, given a query `query_str`. Required template variables: `num_chunks`, `context_list`, `query_str`, `branching_factor` """ -TreeSelectMultiplePrompt = Prompt +TreeSelectMultiplePrompt = PromptTemplate """Refine prompt. -Prompt to refine an existing answer `existing_answer` given a context `context_msg`, -and a query `query_str`. +PromptTemplate to refine an existing answer `existing_answer` +given a context `context_msg`, and a query `query_str`. Required template variables: `query_str`, `existing_answer`, `context_msg` """ -RefinePrompt = Prompt +RefinePrompt = PromptTemplate """Question Answer prompt. -Prompt to answer a question `query_str` given a context `context_str`. +PromptTemplate to answer a question `query_str` given a context `context_str`. Required template variables: `context_str`, `query_str` """ -QuestionAnswerPrompt = Prompt +QuestionAnswerPrompt = PromptTemplate """Keyword extract prompt. -Prompt to extract keywords from a text `text` with a maximum of +PromptTemplate to extract keywords from a text `text` with a maximum of `max_keywords` keywords. Required template variables: `text`, `max_keywords` """ -KeywordExtractPrompt = Prompt +KeywordExtractPrompt = PromptTemplate """Query keyword extract prompt. -Prompt to extract keywords from a query `query_str` with a maximum +PromptTemplate to extract keywords from a query `query_str` with a maximum of `max_keywords` keywords. Required template variables: `query_str`, `max_keywords` """ -QueryKeywordExtractPrompt = Prompt +QueryKeywordExtractPrompt = PromptTemplate """Schema extract prompt. -Prompt to extract schema from unstructured text `text`. +PromptTemplate to extract schema from unstructured text `text`. Required template variables: `text`, `schema` """ -SchemaExtractPrompt = Prompt +SchemaExtractPrompt = PromptTemplate """Text to SQL prompt. -Prompt to translate a natural language query into SQL in the dialect +PromptTemplate to translate a natural language query into SQL in the dialect `dialect` given a schema `schema`. Required template variables: `query_str`, `schema`, `dialect` """ -TextToSQLPrompt = Prompt +TextToSQLPrompt = PromptTemplate """Table context prompt. -Prompt to generate a table context given a table schema `schema`, +PromptTemplate to generate a table context given a table schema `schema`, as well as unstructured text context `context_str`, and a task `query_str`. This includes both a high-level description of the table as well as a description of each column in the table. """ -TableContextPrompt = Prompt +TableContextPrompt = PromptTemplate """Refine Table context prompt. -Prompt to refine a table context given a table schema `schema`, +PromptTemplate to refine a table context given a table schema `schema`, as well as unstructured text context `context_msg`, and a task `query_str`. This includes both a high-level description of the table as well as a description of each column in the table. """ -RefineTableContextPrompt = Prompt +RefineTableContextPrompt = PromptTemplate """Define the knowledge graph triplet extraction prompt.""" -KnowledgeGraphPrompt = Prompt +KnowledgeGraphPrompt = PromptTemplate """Simple Input prompt. Required template variables: `query_str`. """ -SimpleInputPrompt = Prompt +SimpleInputPrompt = PromptTemplate """Pandas prompt. Convert query to python code. Required template variables: `query_str`, `df_str`, `instruction_str`. """ -PandasPrompt = Prompt +PandasPrompt = PromptTemplate + + +"""Choice select prompt. Select from a list of choices. + +Required template variables: `context_str`, `query_str`. +""" +ChoiceSelectPrompt = PromptTemplate diff --git a/llama_index/prompts/utils.py b/llama_index/prompts/utils.py index 90d4a8ae9179ad278f55adacab8b50ec207968a6..b81c8c6496259935f203ecd6420810a11a8487cb 100644 --- a/llama_index/prompts/utils.py +++ b/llama_index/prompts/utils.py @@ -1,35 +1,20 @@ +from string import Formatter from typing import List -from llama_index.prompts.base import Prompt +from llama_index.llms.base import LLM -def get_empty_prompt_txt(prompt: Prompt) -> str: - """Get empty prompt text. +def get_template_vars(template_str: str) -> List[str]: + """Get template variables from a template string.""" + variables = [] + formatter = Formatter() - Substitute empty strings in parts of the prompt that have - not yet been filled out. Skip variables that have already - been partially formatted. This is used to compute the initial tokens. + for _, variable_name, _, _ in formatter.parse(template_str): + if variable_name: + variables.append(variable_name) - """ - fmt_dict = { - v: "" - for v in prompt.get_langchain_prompt().input_variables - if v not in prompt.partial_dict - } - # TODO: change later from llm=None - empty_prompt_txt = prompt.format(llm=None, **fmt_dict) - return empty_prompt_txt + return variables -def get_biggest_prompt(prompts: List[Prompt]) -> Prompt: - """Get biggest prompt. - - Oftentimes we need to fetch the biggest prompt, in order to - be the most conservative about chunking text. This - is a helper utility for that. - - """ - empty_prompt_txts = [get_empty_prompt_txt(prompt) for prompt in prompts] - empty_prompt_txt_lens = [len(txt) for txt in empty_prompt_txts] - biggest_prompt = prompts[empty_prompt_txt_lens.index(max(empty_prompt_txt_lens))] - return biggest_prompt +def is_chat_model(llm: LLM) -> bool: + return llm.metadata.is_chat_model diff --git a/llama_index/query_engine/citation_query_engine.py b/llama_index/query_engine/citation_query_engine.py index 7cc3e7a42742546ded71b3f2bc9b6c2c78eb2c19..508582ed26d6c183062c7b99f5f5b680c093f11b 100644 --- a/llama_index/query_engine/citation_query_engine.py +++ b/llama_index/query_engine/citation_query_engine.py @@ -7,7 +7,8 @@ from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.postprocessor.types import BaseNodePostprocessor from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle -from llama_index.prompts.base import Prompt +from llama_index.prompts import PromptTemplate +from llama_index.prompts.base import BasePromptTemplate from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import ( BaseSynthesizer, @@ -18,7 +19,7 @@ from llama_index.schema import NodeWithScore, TextNode from llama_index.text_splitter import get_default_text_splitter from llama_index.text_splitter.types import TextSplitter -CITATION_QA_TEMPLATE = Prompt( +CITATION_QA_TEMPLATE = PromptTemplate( "Please provide an answer based solely on the provided sources. " "When referencing information from a source, " "cite the appropriate source(s) using their corresponding numbers. " @@ -41,7 +42,7 @@ CITATION_QA_TEMPLATE = Prompt( "Answer: " ) -CITATION_REFINE_TEMPLATE = Prompt( +CITATION_REFINE_TEMPLATE = PromptTemplate( "Please provide an answer based solely on the provided sources. " "When referencing information from a source, " "cite the appropriate source(s) using their corresponding numbers. " @@ -120,8 +121,8 @@ class CitationQueryEngine(BaseQueryEngine): citation_chunk_size: int = DEFAULT_CITATION_CHUNK_SIZE, citation_chunk_overlap: int = DEFAULT_CITATION_CHUNK_OVERLAP, text_splitter: Optional[TextSplitter] = None, - citation_qa_template: Prompt = CITATION_QA_TEMPLATE, - citation_refine_template: Prompt = CITATION_REFINE_TEMPLATE, + citation_qa_template: BasePromptTemplate = CITATION_QA_TEMPLATE, + citation_refine_template: BasePromptTemplate = CITATION_REFINE_TEMPLATE, retriever: Optional[BaseRetriever] = None, node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, # response synthesizer args @@ -142,8 +143,9 @@ class CitationQueryEngine(BaseQueryEngine): text_splitter (Optional[TextSplitter]): A text splitter for creating citation source nodes. Default is a SentenceSplitter. - citation_qa_template (Prompt): Template for initial citation QA - citation_refine_template (Prompt): Template for citation refinement. + citation_qa_template (BasePromptTemplate): Template for initial citation QA + citation_refine_template (BasePromptTemplate): + Template for citation refinement. retriever (BaseRetriever): A retriever object. service_context (Optional[ServiceContext]): A ServiceContext object. node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of diff --git a/llama_index/query_engine/flare/answer_inserter.py b/llama_index/query_engine/flare/answer_inserter.py index f0fb8ba032bb9d4d86214f408ff20947c099181a..abcc5fd6cad17ac1ea1042f6459c550bf1d6468d 100644 --- a/llama_index/query_engine/flare/answer_inserter.py +++ b/llama_index/query_engine/flare/answer_inserter.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import List, Optional from llama_index.query_engine.flare.schema import QueryTask -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.indices.service_context import ServiceContext @@ -117,7 +117,7 @@ Query-Answer Pairs: Synthesized Response: """ -DEFAULT_ANSWER_INSERT_PROMPT = Prompt(DEFAULT_ANSWER_INSERT_PROMPT_TMPL) +DEFAULT_ANSWER_INSERT_PROMPT = PromptTemplate(DEFAULT_ANSWER_INSERT_PROMPT_TMPL) class LLMLookaheadAnswerInserter(BaseLookaheadAnswerInserter): @@ -134,7 +134,7 @@ class LLMLookaheadAnswerInserter(BaseLookaheadAnswerInserter): def __init__( self, service_context: Optional[ServiceContext] = None, - answer_insert_prompt: Optional[Prompt] = None, + answer_insert_prompt: Optional[BasePromptTemplate] = None, ) -> None: """Init params.""" self._service_context = service_context or ServiceContext.from_defaults() diff --git a/llama_index/query_engine/flare/base.py b/llama_index/query_engine/flare/base.py index a9b008c56202a131179eb78d97e0f56c28aee76e..9d214ef92734af2aaea42f41c196bf1343d9be88 100644 --- a/llama_index/query_engine/flare/base.py +++ b/llama_index/query_engine/flare/base.py @@ -10,7 +10,7 @@ from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.service_context import ServiceContext from llama_index.indices.query.schema import QueryBundle from llama_index.response.schema import RESPONSE_TYPE, Response -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate, BasePromptTemplate from llama_index.callbacks.base import CallbackManager from llama_index.query_engine.flare.output_parser import ( IsDoneOutputParser, @@ -87,7 +87,7 @@ Answer: """ ) ) -DEFAULT_INSTRUCT_PROMPT = Prompt(DEFAULT_INSTRUCT_PROMPT_TMPL) +DEFAULT_INSTRUCT_PROMPT = PromptTemplate(DEFAULT_INSTRUCT_PROMPT_TMPL) class FLAREInstructQueryEngine(BaseQueryEngine): @@ -102,7 +102,7 @@ class FLAREInstructQueryEngine(BaseQueryEngine): query_engine (BaseQueryEngine): query engine to use service_context (Optional[ServiceContext]): service context. Defaults to None. - instruct_prompt (Optional[Prompt]): instruct prompt. Defaults to None. + instruct_prompt (Optional[PromptTemplate]): instruct prompt. Defaults to None. lookahead_answer_inserter (Optional[BaseLookaheadAnswerInserter]): lookahead answer inserter. Defaults to None. done_output_parser (Optional[IsDoneOutputParser]): done output parser. @@ -121,7 +121,7 @@ class FLAREInstructQueryEngine(BaseQueryEngine): self, query_engine: BaseQueryEngine, service_context: Optional[ServiceContext] = None, - instruct_prompt: Optional[Prompt] = None, + instruct_prompt: Optional[BasePromptTemplate] = None, lookahead_answer_inserter: Optional[BaseLookaheadAnswerInserter] = None, done_output_parser: Optional[IsDoneOutputParser] = None, query_task_output_parser: Optional[QueryTaskOutputParser] = None, diff --git a/llama_index/query_engine/knowledge_graph_query_engine.py b/llama_index/query_engine/knowledge_graph_query_engine.py index f8bef189ff4220569177486421e14a88b9e42c4d..36a18f5f190fcfdf0e4489fe8c0dd1c2a7a8951c 100644 --- a/llama_index/query_engine/knowledge_graph_query_engine.py +++ b/llama_index/query_engine/knowledge_graph_query_engine.py @@ -12,7 +12,7 @@ from llama_index.graph_stores.registery import ( from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext -from llama_index.prompts.base import Prompt, PromptType +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate, PromptType from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer from llama_index.schema import NodeWithScore, TextNode @@ -49,7 +49,7 @@ Question: {query_str} NebulaGraph Cypher dialect query: """ -DEFAULT_NEBULAGRAPH_NL2CYPHER_PROMPT = Prompt( +DEFAULT_NEBULAGRAPH_NL2CYPHER_PROMPT = PromptTemplate( DEFAULT_NEBULAGRAPH_NL2CYPHER_PROMPT_TMPL, prompt_type=PromptType.TEXT_TO_GRAPH_QUERY, ) @@ -71,7 +71,7 @@ DEFAULT_NEO4J_NL2CYPHER_PROMPT_TMPL = ( "{query_str}\n" ) -DEFAULT_NEO4J_NL2CYPHER_PROMPT = Prompt( +DEFAULT_NEO4J_NL2CYPHER_PROMPT = PromptTemplate( DEFAULT_NEO4J_NL2CYPHER_PROMPT_TMPL, prompt_type=PromptType.TEXT_TO_GRAPH_QUERY, ) @@ -93,7 +93,7 @@ Graph response: {kg_response_str} Response: """ -DEFAULT_KG_RESPONSE_ANSWER_PROMPT = Prompt( +DEFAULT_KG_RESPONSE_ANSWER_PROMPT = PromptTemplate( DEFAULT_KG_RESPONSE_ANSWER_PROMPT_TMPL, prompt_type=PromptType.QUESTION_ANSWER, ) @@ -119,8 +119,8 @@ class KnowledgeGraphQueryEngine(BaseQueryEngine): self, service_context: Optional[ServiceContext] = None, storage_context: Optional[StorageContext] = None, - graph_query_synthesis_prompt: Optional[Prompt] = None, - graph_response_answer_prompt: Optional[Prompt] = None, + graph_query_synthesis_prompt: Optional[BasePromptTemplate] = None, + graph_response_answer_prompt: Optional[BasePromptTemplate] = None, refresh_schema: bool = False, verbose: bool = False, response_synthesizer: Optional[BaseSynthesizer] = None, diff --git a/llama_index/query_engine/pandas_query_engine.py b/llama_index/query_engine/pandas_query_engine.py index 95749eec1f5e1a5dd7b3e487f095e1b9ebdc2b05..63533409dcfa0b33fb0b74d1dad3a86641db52ba 100644 --- a/llama_index/query_engine/pandas_query_engine.py +++ b/llama_index/query_engine/pandas_query_engine.py @@ -10,16 +10,16 @@ require heavy sandboxing or virtual machines import logging from typing import Any, Callable, Optional -import pandas as pd import numpy as np -from llama_index.bridge.langchain import print_text +import pandas as pd +from llama_index.bridge.langchain import print_text from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext from llama_index.indices.struct_store.pandas import PandasIndex +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_PANDAS_PROMPT -from llama_index.prompts.prompts import PandasPrompt from llama_index.response.schema import Response logger = logging.getLogger(__name__) @@ -91,7 +91,7 @@ class PandasQueryEngine(BaseQueryEngine): output_processor (Optional[Callable[[str], str]]): Output processor. A callable that takes in the output string, pandas DataFrame, and any output kwargs and returns a string. - pandas_prompt (Optional[PandasPrompt]): Pandas prompt to use. + pandas_prompt (Optional[BasePromptTemplate]): Pandas prompt to use. head (int): Number of rows to show in the table context. """ @@ -101,7 +101,7 @@ class PandasQueryEngine(BaseQueryEngine): df: pd.DataFrame, instruction_str: Optional[str] = None, output_processor: Optional[Callable] = None, - pandas_prompt: Optional[PandasPrompt] = None, + pandas_prompt: Optional[BasePromptTemplate] = None, output_kwargs: Optional[dict] = None, head: int = 5, verbose: bool = False, diff --git a/llama_index/query_engine/retriever_query_engine.py b/llama_index/query_engine/retriever_query_engine.py index ca8dc0f86ddf45adef19619b452612e5ded0cbaa..7fcc0ada4a4c670854ccf8629b1275167e03ec88 100644 --- a/llama_index/query_engine/retriever_query_engine.py +++ b/llama_index/query_engine/retriever_query_engine.py @@ -7,17 +7,13 @@ from llama_index.indices.postprocessor.types import BaseNodePostprocessor from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext +from llama_index.prompts import BasePromptTemplate +from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import ( BaseSynthesizer, ResponseMode, get_response_synthesizer, ) -from llama_index.prompts.prompts import ( - QuestionAnswerPrompt, - RefinePrompt, - SimpleInputPrompt, -) -from llama_index.response.schema import RESPONSE_TYPE from llama_index.schema import NodeWithScore @@ -55,9 +51,9 @@ class RetrieverQueryEngine(BaseQueryEngine): node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, # response synthesizer args response_mode: ResponseMode = ResponseMode.COMPACT, - text_qa_template: Optional[QuestionAnswerPrompt] = None, - refine_template: Optional[RefinePrompt] = None, - simple_template: Optional[SimpleInputPrompt] = None, + text_qa_template: Optional[BasePromptTemplate] = None, + refine_template: Optional[BasePromptTemplate] = None, + simple_template: Optional[BasePromptTemplate] = None, use_async: bool = False, streaming: bool = False, # class-specific args @@ -72,10 +68,10 @@ class RetrieverQueryEngine(BaseQueryEngine): node postprocessors. verbose (bool): Whether to print out debug info. response_mode (ResponseMode): A ResponseMode object. - text_qa_template (Optional[QuestionAnswerPrompt]): A QuestionAnswerPrompt + text_qa_template (Optional[BasePromptTemplate]): A BasePromptTemplate object. - refine_template (Optional[RefinePrompt]): A RefinePrompt object. - simple_template (Optional[SimpleInputPrompt]): A SimpleInputPrompt object. + refine_template (Optional[BasePromptTemplate]): A BasePromptTemplate object. + simple_template (Optional[BasePromptTemplate]): A BasePromptTemplate object. use_async (bool): Whether to use async. streaming (bool): Whether to use streaming. diff --git a/llama_index/query_engine/sql_join_query_engine.py b/llama_index/query_engine/sql_join_query_engine.py index 69d8d2288377457f31f4a161a86d7ac280c6aad0..592ce061e347796b48465fea6572991f09603001 100644 --- a/llama_index/query_engine/sql_join_query_engine.py +++ b/llama_index/query_engine/sql_join_query_engine.py @@ -1,36 +1,40 @@ """SQL Join query engine.""" +import logging +from typing import Callable, Dict, Optional, Union + from llama_index.bridge.langchain import print_text -from typing import Optional, Dict, Callable, Union +from llama_index.callbacks.base import CallbackManager from llama_index.indices.query.base import BaseQueryEngine +from llama_index.indices.query.query_transform.base import BaseQueryTransform +from llama_index.indices.query.schema import QueryBundle +from llama_index.indices.service_context import ServiceContext from llama_index.indices.struct_store.sql_query import ( BaseSQLTableQueryEngine, NLSQLTableQueryEngine, ) -from llama_index.indices.query.schema import QueryBundle +from llama_index.llm_predictor import LLMPredictor +from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.response.schema import RESPONSE_TYPE, Response -from llama_index.tools.query_engine import QueryEngineTool -from llama_index.indices.service_context import ServiceContext -from llama_index.selectors.utils import get_selector_from_context from llama_index.selectors.llm_selectors import LLMSingleSelector from llama_index.selectors.pydantic_selectors import PydanticSingleSelector -from llama_index.prompts.base import Prompt -from llama_index.indices.query.query_transform.base import BaseQueryTransform -import logging -from llama_index.llm_predictor import LLMPredictor -from llama_index.llm_predictor.base import BaseLLMPredictor -from llama_index.callbacks.base import CallbackManager +from llama_index.selectors.utils import get_selector_from_context +from llama_index.tools.query_engine import QueryEngineTool logger = logging.getLogger(__name__) DEFAULT_SQL_JOIN_SYNTHESIS_PROMPT_TMPL = """ The original question is given below. -This question has been translated into a SQL query. Both the SQL query and the response are given below. -Given the SQL response, the question has also been transformed into a more detailed query, +This question has been translated into a SQL query. Both the SQL query and \ +the response are given below. +Given the SQL response, the question has also been transformed into a more \ +detailed query, and executed against another query engine. The transformed query and query engine response are also given below. -Given SQL query, SQL response, transformed query, and query engine response, please synthesize a response to the original question. +Given SQL query, SQL response, transformed query, and query engine response, \ +please synthesize a response to the original question. Original question: {query_str} SQL query: {sql_query_str} @@ -39,18 +43,26 @@ Transformed query: {query_engine_query_str} Query engine response: {query_engine_response_str} Response: """ # noqa -DEFAULT_SQL_JOIN_SYNTHESIS_PROMPT = Prompt(DEFAULT_SQL_JOIN_SYNTHESIS_PROMPT_TMPL) +DEFAULT_SQL_JOIN_SYNTHESIS_PROMPT = PromptTemplate( + DEFAULT_SQL_JOIN_SYNTHESIS_PROMPT_TMPL +) DEFAULT_SQL_AUGMENT_TRANSFORM_PROMPT_TMPL = """ "The original question is given below. -This question has been translated into a SQL query. Both the SQL query and the response are given below. -The SQL response either answers the question, or should provide additional context that can be used to make the question more specific. -Your job is to come up with a more specific question that needs to be answered to fully answer the original question, or 'None' if the original question has already been fully answered from the SQL response. Do not create a new question that is irrelevant to the original question; in that case return None instead. +This question has been translated into a SQL query. Both the SQL query and the \ +response are given below. +The SQL response either answers the question, or should provide additional context \ +that can be used to make the question more specific. +Your job is to come up with a more specific question that needs to be answered to \ +fully answer the original question, or 'None' if the original question has already \ +been fully answered from the SQL response. Do not create a new question that is \ +irrelevant to the original question; in that case return None instead. Examples: -Original question: Please give more details about the demographics of the city with the highest population. +Original question: Please give more details about the demographics of the city with \ +the highest population. SQL query: SELECT city, population FROM cities ORDER BY population DESC LIMIT 1 SQL response: The city with the highest population is New York City. New question: Can you tell me more about the demographics of New York City? @@ -75,7 +87,9 @@ SQL query: {sql_query_str} SQL response: {sql_response_str} New question: " """ # noqa -DEFAULT_SQL_AUGMENT_TRANSFORM_PROMPT = Prompt(DEFAULT_SQL_AUGMENT_TRANSFORM_PROMPT_TMPL) +DEFAULT_SQL_AUGMENT_TRANSFORM_PROMPT = PromptTemplate( + DEFAULT_SQL_AUGMENT_TRANSFORM_PROMPT_TMPL +) def _default_check_stop(query_bundle: QueryBundle) -> bool: @@ -96,7 +110,8 @@ class SQLAugmentQueryTransform(BaseQueryTransform): Args: llm_predictor (LLMPredictor): LLM predictor to use for query transformation. - sql_augment_transform_prompt (Prompt): Prompt to use for query transformation. + sql_augment_transform_prompt (BasePromptTemplate): PromptTemplate to use + for query transformation. check_stop_parser (Optional[Callable[[str], bool]]): Check stop function. """ @@ -104,7 +119,7 @@ class SQLAugmentQueryTransform(BaseQueryTransform): def __init__( self, llm_predictor: Optional[BaseLLMPredictor] = None, - sql_augment_transform_prompt: Optional[Prompt] = None, + sql_augment_transform_prompt: Optional[BasePromptTemplate] = None, check_stop_parser: Optional[Callable[[QueryBundle], bool]] = None, ) -> None: """Initialize params.""" @@ -150,8 +165,8 @@ class SQLJoinQueryEngine(BaseQueryEngine): selector (Optional[Union[LLMSingleSelector, PydanticSingleSelector]]): Selector to use. service_context (Optional[ServiceContext]): Service context to use. - sql_join_synthesis_prompt (Optional[Prompt]): Prompt to use for SQL join - synthesis. + sql_join_synthesis_prompt (Optional[BasePromptTemplate]): + PromptTemplate to use for SQL join synthesis. sql_augment_query_transform (Optional[SQLAugmentQueryTransform]): Query transform to use for SQL augmentation. use_sql_join_synthesis (bool): Whether to use SQL join synthesis. @@ -166,7 +181,7 @@ class SQLJoinQueryEngine(BaseQueryEngine): other_query_tool: QueryEngineTool, selector: Optional[Union[LLMSingleSelector, PydanticSingleSelector]] = None, service_context: Optional[ServiceContext] = None, - sql_join_synthesis_prompt: Optional[Prompt] = None, + sql_join_synthesis_prompt: Optional[BasePromptTemplate] = None, sql_augment_query_transform: Optional[SQLAugmentQueryTransform] = None, use_sql_join_synthesis: bool = True, callback_manager: Optional[CallbackManager] = None, diff --git a/llama_index/query_engine/sql_vector_query_engine.py b/llama_index/query_engine/sql_vector_query_engine.py index 91120375b20a218923e44837ef7fdbd26581f2b3..1fd7c81bff564803a014e474d235c9643a9c08de 100644 --- a/llama_index/query_engine/sql_vector_query_engine.py +++ b/llama_index/query_engine/sql_vector_query_engine.py @@ -1,6 +1,10 @@ """SQL Vector query engine.""" -from typing import Optional, Any, Union +import logging +from typing import Any, Optional, Union + +from llama_index.callbacks.base import CallbackManager +from llama_index.indices.service_context import ServiceContext from llama_index.indices.struct_store.sql_query import ( BaseSQLTableQueryEngine, NLSQLTableQueryEngine, @@ -8,28 +12,27 @@ from llama_index.indices.struct_store.sql_query import ( from llama_index.indices.vector_store.retrievers.auto_retriever import ( VectorIndexAutoRetriever, ) -from llama_index.tools.query_engine import QueryEngineTool +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine -from llama_index.indices.service_context import ServiceContext -from llama_index.selectors.llm_selectors import LLMSingleSelector -from llama_index.selectors.pydantic_selectors import PydanticSingleSelector -from llama_index.prompts.base import Prompt -import logging -from llama_index.callbacks.base import CallbackManager from llama_index.query_engine.sql_join_query_engine import ( - SQLJoinQueryEngine, SQLAugmentQueryTransform, + SQLJoinQueryEngine, ) +from llama_index.selectors.llm_selectors import LLMSingleSelector +from llama_index.selectors.pydantic_selectors import PydanticSingleSelector +from llama_index.tools.query_engine import QueryEngineTool logger = logging.getLogger(__name__) DEFAULT_SQL_VECTOR_SYNTHESIS_PROMPT_TMPL = """ The original question is given below. -This question has been translated into a SQL query. Both the SQL query and the response are given below. +This question has been translated into a SQL query. \ +Both the SQL query and the response are given below. Given the SQL response, the question has also been translated into a vector store query. The vector store query and response is given below. -Given SQL query, SQL response, transformed vector store query, and vector store response, please synthesize a response to the original question. +Given SQL query, SQL response, transformed vector store query, and vector store \ +response, please synthesize a response to the original question. Original question: {query_str} SQL query: {sql_query_str} @@ -38,7 +41,9 @@ Transformed vector store query: {query_engine_query_str} Vector store response: {query_engine_response_str} Response: """ # noqa -DEFAULT_SQL_VECTOR_SYNTHESIS_PROMPT = Prompt(DEFAULT_SQL_VECTOR_SYNTHESIS_PROMPT_TMPL) +DEFAULT_SQL_VECTOR_SYNTHESIS_PROMPT = PromptTemplate( + DEFAULT_SQL_VECTOR_SYNTHESIS_PROMPT_TMPL +) # NOTE: maintain for backwards compatibility @@ -58,8 +63,8 @@ class SQLAutoVectorQueryEngine(SQLJoinQueryEngine): selector (Optional[Union[LLMSingleSelector, PydanticSingleSelector]]): Selector to use. service_context (Optional[ServiceContext]): Service context to use. - sql_vector_synthesis_prompt (Optional[Prompt]): Prompt to use for SQL vector - synthesis. + sql_vector_synthesis_prompt (Optional[BasePromptTemplate]): + Prompt to use for SQL vector synthesis. sql_augment_query_transform (Optional[SQLAugmentQueryTransform]): Query transform to use for SQL augmentation. use_sql_vector_synthesis (bool): Whether to use SQL vector synthesis. @@ -74,7 +79,7 @@ class SQLAutoVectorQueryEngine(SQLJoinQueryEngine): vector_query_tool: QueryEngineTool, selector: Optional[Union[LLMSingleSelector, PydanticSingleSelector]] = None, service_context: Optional[ServiceContext] = None, - sql_vector_synthesis_prompt: Optional[Prompt] = None, + sql_vector_synthesis_prompt: Optional[BasePromptTemplate] = None, sql_augment_query_transform: Optional[SQLAugmentQueryTransform] = None, use_sql_vector_synthesis: bool = True, callback_manager: Optional[CallbackManager] = None, diff --git a/llama_index/question_gen/llm_generators.py b/llama_index/question_gen/llm_generators.py index f61e027db98463e7753b65a4fae2709c1e38cd2f..ce8836732bfde2fab72f1a4a736954cc9234d53f 100644 --- a/llama_index/question_gen/llm_generators.py +++ b/llama_index/question_gen/llm_generators.py @@ -5,7 +5,7 @@ from llama_index.indices.service_context import ServiceContext from llama_index.llm_predictor.base import BaseLLMPredictor from llama_index.output_parsers.base import StructuredOutput from llama_index.types import BaseOutputParser -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.question_gen.output_parser import SubQuestionOutputParser from llama_index.question_gen.prompts import ( @@ -20,7 +20,7 @@ class LLMQuestionGenerator(BaseQuestionGenerator): def __init__( self, llm_predictor: BaseLLMPredictor, - prompt: Prompt, + prompt: BasePromptTemplate, ) -> None: self._llm_predictor = llm_predictor self._prompt = prompt @@ -41,7 +41,7 @@ class LLMQuestionGenerator(BaseQuestionGenerator): output_parser = output_parser or SubQuestionOutputParser() # construct prompt - prompt = Prompt( + prompt = PromptTemplate( template=prompt_template_str, output_parser=output_parser, prompt_type=PromptType.SUB_QUESTION, diff --git a/llama_index/question_gen/prompts.py b/llama_index/question_gen/prompts.py index daa7738fa6e6e02b315b8445e508ab66116f9dc4..aed8f382af513ba031530546313d07d032649ae2 100644 --- a/llama_index/question_gen/prompts.py +++ b/llama_index/question_gen/prompts.py @@ -1,12 +1,12 @@ import json from typing import Sequence -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate from llama_index.question_gen.types import SubQuestion from llama_index.tools.types import ToolMetadata # deprecated, kept for backward compatibility -SubQuestionPrompt = Prompt +SubQuestionPrompt = PromptTemplate def build_tools_text(tools: Sequence[ToolMetadata]) -> str: diff --git a/llama_index/response_synthesizers/accumulate.py b/llama_index/response_synthesizers/accumulate.py index ebce6f9c2187d58b8156705baa0b4077a8df0296..5f00979e0146dc2d6af4bbd22844d53f681d3d2c 100644 --- a/llama_index/response_synthesizers/accumulate.py +++ b/llama_index/response_synthesizers/accumulate.py @@ -1,12 +1,10 @@ import asyncio -from typing import Any, List, Sequence, Optional +from typing import Any, List, Optional, Sequence from llama_index.async_utils import run_async_tasks from llama_index.indices.service_context import ServiceContext -from llama_index.prompts.default_prompts import ( - DEFAULT_TEXT_QA_PROMPT, -) -from llama_index.prompts.prompts import QuestionAnswerPrompt +from llama_index.prompts import BasePromptTemplate +from llama_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT from llama_index.response_synthesizers.base import BaseSynthesizer from llama_index.types import RESPONSE_TEXT_TYPE @@ -16,7 +14,7 @@ class Accumulate(BaseSynthesizer): def __init__( self, - text_qa_template: Optional[QuestionAnswerPrompt] = None, + text_qa_template: Optional[BasePromptTemplate] = None, service_context: Optional[ServiceContext] = None, streaming: bool = False, use_async: bool = False, diff --git a/llama_index/response_synthesizers/compact_and_refine.py b/llama_index/response_synthesizers/compact_and_refine.py index c99617f9306935fb356c5e72ae5e1adb2ec3b6e5..ee49ff6d7f41f2a758aa4f0e942e1e46061d7d4c 100644 --- a/llama_index/response_synthesizers/compact_and_refine.py +++ b/llama_index/response_synthesizers/compact_and_refine.py @@ -1,6 +1,6 @@ from typing import Any, List, Sequence -from llama_index.prompts.utils import get_biggest_prompt +from llama_index.prompts.prompt_utils import get_biggest_prompt from llama_index.response_synthesizers.refine import Refine from llama_index.types import RESPONSE_TEXT_TYPE diff --git a/llama_index/response_synthesizers/factory.py b/llama_index/response_synthesizers/factory.py index 2ef1f015f6cefeac4e56f30514428ed68cd2871a..0405da137f11d01a487a1e2451eb41f90cfc7b87 100644 --- a/llama_index/response_synthesizers/factory.py +++ b/llama_index/response_synthesizers/factory.py @@ -2,18 +2,13 @@ from typing import Optional from llama_index.callbacks.base import CallbackManager from llama_index.indices.service_context import ServiceContext +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import ( + DEFAULT_REFINE_PROMPT_SEL, DEFAULT_TEXT_QA_PROMPT_SEL, DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, - DEFAULT_REFINE_PROMPT_SEL, ) from llama_index.prompts.default_prompts import DEFAULT_SIMPLE_INPUT_PROMPT -from llama_index.prompts.prompts import ( - QuestionAnswerPrompt, - RefinePrompt, - SimpleInputPrompt, - SummaryPrompt, -) from llama_index.response_synthesizers.accumulate import Accumulate from llama_index.response_synthesizers.base import BaseSynthesizer from llama_index.response_synthesizers.compact_and_accumulate import ( @@ -21,19 +16,19 @@ from llama_index.response_synthesizers.compact_and_accumulate import ( ) from llama_index.response_synthesizers.compact_and_refine import CompactAndRefine from llama_index.response_synthesizers.generation import Generation +from llama_index.response_synthesizers.no_text import NoText from llama_index.response_synthesizers.refine import Refine from llama_index.response_synthesizers.simple_summarize import SimpleSummarize from llama_index.response_synthesizers.tree_summarize import TreeSummarize from llama_index.response_synthesizers.type import ResponseMode -from llama_index.response_synthesizers.no_text import NoText def get_response_synthesizer( service_context: Optional[ServiceContext] = None, - text_qa_template: Optional[QuestionAnswerPrompt] = None, - refine_template: Optional[RefinePrompt] = None, - summary_template: Optional[SummaryPrompt] = None, - simple_template: Optional[SimpleInputPrompt] = None, + text_qa_template: Optional[BasePromptTemplate] = None, + refine_template: Optional[BasePromptTemplate] = None, + summary_template: Optional[BasePromptTemplate] = None, + simple_template: Optional[BasePromptTemplate] = None, response_mode: ResponseMode = ResponseMode.COMPACT, callback_manager: Optional[CallbackManager] = None, use_async: bool = False, diff --git a/llama_index/response_synthesizers/generation.py b/llama_index/response_synthesizers/generation.py index 3524e85d22c70bcecc90b23e65c226686609fd04..6bd4617034632c00015b733b5846a61c7b55d307 100644 --- a/llama_index/response_synthesizers/generation.py +++ b/llama_index/response_synthesizers/generation.py @@ -1,8 +1,8 @@ from typing import Any, Optional, Sequence from llama_index.indices.service_context import ServiceContext +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_SIMPLE_INPUT_PROMPT -from llama_index.prompts.prompts import SimpleInputPrompt from llama_index.response_synthesizers.base import BaseSynthesizer from llama_index.types import RESPONSE_TEXT_TYPE @@ -10,7 +10,7 @@ from llama_index.types import RESPONSE_TEXT_TYPE class Generation(BaseSynthesizer): def __init__( self, - simple_template: Optional[SimpleInputPrompt] = None, + simple_template: Optional[BasePromptTemplate] = None, service_context: Optional[ServiceContext] = None, streaming: bool = False, ) -> None: diff --git a/llama_index/response_synthesizers/refine.py b/llama_index/response_synthesizers/refine.py index 0a9ea51b5c3112fe9d935dc2258e7e999d516199..a686908aab3504f724fc2ee6a73509b77016f1ff 100644 --- a/llama_index/response_synthesizers/refine.py +++ b/llama_index/response_synthesizers/refine.py @@ -3,11 +3,11 @@ from typing import Any, Generator, Optional, Sequence, cast from llama_index.indices.service_context import ServiceContext from llama_index.indices.utils import truncate_text +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import ( - DEFAULT_TEXT_QA_PROMPT_SEL, DEFAULT_REFINE_PROMPT_SEL, + DEFAULT_TEXT_QA_PROMPT_SEL, ) -from llama_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt from llama_index.response.utils import get_response_text from llama_index.response_synthesizers.base import BaseSynthesizer from llama_index.types import RESPONSE_TEXT_TYPE @@ -21,8 +21,8 @@ class Refine(BaseSynthesizer): def __init__( self, service_context: Optional[ServiceContext] = None, - text_qa_template: Optional[QuestionAnswerPrompt] = None, - refine_template: Optional[RefinePrompt] = None, + text_qa_template: Optional[BasePromptTemplate] = None, + refine_template: Optional[BasePromptTemplate] = None, streaming: bool = False, verbose: bool = False, ) -> None: diff --git a/llama_index/response_synthesizers/simple_summarize.py b/llama_index/response_synthesizers/simple_summarize.py index 47352a554ca9a34287b48c0120438db11efa99b4..3a983d5b385af7e41ae4d01cc27fbc3bc8e25fd3 100644 --- a/llama_index/response_synthesizers/simple_summarize.py +++ b/llama_index/response_synthesizers/simple_summarize.py @@ -1,8 +1,8 @@ from typing import Any, Generator, Optional, Sequence, cast from llama_index.indices.service_context import ServiceContext +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import DEFAULT_TEXT_QA_PROMPT_SEL -from llama_index.prompts.prompts import QuestionAnswerPrompt from llama_index.response_synthesizers.base import BaseSynthesizer from llama_index.types import RESPONSE_TEXT_TYPE @@ -10,7 +10,7 @@ from llama_index.types import RESPONSE_TEXT_TYPE class SimpleSummarize(BaseSynthesizer): def __init__( self, - text_qa_template: Optional[QuestionAnswerPrompt] = None, + text_qa_template: Optional[BasePromptTemplate] = None, service_context: Optional[ServiceContext] = None, streaming: bool = False, ) -> None: diff --git a/llama_index/response_synthesizers/tree_summarize.py b/llama_index/response_synthesizers/tree_summarize.py index 3b037317f7f805eabb369f84aeed2c5755a23dcd..c45fd7b501c448eeaa849933790ad622671e9e20 100644 --- a/llama_index/response_synthesizers/tree_summarize.py +++ b/llama_index/response_synthesizers/tree_summarize.py @@ -3,10 +3,10 @@ from typing import Any, List, Optional, Sequence from llama_index.async_utils import run_async_tasks from llama_index.indices.service_context import ServiceContext +from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import ( DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, ) -from llama_index.prompts.prompts import SummaryPrompt from llama_index.response_synthesizers.base import BaseSynthesizer from llama_index.types import RESPONSE_TEXT_TYPE @@ -26,7 +26,7 @@ class TreeSummarize(BaseSynthesizer): def __init__( self, - summary_template: Optional[SummaryPrompt] = None, + summary_template: Optional[BasePromptTemplate] = None, service_context: Optional[ServiceContext] = None, streaming: bool = False, use_async: bool = False, diff --git a/llama_index/selectors/prompts.py b/llama_index/selectors/prompts.py index 88dd0186a57c846cdb2afbe76b9619d8d58fd93c..b6331babb0101685b6ffb6d54853c1dcebc22439 100644 --- a/llama_index/selectors/prompts.py +++ b/llama_index/selectors/prompts.py @@ -1,25 +1,25 @@ -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType """Single select prompt. -Prompt to select one out of `num_choices` options provided in `context_list`, +PromptTemplate to select one out of `num_choices` options provided in `context_list`, given a query `query_str`. Required template variables: `num_chunks`, `context_list`, `query_str` """ -SingleSelectPrompt = Prompt +SingleSelectPrompt = PromptTemplate """Multiple select prompt. -Prompt to select multiple candidates (up to `max_outputs`) out of `num_choices` +PromptTemplate to select multiple candidates (up to `max_outputs`) out of `num_choices` options provided in `context_list`, given a query `query_str`. Required template variables: `num_chunks`, `context_list`, `query_str`, `max_outputs` """ -MultiSelectPrompt = Prompt +MultiSelectPrompt = PromptTemplate # single select @@ -35,7 +35,7 @@ DEFAULT_SINGLE_SELECT_PROMPT_TMPL = ( ) -DEFAULT_SINGLE_SELECT_PROMPT = Prompt( +DEFAULT_SINGLE_SELECT_PROMPT = PromptTemplate( template=DEFAULT_SINGLE_SELECT_PROMPT_TMPL, prompt_type=PromptType.SINGLE_SELECT ) @@ -54,7 +54,7 @@ DEFAULT_MULTI_SELECT_PROMPT_TMPL = ( ) -DEFAULT_MULTIPLE_SELECT_PROMPT = Prompt( +DEFAULT_MULTIPLE_SELECT_PROMPT = PromptTemplate( template=DEFAULT_MULTI_SELECT_PROMPT_TMPL, prompt_type=PromptType.MULTI_SELECT ) diff --git a/llama_index/types.py b/llama_index/types.py index b4cf82bb75859e2e649285bc9fa2696e03108c02..2d6d2ba72055f3320fe6aad704468610cf9a238a 100644 --- a/llama_index/types.py +++ b/llama_index/types.py @@ -1,5 +1,14 @@ -from typing import AsyncGenerator, Generator, Union, Protocol, Any, TypeVar from abc import abstractmethod +from typing import ( + Any, + AsyncGenerator, + Generator, + Protocol, + TypeVar, + Union, + runtime_checkable, +) + from pydantic import BaseModel Model = TypeVar("Model", bound=BaseModel) @@ -10,6 +19,8 @@ RESPONSE_TEXT_TYPE = Union[str, TokenGen] # TODO: move into a `core` folder +# NOTE: this is necessary to make it compatible with pydantic +@runtime_checkable class BaseOutputParser(Protocol): """Output parser class.""" diff --git a/tests/indices/list/test_retrievers.py b/tests/indices/list/test_retrievers.py index f65378fae6ee2a405585f4e659ecedf5ea3dfcf4..c830dbc4a1fbe4d1e6cca866ccc6eafeda0f433d 100644 --- a/tests/indices/list/test_retrievers.py +++ b/tests/indices/list/test_retrievers.py @@ -5,8 +5,7 @@ from llama_index.indices.list.base import ListIndex from llama_index.indices.list.retrievers import ListIndexEmbeddingRetriever from llama_index.indices.service_context import ServiceContext from llama_index.llm_predictor.base import LLMPredictor -from llama_index.prompts.choice_select import ChoiceSelectPrompt -from llama_index.prompts.prompts import Prompt +from llama_index.prompts import BasePromptTemplate from llama_index.schema import Document from tests.indices.list.test_index import _get_embeddings @@ -47,9 +46,10 @@ def test_embedding_query( assert nodes[0].node.get_content() == "Hello world." -def mock_llmpredictor_predict(self: Any, prompt: Prompt, **prompt_args: Any) -> str: +def mock_llmpredictor_predict( + self: Any, prompt: BasePromptTemplate, **prompt_args: Any +) -> str: """Patch llm predictor predict.""" - assert isinstance(prompt, ChoiceSelectPrompt) return "Doc: 2, Relevance: 5" diff --git a/tests/indices/postprocessor/test_llm_rerank.py b/tests/indices/postprocessor/test_llm_rerank.py index 3e06fcfc493c7124f9120e4b6bbe6942f3db2373..8a26cb86d29d376b571373d3c0f85aa7949994dd 100644 --- a/tests/indices/postprocessor/test_llm_rerank.py +++ b/tests/indices/postprocessor/test_llm_rerank.py @@ -1,19 +1,20 @@ """Test LLM reranker.""" -from llama_index.indices.query.schema import QueryBundle -from llama_index.prompts.prompts import Prompt -from llama_index.llm_predictor import LLMPredictor +from typing import Any, List from unittest.mock import patch -from typing import List, Any -from llama_index.prompts.prompts import QuestionAnswerPrompt + from llama_index.indices.postprocessor.llm_rerank import LLMRerank +from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext -from llama_index.schema import BaseNode, TextNode, NodeWithScore +from llama_index.llm_predictor import LLMPredictor +from llama_index.prompts import BasePromptTemplate +from llama_index.schema import BaseNode, NodeWithScore, TextNode -def mock_llmpredictor_predict(self: Any, prompt: Prompt, **prompt_args: Any) -> str: +def mock_llmpredictor_predict( + self: Any, prompt: BasePromptTemplate, **prompt_args: Any +) -> str: """Patch llm predictor predict.""" - assert isinstance(prompt, QuestionAnswerPrompt) context_str = prompt_args["context_str"] node_strs = context_str.split("\n") node_to_choice_and_score = { diff --git a/tests/indices/response/test_response_builder.py b/tests/indices/response/test_response_builder.py index bb18561dcc63656a5d7083bbf61d2cdd7c36045a..7f5223a5f3dc2f8e7d1121a20150ee61565ca0ec 100644 --- a/tests/indices/response/test_response_builder.py +++ b/tests/indices/response/test_response_builder.py @@ -6,7 +6,7 @@ from typing import List from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS from llama_index.indices.prompt_helper import PromptHelper from llama_index.indices.service_context import ServiceContext -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.response_synthesizers import ResponseMode, get_response_synthesizer from llama_index.schema import Document @@ -58,10 +58,14 @@ def test_compact_response(mock_service_context: ServiceContext) -> None: # test response with ResponseMode.COMPACT # NOTE: here we want to guarante that prompts have 0 extra tokens mock_refine_prompt_tmpl = "{query_str}{existing_answer}{context_msg}" - mock_refine_prompt = Prompt(mock_refine_prompt_tmpl, prompt_type=PromptType.REFINE) + mock_refine_prompt = PromptTemplate( + mock_refine_prompt_tmpl, prompt_type=PromptType.REFINE + ) mock_qa_prompt_tmpl = "{context_str}{query_str}" - mock_qa_prompt = Prompt(mock_qa_prompt_tmpl, prompt_type=PromptType.QUESTION_ANSWER) + mock_qa_prompt = PromptTemplate( + mock_qa_prompt_tmpl, prompt_type=PromptType.QUESTION_ANSWER + ) # max input size is 11, prompt is two tokens (the query) --> 9 tokens # --> padding is 1 --> 8 tokens @@ -106,7 +110,9 @@ def test_accumulate_response( # test response with ResponseMode.ACCUMULATE # NOTE: here we want to guarante that prompts have 0 extra tokens mock_qa_prompt_tmpl = "{context_str}{query_str}" - mock_qa_prompt = Prompt(mock_qa_prompt_tmpl, prompt_type=PromptType.QUESTION_ANSWER) + mock_qa_prompt = PromptTemplate( + mock_qa_prompt_tmpl, prompt_type=PromptType.QUESTION_ANSWER + ) # max input size is 11, prompt is two tokens (the query) --> 9 tokens # --> padding is 1 --> 8 tokens @@ -163,7 +169,9 @@ def test_accumulate_response_async( # test response with ResponseMode.ACCUMULATE # NOTE: here we want to guarante that prompts have 0 extra tokens mock_qa_prompt_tmpl = "{context_str}{query_str}" - mock_qa_prompt = Prompt(mock_qa_prompt_tmpl, prompt_type=PromptType.QUESTION_ANSWER) + mock_qa_prompt = PromptTemplate( + mock_qa_prompt_tmpl, prompt_type=PromptType.QUESTION_ANSWER + ) # max input size is 11, prompt is two tokens (the query) --> 9 tokens # --> padding is 1 --> 8 tokens @@ -221,7 +229,9 @@ def test_accumulate_response_aget( # test response with ResponseMode.ACCUMULATE # NOTE: here we want to guarante that prompts have 0 extra tokens mock_qa_prompt_tmpl = "{context_str}{query_str}" - mock_qa_prompt = Prompt(mock_qa_prompt_tmpl, prompt_type=PromptType.QUESTION_ANSWER) + mock_qa_prompt = PromptTemplate( + mock_qa_prompt_tmpl, prompt_type=PromptType.QUESTION_ANSWER + ) # max input size is 11, prompt is two tokens (the query) --> 9 tokens # --> padding is 1 --> 8 tokens @@ -281,7 +291,9 @@ def test_accumulate_compact_response(patch_llm_predictor: None) -> None: # test response with ResponseMode.ACCUMULATE # NOTE: here we want to guarante that prompts have 0 extra tokens mock_qa_prompt_tmpl = "{context_str}{query_str}" - mock_qa_prompt = Prompt(mock_qa_prompt_tmpl, prompt_type=PromptType.QUESTION_ANSWER) + mock_qa_prompt = PromptTemplate( + mock_qa_prompt_tmpl, prompt_type=PromptType.QUESTION_ANSWER + ) # max input size is 11, prompt is two tokens (the query) --> 9 tokens # --> padding is 1 --> 8 tokens diff --git a/tests/indices/response/test_tree_summarize.py b/tests/indices/response/test_tree_summarize.py index 5f451dcd79ea9ebaba7afab9554c06a89bf9781a..017f4e4d72f649fd58470e8aa1c3636fbee342d4 100644 --- a/tests/indices/response/test_tree_summarize.py +++ b/tests/indices/response/test_tree_summarize.py @@ -8,7 +8,7 @@ import pytest from llama_index.indices.prompt_helper import PromptHelper from llama_index.response_synthesizers import TreeSummarize from llama_index.indices.service_context import ServiceContext -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType @@ -16,7 +16,9 @@ from llama_index.prompts.prompt_type import PromptType def mock_service_context_merge_chunks( mock_service_context: ServiceContext, ) -> ServiceContext: - def mock_repack(prompt_template: Prompt, text_chunks: Sequence[str]) -> List[str]: + def mock_repack( + prompt_template: PromptTemplate, text_chunks: Sequence[str] + ) -> List[str]: merged_chunks = [] for chunks in zip(*[iter(text_chunks)] * 2): merged_chunks.append("\n".join(chunks)) @@ -30,7 +32,7 @@ def mock_service_context_merge_chunks( def test_tree_summarize(mock_service_context_merge_chunks: ServiceContext) -> None: mock_summary_prompt_tmpl = "{context_str}{query_str}" - mock_summary_prompt = Prompt( + mock_summary_prompt = PromptTemplate( mock_summary_prompt_tmpl, prompt_type=PromptType.SUMMARY ) @@ -55,7 +57,7 @@ def test_tree_summarize_use_async( mock_service_context_merge_chunks: ServiceContext, ) -> None: mock_summary_prompt_tmpl = "{context_str}{query_str}" - mock_summary_prompt = Prompt( + mock_summary_prompt = PromptTemplate( mock_summary_prompt_tmpl, prompt_type=PromptType.SUMMARY ) @@ -82,7 +84,7 @@ async def test_tree_summarize_async( mock_service_context_merge_chunks: ServiceContext, ) -> None: mock_summary_prompt_tmpl = "{context_str}{query_str}" - mock_summary_prompt = Prompt( + mock_summary_prompt = PromptTemplate( mock_summary_prompt_tmpl, prompt_type=PromptType.SUMMARY ) diff --git a/tests/indices/test_prompt_helper.py b/tests/indices/test_prompt_helper.py index 2946ee2fb03e0ce354867076d3b37ba17daffae2..6ec3da14107af3b83f008c67b7d09f9c4eca94d2 100644 --- a/tests/indices/test_prompt_helper.py +++ b/tests/indices/test_prompt_helper.py @@ -1,11 +1,9 @@ """Test PromptHelper.""" -from typing import cast -from llama_index.bridge.langchain import PromptTemplate as LangchainPrompt from llama_index.indices.prompt_helper import PromptHelper from llama_index.indices.tree.utils import get_numbered_text_from_nodes -from llama_index.prompts.base import Prompt -from llama_index.prompts.utils import get_biggest_prompt, get_empty_prompt_txt +from llama_index.prompts.base import PromptTemplate +from llama_index.prompts.prompt_utils import get_biggest_prompt, get_empty_prompt_txt from llama_index.schema import TextNode from llama_index.text_splitter.utils import truncate_text from tests.mock_utils.mock_utils import mock_tokenizer @@ -14,7 +12,7 @@ from tests.mock_utils.mock_utils import mock_tokenizer def test_get_chunk_size() -> None: """Test get chunk size given prompt.""" # test with 1 chunk - prompt = Prompt("This is the prompt") + prompt = PromptTemplate("This is the prompt") prompt_helper = PromptHelper( context_window=11, num_output=1, chunk_overlap_ratio=0, tokenizer=mock_tokenizer ) @@ -50,7 +48,7 @@ def test_get_chunk_size() -> None: def test_get_text_splitter() -> None: """Test get text splitter.""" test_prompt_text = "This is the prompt{text}" - test_prompt = Prompt(test_prompt_text) + test_prompt = PromptTemplate(test_prompt_text) prompt_helper = PromptHelper( context_window=11, num_output=1, chunk_overlap_ratio=0, tokenizer=mock_tokenizer ) @@ -84,7 +82,7 @@ def test_get_text_splitter_partial() -> None: # test without partially formatting test_prompt_text = "This is the {foo} prompt{text}" - test_prompt = Prompt(test_prompt_text) + test_prompt = PromptTemplate(test_prompt_text) prompt_helper = PromptHelper( context_window=11, num_output=1, chunk_overlap_ratio=0, tokenizer=mock_tokenizer ) @@ -98,7 +96,7 @@ def test_get_text_splitter_partial() -> None: assert truncated_text == "Hello world" # test with partially formatting - test_prompt = Prompt(test_prompt_text) + test_prompt = PromptTemplate(test_prompt_text) test_prompt = test_prompt.partial_format(foo="bar") prompt_helper = PromptHelper( context_window=12, num_output=1, chunk_overlap_ratio=0, tokenizer=mock_tokenizer @@ -118,7 +116,7 @@ def test_truncate() -> None: """Test truncate.""" # test prompt uses up one token test_prompt_txt = "test{text}" - test_prompt = Prompt(test_prompt_txt) + test_prompt = PromptTemplate(test_prompt_txt) # set context_window=19 # For each text chunk, there's 4 tokens for text + 5 for the padding prompt_helper = PromptHelper( @@ -139,7 +137,7 @@ def test_get_numbered_text_from_nodes() -> None: """Test get_text_from_nodes.""" # test prompt uses up one token test_prompt_txt = "test{text}" - test_prompt = Prompt(test_prompt_txt) + test_prompt = PromptTemplate(test_prompt_txt) # set context_window=17 # For each text chunk, there's 3 for text, 5 for padding (including number) prompt_helper = PromptHelper( @@ -159,7 +157,7 @@ def test_get_numbered_text_from_nodes() -> None: def test_repack() -> None: """Test repack.""" test_prompt_text = "This is the prompt{text}" - test_prompt = Prompt(test_prompt_text) + test_prompt = PromptTemplate(test_prompt_text) prompt_helper = PromptHelper( context_window=13, num_output=1, @@ -174,11 +172,9 @@ def test_repack() -> None: def test_get_biggest_prompt() -> None: """Test get_biggest_prompt from PromptHelper.""" - prompt1 = Prompt("This is the prompt{text}") - prompt2 = Prompt("This is the longer prompt{text}") - prompt3 = Prompt("This is the {text}") + prompt1 = PromptTemplate("This is the prompt{text}") + prompt2 = PromptTemplate("This is the longer prompt{text}") + prompt3 = PromptTemplate("This is the {text}") biggest_prompt = get_biggest_prompt([prompt1, prompt2, prompt3]) - lc_biggest_template = cast(LangchainPrompt, biggest_prompt.prompt).template - prompt2_template = cast(LangchainPrompt, prompt2.prompt).template - assert lc_biggest_template == prompt2_template + assert biggest_prompt == prompt2 diff --git a/tests/llm_predictor/test_base.py b/tests/llm_predictor/test_base.py index 01ed7dca006d80669155e3aa8d0f62588284b903..c1a07a33a6ba167d893fb345af57eb0c585d15dd 100644 --- a/tests/llm_predictor/test_base.py +++ b/tests/llm_predictor/test_base.py @@ -2,10 +2,10 @@ from typing import Any from unittest.mock import patch - from llama_index.llm_predictor.structured import LLMPredictor, StructuredLLMPredictor +from llama_index.prompts import BasePromptTemplate +from llama_index.prompts.base import PromptTemplate from llama_index.types import BaseOutputParser -from llama_index.prompts.prompts import Prompt, SimpleInputPrompt try: gptcache_installed = True @@ -25,7 +25,7 @@ class MockOutputParser(BaseOutputParser): return output -def mock_llmpredictor_predict(prompt: Prompt, **prompt_args: Any) -> str: +def mock_llmpredictor_predict(prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Mock LLMPredictor predict.""" return prompt_args["query_str"] @@ -36,12 +36,12 @@ def test_struct_llm_predictor(mock_init: Any, mock_predict: Any) -> None: """Test LLM predictor.""" llm_predictor = StructuredLLMPredictor() output_parser = MockOutputParser() - prompt = SimpleInputPrompt("{query_str}", output_parser=output_parser) + prompt = PromptTemplate("{query_str}", output_parser=output_parser) llm_prediction = llm_predictor.predict(prompt, query_str="hello world") assert llm_prediction == "hello world\nhello world" # no change - prompt = SimpleInputPrompt("{query_str}") + prompt = PromptTemplate("{query_str}") llm_prediction = llm_predictor.predict(prompt, query_str="hello world") assert llm_prediction == "hello world" diff --git a/tests/llm_predictor/vellum/conftest.py b/tests/llm_predictor/vellum/conftest.py index 29cfb14d7eebafffe3af15524ff5b77f5694d30a..94faebbed50ef0b8fd4fcbec0d40f3d47bf1237d 100644 --- a/tests/llm_predictor/vellum/conftest.py +++ b/tests/llm_predictor/vellum/conftest.py @@ -1,26 +1,16 @@ -from typing import Optional, Type, Callable +from typing import Optional, Callable from unittest import mock import pytest -from llama_index import Prompt from llama_index.callbacks import CallbackManager from llama_index.llm_predictor.vellum import VellumPredictor, VellumPromptRegistry -from llama_index.prompts.prompt_type import PromptType +from llama_index.prompts.base import PromptTemplate @pytest.fixture -def dummy_prompt_class() -> Type[Prompt]: - class DummyPrompt(Prompt): - prompt_type = PromptType.CUSTOM - input_variables = ["thing"] - - return DummyPrompt - - -@pytest.fixture -def dummy_prompt(dummy_prompt_class: Type[Prompt]) -> Prompt: - return dummy_prompt_class(template="What's your favorite {thing}?") +def dummy_prompt() -> PromptTemplate: + return PromptTemplate(template="What's your favorite {thing}?") @pytest.fixture diff --git a/tests/llm_predictor/vellum/test_predictor.py b/tests/llm_predictor/vellum/test_predictor.py index 785c4518a78db3af68c6ca5df33428c9ddd05d3c..0f2ff4042f1b3a407d6c57a9243296d37c954076 100644 --- a/tests/llm_predictor/vellum/test_predictor.py +++ b/tests/llm_predictor/vellum/test_predictor.py @@ -3,16 +3,14 @@ from unittest import mock import pytest -from llama_index import Prompt -from llama_index.llm_predictor.vellum import ( - VellumPredictor, -) +from llama_index.llm_predictor.vellum import VellumPredictor +from llama_index.prompts import BasePromptTemplate def test_predict__basic( mock_vellum_client_factory: Callable[..., mock.MagicMock], vellum_predictor_factory: Callable[..., VellumPredictor], - dummy_prompt: Prompt, + dummy_prompt: BasePromptTemplate, ) -> None: """When the Vellum API returns expected values, so should our predictor""" @@ -31,7 +29,7 @@ def test_predict__basic( def test_stream__basic( mock_vellum_client_factory: Callable[..., mock.MagicMock], vellum_predictor_factory: Callable[..., VellumPredictor], - dummy_prompt: Prompt, + dummy_prompt: BasePromptTemplate, ) -> None: """When the Vellum API streams expected values, so should our predictor""" diff --git a/tests/llm_predictor/vellum/test_prompt_registry.py b/tests/llm_predictor/vellum/test_prompt_registry.py index 2f75715e5bedb0b2cdc6a5f391de20673fcbbed1..7350d8f5001573a71a02423d6482a8ee6650d561 100644 --- a/tests/llm_predictor/vellum/test_prompt_registry.py +++ b/tests/llm_predictor/vellum/test_prompt_registry.py @@ -1,16 +1,15 @@ -from typing import Callable, Type +from typing import Callable from unittest import mock -from llama_index import Prompt from llama_index.llm_predictor.vellum import ( VellumRegisteredPrompt, VellumCompiledPrompt, VellumPromptRegistry, ) +from llama_index.prompts.base import PromptTemplate def test_from_prompt__new( - dummy_prompt_class: Type[Prompt], mock_vellum_client_factory: Callable[..., mock.MagicMock], vellum_prompt_registry_factory: Callable[..., VellumPromptRegistry], ) -> None: @@ -18,7 +17,7 @@ def test_from_prompt__new( from vellum.core import ApiError - dummy_prompt = dummy_prompt_class(template="What's your favorite {thing}?") + dummy_prompt = PromptTemplate(template="What's your favorite {thing}?") vellum_client = mock_vellum_client_factory() @@ -31,13 +30,12 @@ def test_from_prompt__new( def test_from_prompt__existing( - dummy_prompt_class: Type[Prompt], mock_vellum_client_factory: Callable[..., mock.MagicMock], vellum_prompt_registry_factory: Callable[..., VellumPromptRegistry], ) -> None: """We shouldn't register a new prompt if a deployment id or name is provided""" - dummy_prompt = dummy_prompt_class( + dummy_prompt = PromptTemplate( template="What's your favorite {thing}?", metadata={"vellum_deployment_id": "abc"}, ) diff --git a/tests/mock_utils/mock_predict.py b/tests/mock_utils/mock_predict.py index 32cc806b2ff3583d5a2c0a2a595fbe446eb0e8b1..f3b5ac82eb40cac83ee507212638940382387938 100644 --- a/tests/mock_utils/mock_predict.py +++ b/tests/mock_utils/mock_predict.py @@ -3,7 +3,9 @@ import json from typing import Any, Dict -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import ( + BasePromptTemplate, +) from llama_index.prompts.prompt_type import PromptType from llama_index.token_counter.utils import mock_extract_keywords_response @@ -150,50 +152,54 @@ def _mock_conversation(prompt_args: Dict) -> str: return prompt_args["history"] + ":" + prompt_args["message"] -def mock_llmpredictor_predict(prompt: Prompt, **prompt_args: Any) -> str: +def mock_llmpredictor_predict(prompt: BasePromptTemplate, **prompt_args: Any) -> str: """Mock predict method of LLMPredictor. Depending on the prompt, return response. """ - full_prompt_args = prompt.get_full_format_args(prompt_args) - if prompt.prompt_type == PromptType.SUMMARY: + full_prompt_args = { + **prompt.kwargs, + **prompt_args, + } + prompt_type = prompt.metadata["prompt_type"] + if prompt_type == PromptType.SUMMARY: response = _mock_summary_predict(full_prompt_args) - elif prompt.prompt_type == PromptType.TREE_INSERT: + elif prompt_type == PromptType.TREE_INSERT: response = _mock_insert_predict() - elif prompt.prompt_type == PromptType.TREE_SELECT: + elif prompt_type == PromptType.TREE_SELECT: response = _mock_query_select() - elif prompt.prompt_type == PromptType.REFINE: + elif prompt_type == PromptType.REFINE: response = _mock_refine(full_prompt_args) - elif prompt.prompt_type == PromptType.QUESTION_ANSWER: + elif prompt_type == PromptType.QUESTION_ANSWER: response = _mock_answer(full_prompt_args) - elif prompt.prompt_type == PromptType.KEYWORD_EXTRACT: + elif prompt_type == PromptType.KEYWORD_EXTRACT: response = _mock_keyword_extract(full_prompt_args) - elif prompt.prompt_type == PromptType.QUERY_KEYWORD_EXTRACT: + elif prompt_type == PromptType.QUERY_KEYWORD_EXTRACT: response = _mock_query_keyword_extract(full_prompt_args) - elif prompt.prompt_type == PromptType.SCHEMA_EXTRACT: + elif prompt_type == PromptType.SCHEMA_EXTRACT: response = _mock_schema_extract(full_prompt_args) - elif prompt.prompt_type == PromptType.TEXT_TO_SQL: + elif prompt_type == PromptType.TEXT_TO_SQL: response = _mock_text_to_sql(full_prompt_args) - elif prompt.prompt_type == PromptType.KNOWLEDGE_TRIPLET_EXTRACT: + elif prompt_type == PromptType.KNOWLEDGE_TRIPLET_EXTRACT: response = _mock_kg_triplet_extract(full_prompt_args) - elif prompt.prompt_type == PromptType.SIMPLE_INPUT: + elif prompt_type == PromptType.SIMPLE_INPUT: response = _mock_input(full_prompt_args) - elif prompt.prompt_type == PromptType.SINGLE_SELECT: + elif prompt_type == PromptType.SINGLE_SELECT: response = _mock_single_select() - elif prompt.prompt_type == PromptType.MULTI_SELECT: + elif prompt_type == PromptType.MULTI_SELECT: response = _mock_multi_select(full_prompt_args) - elif prompt.prompt_type == PromptType.SUB_QUESTION: + elif prompt_type == PromptType.SUB_QUESTION: response = _mock_sub_questions() - elif prompt.prompt_type == PromptType.PANDAS: + elif prompt_type == PromptType.PANDAS: response = _mock_pandas(full_prompt_args) - elif prompt.prompt_type == PromptType.SQL_RESPONSE_SYNTHESIS: + elif prompt_type == PromptType.SQL_RESPONSE_SYNTHESIS: response = _mock_sql_response_synthesis(full_prompt_args) - elif prompt.prompt_type == PromptType.DECOMPOSE: + elif prompt_type == PromptType.DECOMPOSE: response = _mock_decompose_query(full_prompt_args) - elif prompt.prompt_type == PromptType.CHOICE_SELECT: + elif prompt_type == PromptType.CHOICE_SELECT: response = _mock_choice_select(full_prompt_args) - elif prompt.prompt_type == PromptType.CONVERSATION: + elif prompt_type == PromptType.CONVERSATION: response = _mock_conversation(full_prompt_args) else: response = str(full_prompt_args) @@ -201,7 +207,9 @@ def mock_llmpredictor_predict(prompt: Prompt, **prompt_args: Any) -> str: return response -def patch_llmpredictor_predict(self: Any, prompt: Prompt, **prompt_args: Any) -> str: +def patch_llmpredictor_predict( + self: Any, prompt: BasePromptTemplate, **prompt_args: Any +) -> str: """Mock predict method of LLMPredictor. Depending on the prompt, return response. @@ -211,12 +219,14 @@ def patch_llmpredictor_predict(self: Any, prompt: Prompt, **prompt_args: Any) -> async def patch_llmpredictor_apredict( - self: Any, prompt: Prompt, **prompt_args: Any + self: Any, prompt: BasePromptTemplate, **prompt_args: Any ) -> str: """Mock apredict method of LLMPredictor.""" return patch_llmpredictor_predict(self, prompt, **prompt_args) -async def mock_llmpredictor_apredict(prompt: Prompt, **prompt_args: Any) -> str: +async def mock_llmpredictor_apredict( + prompt: BasePromptTemplate, **prompt_args: Any +) -> str: """Mock apredict method of LLMPredictor.""" return mock_llmpredictor_predict(prompt, **prompt_args) diff --git a/tests/mock_utils/mock_prompts.py b/tests/mock_utils/mock_prompts.py index 5835c90293495f9526007c15b9ede8f2616d7bbe..390359c9e722c27d3a516091ff3d7e3284be11ea 100644 --- a/tests/mock_utils/mock_prompts.py +++ b/tests/mock_utils/mock_prompts.py @@ -1,65 +1,77 @@ """Mock prompt utils.""" -from llama_index.prompts.base import Prompt +from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType MOCK_SUMMARY_PROMPT_TMPL = "{context_str}\n" -MOCK_SUMMARY_PROMPT = Prompt(MOCK_SUMMARY_PROMPT_TMPL, prompt_type=PromptType.SUMMARY) +MOCK_SUMMARY_PROMPT = PromptTemplate( + MOCK_SUMMARY_PROMPT_TMPL, prompt_type=PromptType.SUMMARY +) MOCK_INSERT_PROMPT_TMPL = "{num_chunks}\n{context_list}{new_chunk_text}\n" -MOCK_INSERT_PROMPT = Prompt(MOCK_INSERT_PROMPT_TMPL, prompt_type=PromptType.TREE_INSERT) +MOCK_INSERT_PROMPT = PromptTemplate( + MOCK_INSERT_PROMPT_TMPL, prompt_type=PromptType.TREE_INSERT +) # # single choice MOCK_QUERY_PROMPT_TMPL = "{num_chunks}\n" "{context_list}\n" "{query_str}'\n" -MOCK_QUERY_PROMPT = Prompt(MOCK_QUERY_PROMPT_TMPL, prompt_type=PromptType.TREE_SELECT) +MOCK_QUERY_PROMPT = PromptTemplate( + MOCK_QUERY_PROMPT_TMPL, prompt_type=PromptType.TREE_SELECT +) MOCK_REFINE_PROMPT_TMPL = "{query_str}\n" "{existing_answer}\n" "{context_msg}\n" -MOCK_REFINE_PROMPT = Prompt(MOCK_REFINE_PROMPT_TMPL, prompt_type=PromptType.REFINE) +MOCK_REFINE_PROMPT = PromptTemplate( + MOCK_REFINE_PROMPT_TMPL, prompt_type=PromptType.REFINE +) MOCK_TEXT_QA_PROMPT_TMPL = "{context_str}\n" "{query_str}\n" -MOCK_TEXT_QA_PROMPT = Prompt( +MOCK_TEXT_QA_PROMPT = PromptTemplate( MOCK_TEXT_QA_PROMPT_TMPL, prompt_type=PromptType.QUESTION_ANSWER ) MOCK_KEYWORD_EXTRACT_PROMPT_TMPL = "{max_keywords}\n{text}\n" -MOCK_KEYWORD_EXTRACT_PROMPT = Prompt( +MOCK_KEYWORD_EXTRACT_PROMPT = PromptTemplate( MOCK_KEYWORD_EXTRACT_PROMPT_TMPL, prompt_type=PromptType.KEYWORD_EXTRACT ) # TODO: consolidate with keyword extract MOCK_QUERY_KEYWORD_EXTRACT_PROMPT_TMPL = "{max_keywords}\n{question}\n" -MOCK_QUERY_KEYWORD_EXTRACT_PROMPT = Prompt( +MOCK_QUERY_KEYWORD_EXTRACT_PROMPT = PromptTemplate( MOCK_QUERY_KEYWORD_EXTRACT_PROMPT_TMPL, prompt_type=PromptType.QUERY_KEYWORD_EXTRACT ) MOCK_SCHEMA_EXTRACT_PROMPT_TMPL = "{text}\n{schema}" -MOCK_SCHEMA_EXTRACT_PROMPT = Prompt( +MOCK_SCHEMA_EXTRACT_PROMPT = PromptTemplate( MOCK_SCHEMA_EXTRACT_PROMPT_TMPL, prompt_type=PromptType.SCHEMA_EXTRACT ) MOCK_TEXT_TO_SQL_PROMPT_TMPL = "{dialect}\n{schema}\n{query_str}" -MOCK_TEXT_TO_SQL_PROMPT = Prompt( +MOCK_TEXT_TO_SQL_PROMPT = PromptTemplate( MOCK_TEXT_TO_SQL_PROMPT_TMPL, prompt_type=PromptType.TEXT_TO_SQL ) MOCK_TABLE_CONTEXT_PROMPT_TMPL = "{schema}\n{context_str}\n{query_str}" -MOCK_TABLE_CONTEXT_PROMPT = Prompt( +MOCK_TABLE_CONTEXT_PROMPT = PromptTemplate( MOCK_TABLE_CONTEXT_PROMPT_TMPL, prompt_type=PromptType.TABLE_CONTEXT ) MOCK_KG_TRIPLET_EXTRACT_PROMPT_TMPL = "{max_knowledge_triplets}\n{text}" -MOCK_KG_TRIPLET_EXTRACT_PROMPT = Prompt( +MOCK_KG_TRIPLET_EXTRACT_PROMPT = PromptTemplate( MOCK_KG_TRIPLET_EXTRACT_PROMPT_TMPL, prompt_type=PromptType.KNOWLEDGE_TRIPLET_EXTRACT, ) MOCK_INPUT_PROMPT_TMPL = "{query_str}" -MOCK_INPUT_PROMPT = Prompt(MOCK_INPUT_PROMPT_TMPL, prompt_type=PromptType.SIMPLE_INPUT) +MOCK_INPUT_PROMPT = PromptTemplate( + MOCK_INPUT_PROMPT_TMPL, prompt_type=PromptType.SIMPLE_INPUT +) MOCK_PANDAS_PROMPT_TMPL = "{query_str}\n{df_str}\n{instruction_str}" -MOCK_PANDAS_PROMPT = Prompt(MOCK_PANDAS_PROMPT_TMPL, prompt_type=PromptType.PANDAS) +MOCK_PANDAS_PROMPT = PromptTemplate( + MOCK_PANDAS_PROMPT_TMPL, prompt_type=PromptType.PANDAS +) diff --git a/tests/prompts/test_base.py b/tests/prompts/test_base.py index aa1507420bf51414dc1fffb105c5100e1bf6b48e..8fef4f49dda478e98ca23348da32488a5f9d266c 100644 --- a/tests/prompts/test_base.py +++ b/tests/prompts/test_base.py @@ -1,97 +1,126 @@ """Test prompts.""" -from unittest.mock import MagicMock -import pytest +from llama_index.bridge.langchain import BaseLanguageModel +from llama_index.bridge.langchain import ConditionalPromptSelector as LangchainSelector +from llama_index.bridge.langchain import FakeListLLM +from llama_index.bridge.langchain import PromptTemplate as LangchainTemplate +from llama_index.llms import MockLLM +from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.langchain import LangChainLLM +from llama_index.prompts import ( + ChatPromptTemplate, + LangchainPromptTemplate, + PromptTemplate, + SelectorPromptTemplate, +) +from llama_index.prompts.prompt_type import PromptType + + +def test_template() -> None: + """Test partial format.""" + prompt_txt = "hello {text} {foo}" + prompt = PromptTemplate(prompt_txt) -from llama_index.bridge.langchain import PromptTemplate -from llama_index.llms.base import LLM -from llama_index.llms.openai import OpenAI -from llama_index.prompts.base import Prompt -from llama_index.prompts.prompt_selector import PromptSelector + prompt_fmt = prompt.partial_format(foo="bar") + assert isinstance(prompt_fmt, PromptTemplate) + assert prompt_fmt.format(text="world") == "hello world bar" -def is_openai(llm: LLM) -> bool: - """Test condition.""" - return isinstance(llm, OpenAI) + assert prompt_fmt.format_messages(text="world") == [ + ChatMessage(content="hello world bar", role=MessageRole.USER) + ] + + +def test_chat_template() -> None: + chat_template = ChatPromptTemplate( + message_templates=[ + ChatMessage( + content="This is a system message with a {sys_param}", + role=MessageRole.SYSTEM, + ), + ChatMessage(content="hello {text} {foo}", role=MessageRole.USER), + ], + prompt_type=PromptType.CONVERSATION, + ) + partial_template = chat_template.partial_format(sys_param="sys_arg") + messages = partial_template.format_messages(text="world", foo="bar") -def test_partial_format() -> None: - """Test partial format.""" - prompt_txt = "hello {text} {foo}" - prompt = Prompt(prompt_txt) + assert messages[0] == ChatMessage( + content="This is a system message with a sys_arg", role=MessageRole.SYSTEM + ) - prompt_fmt = prompt.partial_format(foo="bar") + assert partial_template.format(text="world", foo="bar") == ( + "system: This is a system message with a sys_arg\n" + "user: hello world bar\n" + "assistant: " + ) - assert isinstance(prompt_fmt, Prompt) - assert prompt_fmt.format(text="world") == "hello world bar" +def test_selector_template() -> None: + default_template = PromptTemplate("hello {text} {foo}") + chat_template = ChatPromptTemplate( + message_templates=[ + ChatMessage( + content="This is a system message with a {sys_param}", + role=MessageRole.SYSTEM, + ), + ChatMessage(content="hello {text} {foo}", role=MessageRole.USER), + ], + prompt_type=PromptType.CONVERSATION, + ) -def test_from_prompt() -> None: - """Test new prompt from a partially formatted prompt.""" - prompt_txt = "hello {text} {foo}" - prompt = Prompt(prompt_txt) - prompt_fmt = prompt.partial_format(foo="bar") + selector_template = SelectorPromptTemplate( + default_template=default_template, + conditionals=[ + (lambda llm: isinstance(llm, MockLLM), chat_template), + ], + ) - prompt_new = Prompt.from_prompt(prompt_fmt) - assert isinstance(prompt_new, Prompt) + partial_template = selector_template.partial_format(text="world", foo="bar") - assert prompt_new.format(text="world2") == "hello world2 bar" + prompt = partial_template.format() + assert prompt == "hello world bar" + messages = partial_template.format_messages(llm=MockLLM(), sys_param="sys_arg") + assert messages[0] == ChatMessage( + content="This is a system message with a sys_arg", role=MessageRole.SYSTEM + ) -def test_from_langchain_prompt() -> None: - """Test from langchain prompt.""" - prompt_txt = "hello {text} {foo}" - prompt = PromptTemplate(input_variables=["text", "foo"], template=prompt_txt) - prompt_new = Prompt.from_langchain_prompt(prompt) - assert isinstance(prompt_new, Prompt) - assert prompt_new.prompt == prompt - assert prompt_new.format(text="world2", foo="bar") == "hello world2 bar" +def test_langchain_template() -> None: + lc_template = LangchainTemplate.from_template("hello {text} {foo}") + template = LangchainPromptTemplate(lc_template) - # test errors if we specify both template and langchain prompt - with pytest.raises(ValueError): - prompt_txt = "hello {text} {foo}" - prompt = PromptTemplate(input_variables=["text", "foo"], template=prompt_txt) - Prompt(template=prompt_txt, langchain_prompt=prompt) + template_fmt = template.partial_format(foo="bar") + assert isinstance(template, LangchainPromptTemplate) + assert template_fmt.format(text="world") == "hello world bar" -def test_from_langchain_prompt_selector() -> None: - """Test from langchain prompt selector.""" - prompt_txt = "hello {text} {foo}" - prompt_txt_2 = "world {text} {foo}" - prompt = PromptTemplate(input_variables=["text", "foo"], template=prompt_txt) - prompt_2 = PromptTemplate(input_variables=["text", "foo"], template=prompt_txt_2) + assert template_fmt.format_messages(text="world") == [ + ChatMessage(content="hello world bar", role=MessageRole.USER) + ] - test_prompt_selector = PromptSelector( - default_prompt=prompt, conditionals=[(is_openai, prompt_2)] - ) - test_llm = MagicMock(spec=OpenAI) +def test_langchain_selector_template() -> None: + lc_llm = FakeListLLM(responses=["test"]) + mock_llm = LangChainLLM(llm=lc_llm) + + def is_mock(llm: BaseLanguageModel) -> bool: + return llm == lc_llm - prompt_new = Prompt.from_langchain_prompt_selector(test_prompt_selector) - assert isinstance(prompt_new, Prompt) - assert prompt_new.prompt == prompt - assert prompt_new.format(text="world2", foo="bar") == "hello world2 bar" - assert ( - prompt_new.format(llm=test_llm, text="world2", foo="bar") == "world world2 bar" + default_lc_template = LangchainTemplate.from_template("hello {text} {foo}") + conditionals = [ + (is_mock, LangchainTemplate.from_template("hello {text} {foo} mock")), + ] + + lc_selector = LangchainSelector( + default_prompt=default_lc_template, conditionals=conditionals ) + template = LangchainPromptTemplate(selector=lc_selector) + + template_fmt = template.partial_format(foo="bar") + assert isinstance(template, LangchainPromptTemplate) - test_lc_prompt = prompt_new.get_langchain_prompt(llm=test_llm) - assert test_lc_prompt == prompt_2 - test_lc_prompt = prompt_new.get_langchain_prompt(llm=None) - assert test_lc_prompt == prompt - - # test errors if langchain prompt input var doesn't match - with pytest.raises(ValueError): - prompt_txt = "hello {text} {foo}" - prompt_txt_2 = "world {text} {foo} {tmp}" - prompt = PromptTemplate(input_variables=["text", "foo"], template=prompt_txt) - prompt_2 = PromptTemplate( - input_variables=["text", "foo", "tmp"], template=prompt_txt_2 - ) - - test_prompt_selector = PromptSelector( - prompt=prompt, conditionals=([is_openai], [prompt_2]) - ) - prompt_new = Prompt.from_langchain_prompt_selector(test_prompt_selector) + assert template_fmt.format(llm=mock_llm, text="world") == "hello world bar mock" diff --git a/tests/prompts/test_utils.py b/tests/prompts/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f09ffc2004811fd6eecf4a03979f21da6af292b0 --- /dev/null +++ b/tests/prompts/test_utils.py @@ -0,0 +1,7 @@ +from llama_index.prompts.utils import get_template_vars + + +def test_get_template_vars() -> None: + template = "hello {text} {foo}" + template_vars = get_template_vars(template) + assert template_vars == ["text", "foo"]