From 0148354dbe0513ee1aa26718abfed5e929f6c2cc Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Fri, 6 Sep 2024 11:22:08 -0700 Subject: [PATCH] refactor: prompt system (#1154) --- .changeset/tall-kangaroos-sleep.md | 10 + examples/prompts/promptMixin.ts | 13 +- examples/readers/src/csv.ts | 14 +- examples/readers/src/llamaparse-json.ts | 5 +- packages/core/package.json | 17 +- packages/core/src/prompts/base.ts | 225 ++++++++++ packages/core/src/prompts/index.ts | 33 ++ packages/core/src/prompts/mixin.ts | 79 ++++ packages/core/src/prompts/prompt-type.ts | 64 +++ packages/core/src/prompts/prompt.ts | 253 +++++++++++ packages/core/src/schema/index.ts | 1 + .../src/schema/type/base-output-parser.ts | 8 + packages/core/src/utils/index.ts | 4 + packages/core/src/utils/llms.ts | 23 + packages/core/src/utils/object-entries.ts | 14 + packages/core/tests/prompts.test.ts | 161 +++++++ .../tests/prompts/mixin.test.ts} | 69 +-- packages/llamaindex/src/ChatHistory.ts | 12 +- packages/llamaindex/src/OutputParser.ts | 3 +- packages/llamaindex/src/Prompt.ts | 411 ------------------ packages/llamaindex/src/PromptHelper.ts | 37 +- packages/llamaindex/src/QuestionGenerator.ts | 22 +- .../chat/CondenseQuestionChatEngine.ts | 22 +- .../src/engines/chat/ContextChatEngine.ts | 24 +- .../engines/chat/DefaultContextGenerator.ts | 13 +- .../src/engines/query/RetrieverQueryEngine.ts | 8 +- .../src/engines/query/RouterQueryEngine.ts | 10 +- .../engines/query/SubQuestionQueryEngine.ts | 8 +- .../llamaindex/src/evaluation/Correctness.ts | 19 +- .../llamaindex/src/evaluation/Faithfulness.ts | 6 +- .../llamaindex/src/evaluation/Relevancy.ts | 6 +- packages/llamaindex/src/evaluation/prompts.ts | 132 +++--- packages/llamaindex/src/index.edge.ts | 3 +- .../llamaindex/src/indices/keyword/index.ts | 20 +- .../llamaindex/src/indices/summary/index.ts | 8 +- .../llamaindex/src/outputParsers/selectors.ts | 3 +- packages/llamaindex/src/prompts/Mixin.ts | 90 ---- packages/llamaindex/src/prompts/index.ts | 1 - packages/llamaindex/src/selectors/base.ts | 2 +- .../llamaindex/src/selectors/llmSelectors.ts | 37 +- packages/llamaindex/src/selectors/prompts.ts | 47 +- .../MultiModalResponseSynthesizer.ts | 21 +- .../src/synthesizers/ResponseSynthesizer.ts | 11 +- .../llamaindex/src/synthesizers/builders.ts | 114 +++-- packages/llamaindex/src/synthesizers/types.ts | 6 +- packages/llamaindex/src/synthesizers/utils.ts | 8 +- packages/llamaindex/src/types.ts | 9 - pnpm-lock.yaml | 26 +- 48 files changed, 1309 insertions(+), 823 deletions(-) create mode 100644 .changeset/tall-kangaroos-sleep.md create mode 100644 packages/core/src/prompts/base.ts create mode 100644 packages/core/src/prompts/index.ts create mode 100644 packages/core/src/prompts/mixin.ts create mode 100644 packages/core/src/prompts/prompt-type.ts create mode 100644 packages/core/src/prompts/prompt.ts create mode 100644 packages/core/src/schema/type/base-output-parser.ts create mode 100644 packages/core/src/utils/object-entries.ts create mode 100644 packages/core/tests/prompts.test.ts rename packages/{llamaindex/tests/prompts/Mixin.test.ts => core/tests/prompts/mixin.test.ts} (61%) delete mode 100644 packages/llamaindex/src/Prompt.ts delete mode 100644 packages/llamaindex/src/prompts/Mixin.ts delete mode 100644 packages/llamaindex/src/prompts/index.ts diff --git a/.changeset/tall-kangaroos-sleep.md b/.changeset/tall-kangaroos-sleep.md new file mode 100644 index 000000000..557a673fd --- /dev/null +++ b/.changeset/tall-kangaroos-sleep.md @@ -0,0 +1,10 @@ +--- +"@llamaindex/core": patch +"llamaindex": patch +"@llamaindex/core-tests": patch +"llamaindex-loader-example": patch +--- + +refactor: prompt system + +Add `PromptTemplate` module with strong type check. diff --git a/examples/prompts/promptMixin.ts b/examples/prompts/promptMixin.ts index 7200d526e..074d2be77 100644 --- a/examples/prompts/promptMixin.ts +++ b/examples/prompts/promptMixin.ts @@ -1,21 +1,22 @@ import { Document, + PromptTemplate, ResponseSynthesizer, TreeSummarize, TreeSummarizePrompt, VectorStoreIndex, } from "llamaindex"; -const treeSummarizePrompt: TreeSummarizePrompt = ({ context, query }) => { - return `Context information from multiple sources is below. +const treeSummarizePrompt: TreeSummarizePrompt = new PromptTemplate({ + template: `Context information from multiple sources is below. --------------------- -${context} +{context} --------------------- Given the information from multiple sources and not prior knowledge. Answer the query in the style of a Shakespeare play" -Query: ${query} -Answer:`; -}; +Query: {query} +Answer:`, +}); async function main() { const documents = new Document({ diff --git a/examples/readers/src/csv.ts b/examples/readers/src/csv.ts index 6d9f6d901..17a6511ea 100644 --- a/examples/readers/src/csv.ts +++ b/examples/readers/src/csv.ts @@ -1,6 +1,7 @@ import { CompactAndRefine, OpenAI, + PromptTemplate, ResponseSynthesizer, Settings, VectorStoreIndex, @@ -18,14 +19,15 @@ async function main() { // Split text and create embeddings. Store them in a VectorStoreIndex const index = await VectorStoreIndex.fromDocuments(documents); - const csvPrompt = ({ context = "", query = "" }) => { - return `The following CSV file is loaded from ${path} + const csvPrompt = new PromptTemplate({ + templateVars: ["query", "context"], + template: `The following CSV file is loaded from ${path} \`\`\`csv -${context} +{context} \`\`\` -Given the CSV file, generate me Typescript code to answer the question: ${query}. You can use built in NodeJS functions but avoid using third party libraries. -`; - }; +Given the CSV file, generate me Typescript code to answer the question: {query}. You can use built in NodeJS functions but avoid using third party libraries. +`, + }); const responseSynthesizer = new ResponseSynthesizer({ responseBuilder: new CompactAndRefine(undefined, csvPrompt), diff --git a/examples/readers/src/llamaparse-json.ts b/examples/readers/src/llamaparse-json.ts index cc989bd0c..1c8576769 100644 --- a/examples/readers/src/llamaparse-json.ts +++ b/examples/readers/src/llamaparse-json.ts @@ -3,6 +3,7 @@ import { ImageNode, LlamaParseReader, OpenAI, + PromptTemplate, VectorStoreIndex, } from "llamaindex"; import { createMessageContent } from "llamaindex/synthesizers/utils"; @@ -50,7 +51,9 @@ async function getImageTextDocs( for (const imageDict of imageDicts) { const imageDoc = new ImageNode({ image: imageDict.path }); - const prompt = () => `Describe the image as alt text`; + const prompt = new PromptTemplate({ + template: `Describe the image as alt text`, + }); const message = await createMessageContent(prompt, [imageDoc]); const response = await llm.complete({ diff --git a/packages/core/package.json b/packages/core/package.json index 4a51d2c15..595a43bc4 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -115,6 +115,20 @@ "types": "./dist/utils/index.d.ts", "default": "./dist/utils/index.js" } + }, + "./prompts": { + "require": { + "types": "./dist/prompts/index.d.cts", + "default": "./dist/prompts/index.cjs" + }, + "import": { + "types": "./dist/prompts/index.d.ts", + "default": "./dist/prompts/index.js" + }, + "default": { + "types": "./dist/prompts/index.d.ts", + "default": "./dist/prompts/index.js" + } } }, "files": [ @@ -132,7 +146,8 @@ "devDependencies": { "ajv": "^8.17.1", "bunchee": "5.3.2", - "natural": "^8.0.1" + "natural": "^8.0.1", + "python-format-js": "^1.4.3" }, "dependencies": { "@llamaindex/env": "workspace:*", diff --git a/packages/core/src/prompts/base.ts b/packages/core/src/prompts/base.ts new file mode 100644 index 000000000..40312f5ac --- /dev/null +++ b/packages/core/src/prompts/base.ts @@ -0,0 +1,225 @@ +import format from "python-format-js"; +import type { ChatMessage } from "../llms"; +import type { BaseOutputParser, Metadata } from "../schema"; +import { objectEntries } from "../utils"; +import { PromptType } from "./prompt-type"; + +type MappingFn<TemplatesVar extends string[] = string[]> = ( + options: Record<TemplatesVar[number], string>, +) => string; + +export type BasePromptTemplateOptions< + TemplatesVar extends readonly string[], + Vars extends readonly string[], +> = { + metadata?: Metadata; + templateVars?: + | TemplatesVar + // loose type for better type inference + | readonly string[]; + options?: Partial<Record<TemplatesVar[number] | (string & {}), string>>; + outputParser?: BaseOutputParser; + templateVarMappings?: Partial< + Record<Vars[number] | (string & {}), TemplatesVar[number] | (string & {})> + >; + functionMappings?: Partial< + Record<TemplatesVar[number] | (string & {}), MappingFn> + >; +}; + +export abstract class BasePromptTemplate< + const TemplatesVar extends readonly string[] = string[], + const Vars extends readonly string[] = string[], +> { + metadata: Metadata = {}; + templateVars: Set<string> = new Set(); + options: Partial<Record<TemplatesVar[number] | (string & {}), string>> = {}; + outputParser?: BaseOutputParser; + templateVarMappings: Partial< + Record<Vars[number] | (string & {}), TemplatesVar[number] | (string & {})> + > = {}; + functionMappings: Partial< + Record<TemplatesVar[number] | (string & {}), MappingFn> + > = {}; + + protected constructor( + options: BasePromptTemplateOptions<TemplatesVar, Vars>, + ) { + const { + metadata, + templateVars, + outputParser, + templateVarMappings, + functionMappings, + } = options; + if (metadata) { + this.metadata = metadata; + } + if (templateVars) { + this.templateVars = new Set(templateVars); + } + if (options.options) { + this.options = options.options; + } + this.outputParser = outputParser; + if (templateVarMappings) { + this.templateVarMappings = templateVarMappings; + } + if (functionMappings) { + this.functionMappings = functionMappings; + } + } + + protected mapTemplateVars( + options: Record<TemplatesVar[number] | (string & {}), string>, + ) { + const templateVarMappings = this.templateVarMappings; + return Object.fromEntries( + objectEntries(options).map(([k, v]) => [templateVarMappings[k] || k, v]), + ); + } + + protected mapFunctionVars( + options: Record<TemplatesVar[number] | (string & {}), string>, + ) { + const functionMappings = this.functionMappings; + const newOptions = {} as Record<TemplatesVar[number], string>; + for (const [k, v] of objectEntries(functionMappings)) { + newOptions[k] = v!(options); + } + + for (const [k, v] of objectEntries(options)) { + if (!(k in newOptions)) { + newOptions[k] = v; + } + } + + return newOptions; + } + + protected mapAllVars( + options: Record<TemplatesVar[number] | (string & {}), string>, + ): Record<string, string> { + const newOptions = this.mapFunctionVars(options); + return this.mapTemplateVars(newOptions); + } + + abstract partialFormat( + options: Partial<Record<TemplatesVar[number] | (string & {}), string>>, + ): BasePromptTemplate<TemplatesVar, Vars>; + + abstract format( + options?: Partial<Record<TemplatesVar[number] | (string & {}), string>>, + ): string; + + abstract formatMessages( + options?: Partial<Record<TemplatesVar[number] | (string & {}), string>>, + ): ChatMessage[]; + + abstract get template(): string; +} + +type Permutation<T, K = T> = [T] extends [never] + ? [] + : K extends K + ? [K, ...Permutation<Exclude<T, K>>] + : never; + +type Join<T extends any[], U extends string> = T extends [infer F, ...infer R] + ? R["length"] extends 0 + ? `${F & string}` + : `${F & string}${U}${Join<R, U>}` + : never; + +type WrapStringWithBracket<T extends string> = `{${T}}`; + +export type StringTemplate<Var extends readonly string[]> = + Var["length"] extends 0 + ? string + : Var["length"] extends number + ? number extends Var["length"] + ? string + : `${string}${Join<Permutation<WrapStringWithBracket<Var[number]>>, `${string}`>}${string}` + : never; + +export type PromptTemplateOptions< + TemplatesVar extends readonly string[], + Vars extends readonly string[], + Template extends StringTemplate<TemplatesVar>, +> = BasePromptTemplateOptions<TemplatesVar, Vars> & { + template: Template; + promptType?: PromptType; +}; + +export class PromptTemplate< + const TemplatesVar extends readonly string[] = string[], + const Vars extends readonly string[] = string[], + const Template extends + StringTemplate<TemplatesVar> = StringTemplate<TemplatesVar>, +> extends BasePromptTemplate<TemplatesVar, Vars> { + #template: Template; + promptType: PromptType; + + constructor(options: PromptTemplateOptions<TemplatesVar, Vars, Template>) { + const { template, promptType, ...rest } = options; + super(rest); + this.#template = template; + this.promptType = promptType ?? PromptType.custom; + } + + partialFormat( + options: Partial<Record<TemplatesVar[number] | (string & {}), string>>, + ): PromptTemplate<TemplatesVar, Vars, Template> { + const prompt = new PromptTemplate({ + template: this.template, + templateVars: [...this.templateVars], + options: this.options, + outputParser: this.outputParser, + templateVarMappings: this.templateVarMappings, + functionMappings: this.functionMappings, + metadata: this.metadata, + promptType: this.promptType, + }); + + prompt.options = { + ...prompt.options, + ...options, + }; + + return prompt; + } + + format( + options?: Partial<Record<TemplatesVar[number] | (string & {}), string>>, + ): string { + const allOptions = { + ...this.options, + ...options, + } as Record<TemplatesVar[number], string>; + + const mappedAllOptions = this.mapAllVars(allOptions); + + const prompt = format(this.template, mappedAllOptions); + + if (this.outputParser) { + return this.outputParser.format(prompt); + } + return prompt; + } + + formatMessages( + options?: Partial<Record<TemplatesVar[number] | (string & {}), string>>, + ): ChatMessage[] { + const prompt = this.format(options); + return [ + { + role: "user", + content: prompt, + }, + ]; + } + + get template(): Template { + return this.#template; + } +} diff --git a/packages/core/src/prompts/index.ts b/packages/core/src/prompts/index.ts new file mode 100644 index 000000000..c8ae7087f --- /dev/null +++ b/packages/core/src/prompts/index.ts @@ -0,0 +1,33 @@ +export { BasePromptTemplate, PromptTemplate } from "./base"; +export type { + BasePromptTemplateOptions, + PromptTemplateOptions, + StringTemplate, +} from "./base"; +export { PromptMixin, type ModuleRecord, type PromptsRecord } from "./mixin"; +export { + anthropicSummaryPrompt, + anthropicTextQaPrompt, + defaultChoiceSelectPrompt, + defaultCondenseQuestionPrompt, + defaultContextSystemPrompt, + defaultKeywordExtractPrompt, + defaultQueryKeywordExtractPrompt, + defaultRefinePrompt, + defaultSubQuestionPrompt, + defaultSummaryPrompt, + defaultTextQAPrompt, + defaultTreeSummarizePrompt, +} from "./prompt"; +export type { + ChoiceSelectPrompt, + CondenseQuestionPrompt, + ContextSystemPrompt, + KeywordExtractPrompt, + QueryKeywordExtractPrompt, + RefinePrompt, + SubQuestionPrompt, + SummaryPrompt, + TextQAPrompt, + TreeSummarizePrompt, +} from "./prompt"; diff --git a/packages/core/src/prompts/mixin.ts b/packages/core/src/prompts/mixin.ts new file mode 100644 index 000000000..76c9dd723 --- /dev/null +++ b/packages/core/src/prompts/mixin.ts @@ -0,0 +1,79 @@ +import { objectEntries } from "../utils"; +import type { BasePromptTemplate } from "./base"; + +export type PromptsRecord = Record<string, BasePromptTemplate>; +export type ModuleRecord = Record<string, PromptMixin>; + +export abstract class PromptMixin { + validatePrompts(promptsDict: PromptsRecord, moduleDict: ModuleRecord): void { + for (const key of Object.keys(promptsDict)) { + if (key.includes(":")) { + throw new Error(`Prompt key ${key} cannot contain ':'.`); + } + } + + for (const key of Object.keys(moduleDict)) { + if (key.includes(":")) { + throw new Error(`Module key ${key} cannot contain ':'.`); + } + } + } + + getPrompts(): PromptsRecord { + const promptsDict: PromptsRecord = this._getPrompts(); + + const moduleDict = this._getPromptModules(); + + this.validatePrompts(promptsDict, moduleDict); + + const allPrompts: PromptsRecord = { ...promptsDict }; + + for (const [module_name, prompt_module] of objectEntries(moduleDict)) { + for (const [key, prompt] of objectEntries(prompt_module.getPrompts())) { + allPrompts[`${module_name}:${key}`] = prompt; + } + } + return allPrompts; + } + + updatePrompts(prompts: PromptsRecord): void { + const promptModules = this._getPromptModules(); + + this._updatePrompts(prompts); + + const subPrompt: Record<string, PromptsRecord> = {}; + + for (const key in prompts) { + if (key.includes(":")) { + const [module_name, sub_key] = key.split(":"); + + if (!subPrompt[module_name]) { + subPrompt[module_name] = {}; + } + subPrompt[module_name][sub_key] = prompts[key]; + } + } + + for (const [module_name, subPromptDict] of Object.entries(subPrompt)) { + if (!promptModules[module_name]) { + throw new Error(`Module ${module_name} not found.`); + } + + const moduleToUpdate = promptModules[module_name]; + + moduleToUpdate.updatePrompts(subPromptDict); + } + } + + protected abstract _getPrompts(): PromptsRecord; + protected abstract _updatePrompts(prompts: PromptsRecord): void; + + /** + * + * Return a dictionary of sub-modules within the current module + * that also implement PromptMixin (so that their prompts can also be get/set). + * + * Can be blank if no sub-modules. + */ + protected abstract _getPromptModules(): ModuleRecord; +} diff --git a/packages/core/src/prompts/prompt-type.ts b/packages/core/src/prompts/prompt-type.ts new file mode 100644 index 000000000..d6b7af41e --- /dev/null +++ b/packages/core/src/prompts/prompt-type.ts @@ -0,0 +1,64 @@ +import { z } from "zod"; + +const promptType = { + SUMMARY: "summary", + TREE_INSERT: "insert", + TREE_SELECT: "tree_select", + TREE_SELECT_MULTIPLE: "tree_select_multiple", + QUESTION_ANSWER: "text_qa", + REFINE: "refine", + KEYWORD_EXTRACT: "keyword_extract", + QUERY_KEYWORD_EXTRACT: "query_keyword_extract", + SCHEMA_EXTRACT: "schema_extract", + TEXT_TO_SQL: "text_to_sql", + TEXT_TO_GRAPH_QUERY: "text_to_graph_query", + TABLE_CONTEXT: "table_context", + KNOWLEDGE_TRIPLET_EXTRACT: "knowledge_triplet_extract", + SIMPLE_INPUT: "simple_input", + PANDAS: "pandas", + JSON_PATH: "json_path", + SINGLE_SELECT: "single_select", + MULTI_SELECT: "multi_select", + VECTOR_STORE_QUERY: "vector_store_query", + SUB_QUESTION: "sub_question", + SQL_RESPONSE_SYNTHESIS: "sql_response_synthesis", + SQL_RESPONSE_SYNTHESIS_V2: "sql_response_synthesis_v2", + CONVERSATION: "conversation", + DECOMPOSE: "decompose", + CHOICE_SELECT: "choice_select", + CUSTOM: "custom", + RANKGPT_RERANK: "rankgpt_rerank", +} as const; + +const promptTypeSchema = z.enum([ + promptType.SUMMARY, + promptType.TREE_INSERT, + promptType.TREE_SELECT, + promptType.TREE_SELECT_MULTIPLE, + promptType.QUESTION_ANSWER, + promptType.REFINE, + promptType.KEYWORD_EXTRACT, + promptType.QUERY_KEYWORD_EXTRACT, + promptType.SCHEMA_EXTRACT, + promptType.TEXT_TO_SQL, + promptType.TEXT_TO_GRAPH_QUERY, + promptType.TABLE_CONTEXT, + promptType.KNOWLEDGE_TRIPLET_EXTRACT, + promptType.SIMPLE_INPUT, + promptType.PANDAS, + promptType.JSON_PATH, + promptType.SINGLE_SELECT, + promptType.MULTI_SELECT, + promptType.VECTOR_STORE_QUERY, + promptType.SUB_QUESTION, + promptType.SQL_RESPONSE_SYNTHESIS, + promptType.SQL_RESPONSE_SYNTHESIS_V2, + promptType.CONVERSATION, + promptType.DECOMPOSE, + promptType.CHOICE_SELECT, + promptType.CUSTOM, + promptType.RANKGPT_RERANK, +]); + +export const PromptType = promptTypeSchema.enum; +export type PromptType = z.infer<typeof promptTypeSchema>; diff --git a/packages/core/src/prompts/prompt.ts b/packages/core/src/prompts/prompt.ts new file mode 100644 index 000000000..53851f25f --- /dev/null +++ b/packages/core/src/prompts/prompt.ts @@ -0,0 +1,253 @@ +import type { ChatMessage, ToolMetadata } from "../llms"; +import { PromptTemplate } from "./base"; + +export type TextQAPrompt = PromptTemplate<["context", "query"]>; +export type SummaryPrompt = PromptTemplate<["context"]>; +export type RefinePrompt = PromptTemplate< + ["query", "existingAnswer", "context"] +>; +export type TreeSummarizePrompt = PromptTemplate<["context", "query"]>; +export type ChoiceSelectPrompt = PromptTemplate<["context", "query"]>; +export type SubQuestionPrompt = PromptTemplate<["toolsStr", "queryStr"]>; +export type CondenseQuestionPrompt = PromptTemplate< + ["chatHistory", "question"] +>; +export type ContextSystemPrompt = PromptTemplate<["context"]>; +export type KeywordExtractPrompt = PromptTemplate<["context"]>; +export type QueryKeywordExtractPrompt = PromptTemplate<["question"]>; + +export const defaultTextQAPrompt: TextQAPrompt = new PromptTemplate({ + templateVars: ["context", "query"], + template: `Context information is below. +--------------------- +{context} +--------------------- +Given the context information and not prior knowledge, answer the query. +Query: {query} +Answer:`, +}); + +export const anthropicTextQaPrompt: TextQAPrompt = new PromptTemplate({ + templateVars: ["context", "query"], + template: `Context information: +<context> +{context} +</context> +Given the context information and not prior knowledge, answer the query. +Query: {query}`, +}); + +export const defaultSummaryPrompt: SummaryPrompt = new PromptTemplate({ + templateVars: ["context"], + template: `Write a summary of the following. Try to use only the information provided. Try to include as many key details as possible. + + +{context} + + +SUMMARY:""" +`, +}); + +export const anthropicSummaryPrompt: SummaryPrompt = new PromptTemplate({ + templateVars: ["context"], + template: `Summarize the following text. Try to use only the information provided. Try to include as many key details as possible. +<original-text> +{context} +</original-text> + +SUMMARY: +`, +}); + +export const defaultRefinePrompt: RefinePrompt = new PromptTemplate({ + templateVars: ["query", "existingAnswer", "context"], + template: `The original query is as follows: {query} +We have provided an existing answer: {existingAnswer} +We have the opportunity to refine the existing answer (only if needed) with some more context below. +------------ +{context} +------------ +Given the new context, refine the original answer to better answer the query. If the context isn't useful, return the original answer. +Refined Answer:`, +}); + +export const defaultTreeSummarizePrompt = new PromptTemplate({ + templateVars: ["context", "query"], + template: `Context information from multiple sources is below. +--------------------- +{context} +--------------------- +Given the information from multiple sources and not prior knowledge, answer the query. +Query: {query} +Answer:`, +}); + +export const defaultChoiceSelectPrompt = new PromptTemplate({ + templateVars: ["context", "query"], + template: `A list of documents is shown below. Each document has a number next to it along +with a summary of the document. A question is also provided. +Respond with the numbers of the documents +you should consult to answer the question, in order of relevance, as well +as the relevance score. The relevance score is a number from 1-10 based on +how relevant you think the document is to the question. +Do not include any documents that are not relevant to the question. +Example format: +Document 1: +<summary of document 1> + +Document 2: +<summary of document 2> + +... + +Document 10:\n<summary of document 10> + +Question: <question> +Answer: +Doc: 9, Relevance: 7 +Doc: 3, Relevance: 4 +Doc: 7, Relevance: 3 + +Let's try this now: + +{context} +Question: {query} +Answer:`, +}); + +export function buildToolsText(tools: ToolMetadata[]) { + const toolsObj = tools.reduce<Record<string, string>>((acc, tool) => { + acc[tool.name] = tool.description; + return acc; + }, {}); + + return JSON.stringify(toolsObj, null, 4); +} + +const exampleTools: ToolMetadata[] = [ + { + name: "uber_10k", + description: "Provides information about Uber financials for year 2021", + }, + { + name: "lyft_10k", + description: "Provides information about Lyft financials for year 2021", + }, +]; + +const exampleQueryStr = `Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021`; + +const exampleOutput = [ + { + subQuestion: "What is the revenue growth of Uber", + toolName: "uber_10k", + }, + { + subQuestion: "What is the EBITDA of Uber", + toolName: "uber_10k", + }, + { + subQuestion: "What is the revenue growth of Lyft", + toolName: "lyft_10k", + }, + { + subQuestion: "What is the EBITDA of Lyft", + toolName: "lyft_10k", + }, +] as const; + +export const defaultSubQuestionPrompt: SubQuestionPrompt = new PromptTemplate({ + templateVars: ["toolsStr", "queryStr"], + template: `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question: + +# Example 1 +<Tools> +\`\`\`json +${buildToolsText(exampleTools)} +\`\`\` + +<User Question> +${exampleQueryStr} + +<Output> +\`\`\`json +${JSON.stringify(exampleOutput, null, 4)} +\`\`\` + +# Example 2 +<Tools> +\`\`\`json +{toolsStr} +\`\`\` + +<User Question> +{queryStr} + +<Output> +`, +}); + +export const defaultCondenseQuestionPrompt = new PromptTemplate({ + templateVars: ["chatHistory", "question"], + template: `Given a conversation (between Human and Assistant) and a follow up message from Human, rewrite the message to be a standalone question that captures all relevant context from the conversation. + +<Chat History> +{chatHistory} + +<Follow Up Message> +{question} + +<Standalone question> +`, +}); + +export function messagesToHistoryStr(messages: ChatMessage[]) { + return messages.reduce((acc, message) => { + acc += acc ? "\n" : ""; + if (message.role === "user") { + acc += `Human: ${message.content}`; + } else { + acc += `Assistant: ${message.content}`; + } + return acc; + }, ""); +} + +export const defaultContextSystemPrompt: ContextSystemPrompt = + new PromptTemplate({ + templateVars: ["context"], + template: `Context information is below. +--------------------- +{context} +---------------------`, + }); + +export const defaultKeywordExtractPrompt: KeywordExtractPrompt = + new PromptTemplate({ + templateVars: ["maxKeywords", "context"], + template: ` +Some text is provided below. Given the text, extract up to {maxKeywords} keywords from the text. Avoid stopwords. +--------------------- +{context} +--------------------- +Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>' +`, + }).partialFormat({ + maxKeywords: "10", + }); + +export const defaultQueryKeywordExtractPrompt = new PromptTemplate({ + templateVars: ["maxKeywords", "question"], + template: `( + "A question is provided below. Given the question, extract up to {maxKeywords} " + "keywords from the text. Focus on extracting the keywords that we can use " + "to best lookup answers to the question. Avoid stopwords." + "---------------------" + "{question}" + "---------------------" + "Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'" +)`, +}).partialFormat({ + maxKeywords: "10", +}); diff --git a/packages/core/src/schema/index.ts b/packages/core/src/schema/index.ts index a38ba9975..76d8e824d 100644 --- a/packages/core/src/schema/index.ts +++ b/packages/core/src/schema/index.ts @@ -1,4 +1,5 @@ export * from "./node"; export { FileReader, TransformComponent, type BaseReader } from "./type"; +export type { BaseOutputParser } from "./type/base-output-parser"; export { EngineResponse } from "./type/engine–response"; export * from "./zod"; diff --git a/packages/core/src/schema/type/base-output-parser.ts b/packages/core/src/schema/type/base-output-parser.ts new file mode 100644 index 000000000..dbbc5b2db --- /dev/null +++ b/packages/core/src/schema/type/base-output-parser.ts @@ -0,0 +1,8 @@ +/** + * An OutputParser is used to extract structured data from the raw output of the LLM. + */ +export interface BaseOutputParser<T = any> { + parse(output: string): T; + + format(output: string): string; +} diff --git a/packages/core/src/utils/index.ts b/packages/core/src/utils/index.ts index afa21bcd5..6c397467e 100644 --- a/packages/core/src/utils/index.ts +++ b/packages/core/src/utils/index.ts @@ -51,4 +51,8 @@ export { extractImage, extractSingleText, extractText, + messagesToHistory, + toToolDescriptions, } from "./llms"; + +export { objectEntries } from "./object-entries"; diff --git a/packages/core/src/utils/llms.ts b/packages/core/src/utils/llms.ts index d46bce7bd..664e1ff3a 100644 --- a/packages/core/src/utils/llms.ts +++ b/packages/core/src/utils/llms.ts @@ -1,7 +1,9 @@ import type { + ChatMessage, MessageContent, MessageContentDetail, MessageContentTextDetail, + ToolMetadata, } from "../llms"; import type { QueryType } from "../query-engine"; import type { ImageType } from "../schema"; @@ -84,3 +86,24 @@ export const extractDataUrlComponents = ( base64, }; }; + +export function messagesToHistory(messages: ChatMessage[]): string { + return messages.reduce((acc, message) => { + acc += acc ? "\n" : ""; + if (message.role === "user") { + acc += `Human: ${message.content}`; + } else { + acc += `Assistant: ${message.content}`; + } + return acc; + }, ""); +} + +export function toToolDescriptions(tools: ToolMetadata[]): string { + const toolsObj = tools.reduce<Record<string, string>>((acc, tool) => { + acc[tool.name] = tool.description; + return acc; + }, {}); + + return JSON.stringify(toolsObj, null, 4); +} diff --git a/packages/core/src/utils/object-entries.ts b/packages/core/src/utils/object-entries.ts new file mode 100644 index 000000000..18fe9796c --- /dev/null +++ b/packages/core/src/utils/object-entries.ts @@ -0,0 +1,14 @@ +type ObjectEntries<T extends Record<string, any>> = { + [K in keyof T]: [K, T[K]]; +}[keyof T][]; + +/** + * Type safe version of `Object.entries` + */ +export function objectEntries<T extends Record<string, any>>( + obj: T, +): ObjectEntries<{ + [K in keyof T]-?: NonNullable<T[K]>; +}> { + return Object.entries(obj); +} diff --git a/packages/core/tests/prompts.test.ts b/packages/core/tests/prompts.test.ts new file mode 100644 index 000000000..9a599851c --- /dev/null +++ b/packages/core/tests/prompts.test.ts @@ -0,0 +1,161 @@ +import { PromptTemplate, type StringTemplate } from "@llamaindex/core/prompts"; +import type { BaseOutputParser } from "@llamaindex/core/schema"; +import { describe, expect, expectTypeOf, test } from "vitest"; + +describe("type system", () => { + test("StringTemplate", () => { + { + type Test = StringTemplate<["var1", "var2"]>; + expectTypeOf<"{var1}{var2}">().toMatchTypeOf<Test>(); + expectTypeOf<"{var1}">().not.toMatchTypeOf<Test>(); + expectTypeOf<"{var1} var2">().not.toMatchTypeOf<Test>(); + expectTypeOf<"{var2}{var1}">().toMatchTypeOf<Test>(); + } + { + const arr = ["var1", "var2"] as const; + type Test = StringTemplate<typeof arr>; + expectTypeOf<"{var1}{var2}">().toMatchTypeOf<Test>(); + expectTypeOf<"{var1}">().not.toMatchTypeOf<Test>(); + expectTypeOf<"{var1} var2">().not.toMatchTypeOf<Test>(); + expectTypeOf<"{var2}{var1}">().toMatchTypeOf<Test>(); + } + { + const template = + `Act as a natural language processing software. Analyze the given text and return me only a parsable and minified JSON object. + + +Here's the JSON Object structure: +{ + "key1": /* Some instructions */, + "key2": /* Some instructions */, +} + +Here are the rules you must follow: +- You MUST return a valid, parsable JSON object. +- More rules… + +Here are some examples to help you out: +- Example 1… +- Example 2… + +Text: {selection} + +JSON Data:` as const; + type Test = StringTemplate<["selection"]>; + expectTypeOf(template).toMatchTypeOf<Test>(); + } + { + // matrix + type Test = StringTemplate<["a", "b", "c"]>; + expectTypeOf<"{a}{b}{c}">().toMatchTypeOf<Test>(); + expectTypeOf<"{a}{c}{b}">().toMatchTypeOf<Test>(); + expectTypeOf<"{b}{a}{c}">().toMatchTypeOf<Test>(); + expectTypeOf<"{b}{c}{a}">().toMatchTypeOf<Test>(); + expectTypeOf<"{c}{a}{b}">().toMatchTypeOf<Test>(); + expectTypeOf<"{c}{b}{a}">().toMatchTypeOf<Test>(); + } + }); + + test("PromptTemplate", () => { + { + new PromptTemplate({ + // @ts-expect-error + template: "", + templateVars: ["var1"], + }); + } + { + new PromptTemplate({ + template: "something{var1}", + templateVars: ["var1"], + }); + } + { + new PromptTemplate({ + // @ts-expect-error + template: "{var1 }", + templateVars: ["var1"], + }); + } + { + // in this case, type won't be inferred + const template = "{var2}"; + const templateVars = ["var1"]; + new PromptTemplate({ + template, + templateVars, + }); + } + { + const template = "{var2}" as const; + const templateVars = ["var1"] as const; + new PromptTemplate({ + // @ts-expect-error + template, + templateVars, + }); + } + { + const prompt = new PromptTemplate({ + template: "hello {text} {foo}", + templateVars: ["text", "foo"], + }); + + prompt.partialFormat({ + foo: "bar", + goo: "baz", + }); + } + }); +}); + +describe("PromptTemplate", () => { + test("basic usage", () => { + { + const template = "hello {text} {foo}"; + const prompt = new PromptTemplate({ + template, + }); + const partialPrompt = prompt.partialFormat({ + foo: "bar", + }); + expect(partialPrompt).instanceof(PromptTemplate); + expect( + partialPrompt.format({ + text: "world", + }), + ).toBe("hello world bar"); + } + }); + test("should partially format and fully format a prompt", () => { + const prompt = new PromptTemplate({ + template: "hello {text} {foo}", + templateVars: ["text", "foo"], + }); + + const partialPrompt = prompt.partialFormat({ foo: "bar" }); + expect(partialPrompt).toBeInstanceOf(PromptTemplate); + expect(partialPrompt.format({ text: "world" })).toBe("hello world bar"); + }); + + test("should use output parser in formatting", () => { + const outputParser: BaseOutputParser = { + parse(output: string) { + return { output: output }; + }, + + format(query: string) { + return `${query}\noutput_instruction`; + }, + }; + + const prompt = new PromptTemplate({ + template: "hello {text} {foo}", + templateVars: ["text", "foo"], + outputParser: outputParser, + }); + + const formatted = prompt.format({ text: "world", foo: "bar" }); + expect(formatted).toBe("hello world bar\noutput_instruction"); + }); +}); diff --git a/packages/llamaindex/tests/prompts/Mixin.test.ts b/packages/core/tests/prompts/mixin.test.ts similarity index 61% rename from packages/llamaindex/tests/prompts/Mixin.test.ts rename to packages/core/tests/prompts/mixin.test.ts index 4af197633..5ace7bed0 100644 --- a/packages/llamaindex/tests/prompts/Mixin.test.ts +++ b/packages/core/tests/prompts/mixin.test.ts @@ -1,45 +1,52 @@ -import { PromptMixin } from "llamaindex/prompts/index"; +import { + type ModuleRecord, + PromptMixin, + PromptTemplate, +} from "@llamaindex/core/prompts"; import { describe, expect, it } from "vitest"; -type MockPrompt = { - context: string; - query: string; -}; - -const mockPrompt = ({ context, query }: MockPrompt) => - `context: ${context} query: ${query}`; +const mockPrompt = new PromptTemplate({ + templateVars: ["context", "query"], + template: `context: {context} query: {query}`, +}); -const mockPrompt2 = ({ context, query }: MockPrompt) => - `query: ${query} context: ${context}`; +const mockPrompt2 = new PromptTemplate({ + templateVars: ["context", "query"], + template: `query: {query} context: {context}`, +}); -type MockPromptFunction = typeof mockPrompt; +type MockPrompt = typeof mockPrompt; class MockObject2 extends PromptMixin { - _prompt_dict_2: MockPromptFunction; + _prompt_dict_2: MockPrompt; constructor() { super(); this._prompt_dict_2 = mockPrompt; } - protected _getPrompts(): { [x: string]: MockPromptFunction } { + protected _getPrompts() { return { abc: this._prompt_dict_2, }; } - _updatePrompts(promptsDict: { [x: string]: MockPromptFunction }): void { - if ("abc" in promptsDict) { - this._prompt_dict_2 = promptsDict["abc"]; + protected _updatePrompts(prompts: { abc: MockPrompt }): void { + if ("abc" in prompts) { + this._prompt_dict_2 = prompts["abc"]; } } + + protected _getPromptModules(): ModuleRecord { + return {}; + } } class MockObject1 extends PromptMixin { mockObject2: MockObject2; - fooPrompt: MockPromptFunction; - barPrompt: MockPromptFunction; + fooPrompt: MockPrompt; + barPrompt: MockPrompt; constructor() { super(); @@ -76,14 +83,14 @@ describe("PromptMixin", () => { const prompts = mockObj1.getPrompts(); expect( - mockObj1.fooPrompt({ + mockObj1.fooPrompt.format({ context: "{foo}", query: "{foo}", }), ).toEqual("context: {foo} query: {foo}"); expect( - mockObj1.barPrompt({ + mockObj1.barPrompt.format({ context: "{foo} {bar}", query: "{foo} {bar}", }), @@ -103,18 +110,22 @@ describe("PromptMixin", () => { const mockObj1 = new MockObject1(); expect( - mockObj1.barPrompt({ + mockObj1.barPrompt.format({ context: "{foo} {bar}", query: "{foo} {bar}", }), - ).toEqual(mockPrompt({ context: "{foo} {bar}", query: "{foo} {bar}" })); + ).toEqual( + mockPrompt.format({ context: "{foo} {bar}", query: "{foo} {bar}" }), + ); expect( - mockObj1.mockObject2._prompt_dict_2({ + mockObj1.mockObject2._prompt_dict_2.format({ context: "{bar} testing", query: "{bar} testing", }), - ).toEqual(mockPrompt({ context: "{bar} testing", query: "{bar} testing" })); + ).toEqual( + mockPrompt.format({ context: "{bar} testing", query: "{bar} testing" }), + ); mockObj1.updatePrompts({ bar: mockPrompt2, @@ -122,18 +133,20 @@ describe("PromptMixin", () => { }); expect( - mockObj1.barPrompt({ + mockObj1.barPrompt.format({ context: "{foo} {bar}", query: "{bar} {foo}", }), - ).toEqual(mockPrompt2({ context: "{foo} {bar}", query: "{bar} {foo}" })); + ).toEqual( + mockPrompt2.format({ context: "{foo} {bar}", query: "{bar} {foo}" }), + ); expect( - mockObj1.mockObject2._prompt_dict_2({ + mockObj1.mockObject2._prompt_dict_2.format({ context: "{bar} testing", query: "{bar} testing", }), ).toEqual( - mockPrompt2({ context: "{bar} testing", query: "{bar} testing" }), + mockPrompt2.format({ context: "{bar} testing", query: "{bar} testing" }), ); }); }); diff --git a/packages/llamaindex/src/ChatHistory.ts b/packages/llamaindex/src/ChatHistory.ts index 7fa9868f9..16a9f2756 100644 --- a/packages/llamaindex/src/ChatHistory.ts +++ b/packages/llamaindex/src/ChatHistory.ts @@ -1,8 +1,10 @@ import type { ChatMessage, LLM, MessageType } from "@llamaindex/core/llms"; -import { extractText } from "@llamaindex/core/utils"; +import { + defaultSummaryPrompt, + type SummaryPrompt, +} from "@llamaindex/core/prompts"; +import { extractText, messagesToHistory } from "@llamaindex/core/utils"; import { tokenizers, type Tokenizer } from "@llamaindex/env"; -import type { SummaryPrompt } from "./Prompt.js"; -import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js"; import { OpenAI } from "./llm/openai.js"; /** @@ -106,8 +108,8 @@ export class SummaryChatHistory extends ChatHistory { do { promptMessages = [ { - content: this.summaryPrompt({ - context: messagesToHistoryStr(messagesToSummarize), + content: this.summaryPrompt.format({ + context: messagesToHistory(messagesToSummarize), }), role: "user" as MessageType, options: {}, diff --git a/packages/llamaindex/src/OutputParser.ts b/packages/llamaindex/src/OutputParser.ts index 488f33576..da4a2e145 100644 --- a/packages/llamaindex/src/OutputParser.ts +++ b/packages/llamaindex/src/OutputParser.ts @@ -1,5 +1,6 @@ +import type { BaseOutputParser } from "@llamaindex/core/schema"; import type { SubQuestion } from "./engines/query/types.js"; -import type { BaseOutputParser, StructuredOutput } from "./types.js"; +import type { StructuredOutput } from "./types.js"; /** * Error class for output parsing. Due to the nature of LLMs, anytime we use LLM diff --git a/packages/llamaindex/src/Prompt.ts b/packages/llamaindex/src/Prompt.ts deleted file mode 100644 index 0763df484..000000000 --- a/packages/llamaindex/src/Prompt.ts +++ /dev/null @@ -1,411 +0,0 @@ -import type { ChatMessage, ToolMetadata } from "@llamaindex/core/llms"; -import type { SubQuestion } from "./engines/query/types.js"; - -/** - * A SimplePrompt is a function that takes a dictionary of inputs and returns a string. - * NOTE this is a different interface compared to LlamaIndex Python - * NOTE 2: we default to empty string to make it easy to calculate prompt sizes - */ -export type SimplePrompt = ( - input: Record<string, string | undefined>, -) => string; - -/* -DEFAULT_TEXT_QA_PROMPT_TMPL = ( - "Context information is below.\n" - "---------------------\n" - "{context_str}\n" - "---------------------\n" - "Given the context information and not prior knowledge, " - "answer the query.\n" - "Query: {query_str}\n" - "Answer: " -) -*/ - -export const defaultTextQaPrompt = ({ context = "", query = "" }) => { - return `Context information is below. ---------------------- -${context} ---------------------- -Given the context information and not prior knowledge, answer the query. -Query: ${query} -Answer:`; -}; - -export type TextQaPrompt = typeof defaultTextQaPrompt; - -export const anthropicTextQaPrompt: TextQaPrompt = ({ - context = "", - query = "", -}) => { - return `Context information: -<context> -${context} -</context> -Given the context information and not prior knowledge, answer the query. -Query: ${query}`; -}; - -/* -DEFAULT_SUMMARY_PROMPT_TMPL = ( - "Write a summary of the following. Try to use only the " - "information provided. " - "Try to include as many key details as possible.\n" - "\n" - "\n" - "{context_str}\n" - "\n" - "\n" - 'SUMMARY:"""\n' -) -*/ - -export const defaultSummaryPrompt = ({ context = "" }) => { - return `Write a summary of the following. Try to use only the information provided. Try to include as many key details as possible. - - -${context} - - -SUMMARY:""" -`; -}; - -export type SummaryPrompt = typeof defaultSummaryPrompt; - -export const anthropicSummaryPrompt: SummaryPrompt = ({ context = "" }) => { - return `Summarize the following text. Try to use only the information provided. Try to include as many key details as possible. -<original-text> -${context} -</original-text> - -SUMMARY: -`; -}; - -/* -DEFAULT_REFINE_PROMPT_TMPL = ( - "The original query is as follows: {query_str}\n" - "We have provided an existing answer: {existing_answer}\n" - "We have the opportunity to refine the existing answer " - "(only if needed) with some more context below.\n" - "------------\n" - "{context_msg}\n" - "------------\n" - "Given the new context, refine the original answer to better " - "answer the query. " - "If the context isn't useful, return the original answer.\n" - "Refined Answer: " -) -*/ - -export const defaultRefinePrompt = ({ - query = "", - existingAnswer = "", - context = "", -}) => { - return `The original query is as follows: ${query} -We have provided an existing answer: ${existingAnswer} -We have the opportunity to refine the existing answer (only if needed) with some more context below. ------------- -${context} ------------- -Given the new context, refine the original answer to better answer the query. If the context isn't useful, return the original answer. -Refined Answer:`; -}; - -export type RefinePrompt = typeof defaultRefinePrompt; - -/* -DEFAULT_TREE_SUMMARIZE_TMPL = ( - "Context information from multiple sources is below.\n" - "---------------------\n" - "{context_str}\n" - "---------------------\n" - "Given the information from multiple sources and not prior knowledge, " - "answer the query.\n" - "Query: {query_str}\n" - "Answer: " -) -*/ - -export const defaultTreeSummarizePrompt = ({ context = "", query = "" }) => { - return `Context information from multiple sources is below. ---------------------- -${context} ---------------------- -Given the information from multiple sources and not prior knowledge, answer the query. -Query: ${query} -Answer:`; -}; - -export type TreeSummarizePrompt = typeof defaultTreeSummarizePrompt; - -export const defaultChoiceSelectPrompt = ({ context = "", query = "" }) => { - return `A list of documents is shown below. Each document has a number next to it along -with a summary of the document. A question is also provided. -Respond with the numbers of the documents -you should consult to answer the question, in order of relevance, as well -as the relevance score. The relevance score is a number from 1-10 based on -how relevant you think the document is to the question. -Do not include any documents that are not relevant to the question. -Example format: -Document 1: -<summary of document 1> - -Document 2: -<summary of document 2> - -... - -Document 10:\n<summary of document 10> - -Question: <question> -Answer: -Doc: 9, Relevance: 7 -Doc: 3, Relevance: 4 -Doc: 7, Relevance: 3 - -Let's try this now: - -${context} -Question: ${query} -Answer:`; -}; - -export type ChoiceSelectPrompt = typeof defaultChoiceSelectPrompt; - -/* -PREFIX = """\ -Given a user question, and a list of tools, output a list of relevant sub-questions \ -that when composed can help answer the full user question: - -""" - - -example_query_str = ( - "Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021" -) -example_tools = [ - ToolMetadata( - name="uber_10k", - description="Provides information about Uber financials for year 2021", - ), - ToolMetadata( - name="lyft_10k", - description="Provides information about Lyft financials for year 2021", - ), -] -example_tools_str = build_tools_text(example_tools) -example_output = [ - SubQuestion( - sub_question="What is the revenue growth of Uber", tool_name="uber_10k" - ), - SubQuestion(sub_question="What is the EBITDA of Uber", tool_name="uber_10k"), - SubQuestion( - sub_question="What is the revenue growth of Lyft", tool_name="lyft_10k" - ), - SubQuestion(sub_question="What is the EBITDA of Lyft", tool_name="lyft_10k"), -] -example_output_str = json.dumps([x.dict() for x in example_output], indent=4) - -EXAMPLES = ( - """\ -# Example 1 -<Tools> -```json -{tools_str} -``` - -<User Question> -{query_str} - - -<Output> -```json -{output_str} -``` - -""".format( - query_str=example_query_str, - tools_str=example_tools_str, - output_str=example_output_str, - ) - .replace("{", "{{") - .replace("}", "}}") -) - -SUFFIX = """\ -# Example 2 -<Tools> -```json -{tools_str} -``` - -<User Question> -{query_str} - -<Output> -""" - -DEFAULT_SUB_QUESTION_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX -*/ - -export function buildToolsText(tools: ToolMetadata[]) { - const toolsObj = tools.reduce<Record<string, string>>((acc, tool) => { - acc[tool.name] = tool.description; - return acc; - }, {}); - - return JSON.stringify(toolsObj, null, 4); -} - -const exampleTools: ToolMetadata[] = [ - { - name: "uber_10k", - description: "Provides information about Uber financials for year 2021", - }, - { - name: "lyft_10k", - description: "Provides information about Lyft financials for year 2021", - }, -]; - -const exampleQueryStr = `Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021`; - -const exampleOutput: SubQuestion[] = [ - { - subQuestion: "What is the revenue growth of Uber", - toolName: "uber_10k", - }, - { - subQuestion: "What is the EBITDA of Uber", - toolName: "uber_10k", - }, - { - subQuestion: "What is the revenue growth of Lyft", - toolName: "lyft_10k", - }, - { - subQuestion: "What is the EBITDA of Lyft", - toolName: "lyft_10k", - }, -]; - -export const defaultSubQuestionPrompt = ({ toolsStr = "", queryStr = "" }) => { - return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question: - -# Example 1 -<Tools> -\`\`\`json -${buildToolsText(exampleTools)} -\`\`\` - -<User Question> -${exampleQueryStr} - -<Output> -\`\`\`json -${JSON.stringify(exampleOutput, null, 4)} -\`\`\` - -# Example 2 -<Tools> -\`\`\`json -${toolsStr} -\`\`\` - -<User Question> -${queryStr} - -<Output> -`; -}; - -export type SubQuestionPrompt = typeof defaultSubQuestionPrompt; - -// DEFAULT_TEMPLATE = """\ -// Given a conversation (between Human and Assistant) and a follow up message from Human, \ -// rewrite the message to be a standalone question that captures all relevant context \ -// from the conversation. - -// <Chat History> -// {chat_history} - -// <Follow Up Message> -// {question} - -// <Standalone question> -// """ - -export const defaultCondenseQuestionPrompt = ({ - chatHistory = "", - question = "", -}) => { - return `Given a conversation (between Human and Assistant) and a follow up message from Human, rewrite the message to be a standalone question that captures all relevant context from the conversation. - -<Chat History> -${chatHistory} - -<Follow Up Message> -${question} - -<Standalone question> -`; -}; - -export type CondenseQuestionPrompt = typeof defaultCondenseQuestionPrompt; - -export function messagesToHistoryStr(messages: ChatMessage[]) { - return messages.reduce((acc, message) => { - acc += acc ? "\n" : ""; - if (message.role === "user") { - acc += `Human: ${message.content}`; - } else { - acc += `Assistant: ${message.content}`; - } - return acc; - }, ""); -} - -export const defaultContextSystemPrompt = ({ context = "" }) => { - if (!context) return ""; - return `Context information is below. ---------------------- -${context} ----------------------`; -}; - -export type ContextSystemPrompt = typeof defaultContextSystemPrompt; - -export const defaultKeywordExtractPrompt = ({ - context = "", - maxKeywords = 10, -}) => { - return ` -Some text is provided below. Given the text, extract up to ${maxKeywords} keywords from the text. Avoid stopwords. ---------------------- -${context} ---------------------- -Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>' -`; -}; - -export type KeywordExtractPrompt = typeof defaultKeywordExtractPrompt; - -export const defaultQueryKeywordExtractPrompt = ({ - question = "", - maxKeywords = 10, -}) => { - return `( - "A question is provided below. Given the question, extract up to ${maxKeywords} " - "keywords from the text. Focus on extracting the keywords that we can use " - "to best lookup answers to the question. Avoid stopwords." - "---------------------" - "${question}" - "---------------------" - "Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'" -)`; -}; -export type QueryKeywordExtractPrompt = typeof defaultQueryKeywordExtractPrompt; diff --git a/packages/llamaindex/src/PromptHelper.ts b/packages/llamaindex/src/PromptHelper.ts index 411c4487f..40ff503d5 100644 --- a/packages/llamaindex/src/PromptHelper.ts +++ b/packages/llamaindex/src/PromptHelper.ts @@ -1,6 +1,6 @@ import { SentenceSplitter } from "@llamaindex/core/node-parser"; +import type { PromptTemplate } from "@llamaindex/core/prompts"; import { type Tokenizer, tokenizers } from "@llamaindex/env"; -import type { SimplePrompt } from "./Prompt.js"; import { DEFAULT_CHUNK_OVERLAP_RATIO, DEFAULT_CONTEXT_WINDOW, @@ -8,17 +8,22 @@ import { DEFAULT_PADDING, } from "./constants.js"; -export function getEmptyPromptTxt(prompt: SimplePrompt) { - return prompt({}); +/** + * Get the empty prompt text given a prompt. + */ +export function getEmptyPromptTxt(prompt: PromptTemplate) { + return prompt.format({ + ...Object.fromEntries( + [...prompt.templateVars.keys()].map((key) => [key, ""]), + ), + }); } /** * Get biggest empty prompt size from a list of prompts. * Used to calculate the maximum size of inputs to the LLM. - * @param prompts - * @returns */ -export function getBiggestPrompt(prompts: SimplePrompt[]) { +export function getBiggestPrompt(prompts: PromptTemplate[]) { const emptyPromptTexts = prompts.map(getEmptyPromptTxt); const emptyPromptLengths = emptyPromptTexts.map((text) => text.length); const maxEmptyPromptLength = Math.max(...emptyPromptLengths); @@ -59,7 +64,7 @@ export class PromptHelper { * @param prompt * @returns */ - private getAvailableContextSize(prompt: SimplePrompt) { + private getAvailableContextSize(prompt: PromptTemplate) { const emptyPromptText = getEmptyPromptTxt(prompt); const promptTokens = this.tokenizer.encode(emptyPromptText); const numPromptTokens = promptTokens.length; @@ -69,13 +74,9 @@ export class PromptHelper { /** * Find the maximum size of each chunk given a prompt. - * @param prompt - * @param numChunks - * @param padding - * @returns */ private getAvailableChunkSize( - prompt: SimplePrompt, + prompt: PromptTemplate, numChunks = 1, padding = 5, ) { @@ -92,13 +93,9 @@ export class PromptHelper { /** * Creates a text splitter with the correct chunk sizes and overlaps given a prompt. - * @param prompt - * @param numChunks - * @param padding - * @returns */ getTextSplitterGivenPrompt( - prompt: SimplePrompt, + prompt: PromptTemplate, numChunks = 1, padding = DEFAULT_PADDING, ) { @@ -112,13 +109,9 @@ export class PromptHelper { /** * Repack resplits the strings based on the optimal text splitter. - * @param prompt - * @param textChunks - * @param padding - * @returns */ repack( - prompt: SimplePrompt, + prompt: PromptTemplate, textChunks: string[], padding = DEFAULT_PADDING, ) { diff --git a/packages/llamaindex/src/QuestionGenerator.ts b/packages/llamaindex/src/QuestionGenerator.ts index e943a3861..47ea6001e 100644 --- a/packages/llamaindex/src/QuestionGenerator.ts +++ b/packages/llamaindex/src/QuestionGenerator.ts @@ -1,16 +1,20 @@ import type { LLM, ToolMetadata } from "@llamaindex/core/llms"; +import { + defaultSubQuestionPrompt, + type ModuleRecord, + PromptMixin, + type SubQuestionPrompt, +} from "@llamaindex/core/prompts"; import type { QueryType } from "@llamaindex/core/query-engine"; -import { extractText } from "@llamaindex/core/utils"; +import type { BaseOutputParser } from "@llamaindex/core/schema"; +import { extractText, toToolDescriptions } from "@llamaindex/core/utils"; import { SubQuestionOutputParser } from "./OutputParser.js"; -import type { SubQuestionPrompt } from "./Prompt.js"; -import { buildToolsText, defaultSubQuestionPrompt } from "./Prompt.js"; import type { BaseQuestionGenerator, SubQuestion, } from "./engines/query/types.js"; import { OpenAI } from "./llm/openai.js"; -import { PromptMixin } from "./prompts/index.js"; -import type { BaseOutputParser, StructuredOutput } from "./types.js"; +import type { StructuredOutput } from "./types.js"; /** * LLMQuestionGenerator uses the LLM to generate new questions for the LLM using tools and a user query. @@ -49,11 +53,11 @@ export class LLMQuestionGenerator tools: ToolMetadata[], query: QueryType, ): Promise<SubQuestion[]> { - const toolsStr = buildToolsText(tools); + const toolsStr = toToolDescriptions(tools); const queryStr = extractText(query); const prediction = ( await this.llm.complete({ - prompt: this.prompt({ + prompt: this.prompt.format({ toolsStr, queryStr, }), @@ -64,4 +68,8 @@ export class LLMQuestionGenerator return structuredOutput.parsedOutput; } + + protected _getPromptModules(): ModuleRecord { + return {}; + } } diff --git a/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts b/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts index 009a50a06..f097a1985 100644 --- a/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts @@ -1,26 +1,28 @@ import type { ChatMessage, LLM } from "@llamaindex/core/llms"; +import { + type CondenseQuestionPrompt, + defaultCondenseQuestionPrompt, + type ModuleRecord, + PromptMixin, +} from "@llamaindex/core/prompts"; import type { EngineResponse } from "@llamaindex/core/schema"; import { extractText, + messagesToHistory, streamReducer, wrapEventCaller, } from "@llamaindex/core/utils"; import type { ChatHistory } from "../../ChatHistory.js"; import { getHistory } from "../../ChatHistory.js"; -import type { CondenseQuestionPrompt } from "../../Prompt.js"; -import { - defaultCondenseQuestionPrompt, - messagesToHistoryStr, -} from "../../Prompt.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { llmFromSettingsOrContext } from "../../Settings.js"; -import { PromptMixin } from "../../prompts/index.js"; import type { QueryEngine } from "../../types.js"; import type { ChatEngine, ChatEngineParamsNonStreaming, ChatEngineParamsStreaming, } from "./types.js"; + /** * CondenseQuestionChatEngine is used in conjunction with a Index (for example VectorStoreIndex). * It does two steps on taking a user's chat message: first, it condenses the chat message @@ -56,6 +58,10 @@ export class CondenseQuestionChatEngine init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt; } + protected _getPromptModules(): ModuleRecord { + return {}; + } + protected _getPrompts(): { condenseMessagePrompt: CondenseQuestionPrompt } { return { condenseMessagePrompt: this.condenseMessagePrompt, @@ -71,12 +77,12 @@ export class CondenseQuestionChatEngine } private async condenseQuestion(chatHistory: ChatHistory, question: string) { - const chatHistoryStr = messagesToHistoryStr( + const chatHistoryStr = messagesToHistory( await chatHistory.requestMessages(), ); return this.llm.complete({ - prompt: this.condenseMessagePrompt({ + prompt: this.condenseMessagePrompt.format({ question: question, chatHistory: chatHistoryStr, }), diff --git a/packages/llamaindex/src/engines/chat/ContextChatEngine.ts b/packages/llamaindex/src/engines/chat/ContextChatEngine.ts index ea1f04350..c7021cd77 100644 --- a/packages/llamaindex/src/engines/chat/ContextChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/ContextChatEngine.ts @@ -4,6 +4,12 @@ import type { MessageContent, MessageType, } from "@llamaindex/core/llms"; +import { + type ContextSystemPrompt, + type ModuleRecord, + PromptMixin, + type PromptsRecord, +} from "@llamaindex/core/prompts"; import { EngineResponse, MetadataMode } from "@llamaindex/core/schema"; import { extractText, @@ -13,11 +19,9 @@ import { } from "@llamaindex/core/utils"; import type { ChatHistory } from "../../ChatHistory.js"; import { getHistory } from "../../ChatHistory.js"; -import type { ContextSystemPrompt } from "../../Prompt.js"; import type { BaseRetriever } from "../../Retriever.js"; import { Settings } from "../../Settings.js"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; -import { PromptMixin } from "../../prompts/Mixin.js"; import { DefaultContextGenerator } from "./DefaultContextGenerator.js"; import type { ChatEngine, @@ -33,7 +37,7 @@ import type { export class ContextChatEngine extends PromptMixin implements ChatEngine { chatModel: LLM; chatHistory: ChatHistory; - contextGenerator: ContextGenerator; + contextGenerator: ContextGenerator & PromptMixin; systemPrompt?: string; constructor(init: { @@ -58,7 +62,19 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { this.systemPrompt = init.systemPrompt; } - protected _getPromptModules(): Record<string, ContextGenerator> { + protected _getPrompts(): PromptsRecord { + return { + ...this.contextGenerator.getPrompts(), + }; + } + + protected _updatePrompts(prompts: { + contextSystemPrompt: ContextSystemPrompt; + }): void { + this.contextGenerator.updatePrompts(prompts); + } + + protected _getPromptModules(): ModuleRecord { return { contextGenerator: this.contextGenerator, }; diff --git a/packages/llamaindex/src/engines/chat/DefaultContextGenerator.ts b/packages/llamaindex/src/engines/chat/DefaultContextGenerator.ts index 2b5edba05..533ed3b53 100644 --- a/packages/llamaindex/src/engines/chat/DefaultContextGenerator.ts +++ b/packages/llamaindex/src/engines/chat/DefaultContextGenerator.ts @@ -1,9 +1,12 @@ import type { MessageContent, MessageType } from "@llamaindex/core/llms"; +import { + type ContextSystemPrompt, + defaultContextSystemPrompt, + type ModuleRecord, + PromptMixin, +} from "@llamaindex/core/prompts"; import { MetadataMode, type NodeWithScore } from "@llamaindex/core/schema"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; -import type { ContextSystemPrompt } from "../../Prompt.js"; -import { defaultContextSystemPrompt } from "../../Prompt.js"; -import { PromptMixin } from "../../prompts/index.js"; import type { BaseRetriever } from "../../Retriever.js"; import { createMessageContent } from "../../synthesizers/utils.js"; import type { Context, ContextGenerator } from "./types.js"; @@ -35,6 +38,10 @@ export class DefaultContextGenerator this.metadataMode = init.metadataMode ?? MetadataMode.NONE; } + protected _getPromptModules(): ModuleRecord { + return {}; + } + protected _getPrompts(): { contextSystemPrompt: ContextSystemPrompt } { return { contextSystemPrompt: this.contextSystemPrompt, diff --git a/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts b/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts index 659c077c9..665d31500 100644 --- a/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts +++ b/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts @@ -1,7 +1,7 @@ +import { PromptMixin } from "@llamaindex/core/prompts"; import { EngineResponse, type NodeWithScore } from "@llamaindex/core/schema"; import { wrapEventCaller } from "@llamaindex/core/utils"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; -import { PromptMixin } from "../../prompts/Mixin.js"; import type { BaseRetriever } from "../../Retriever.js"; import type { BaseSynthesizer } from "../../synthesizers/index.js"; import { ResponseSynthesizer } from "../../synthesizers/index.js"; @@ -38,6 +38,12 @@ export class RetrieverQueryEngine extends PromptMixin implements QueryEngine { this.nodePostprocessors = nodePostprocessors || []; } + protected _getPrompts() { + return {}; + } + + protected _updatePrompts() {} + _getPromptModules() { return { responseSynthesizer: this.responseSynthesizer, diff --git a/packages/llamaindex/src/engines/query/RouterQueryEngine.ts b/packages/llamaindex/src/engines/query/RouterQueryEngine.ts index 8c709b4f3..4f3c9b154 100644 --- a/packages/llamaindex/src/engines/query/RouterQueryEngine.ts +++ b/packages/llamaindex/src/engines/query/RouterQueryEngine.ts @@ -1,9 +1,9 @@ +import { PromptMixin } from "@llamaindex/core/prompts"; import type { QueryType } from "@llamaindex/core/query-engine"; import { EngineResponse, type NodeWithScore } from "@llamaindex/core/schema"; import { extractText } from "@llamaindex/core/utils"; import type { ServiceContext } from "../../ServiceContext.js"; import { llmFromSettingsOrContext } from "../../Settings.js"; -import { PromptMixin } from "../../prompts/index.js"; import type { BaseSelector } from "../../selectors/index.js"; import { LLMSingleSelector } from "../../selectors/index.js"; import { TreeSummarize } from "../../synthesizers/index.js"; @@ -79,7 +79,13 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine { this.verbose = init.verbose ?? false; } - _getPromptModules(): Record<string, any> { + protected _getPrompts() { + return {}; + } + + protected _updatePrompts() {} + + protected _getPromptModules() { return { selector: this.selector, summarizer: this.summarizer, diff --git a/packages/llamaindex/src/engines/query/SubQuestionQueryEngine.ts b/packages/llamaindex/src/engines/query/SubQuestionQueryEngine.ts index 79ad860b4..a03de63c0 100644 --- a/packages/llamaindex/src/engines/query/SubQuestionQueryEngine.ts +++ b/packages/llamaindex/src/engines/query/SubQuestionQueryEngine.ts @@ -5,7 +5,6 @@ import { } from "@llamaindex/core/schema"; import { LLMQuestionGenerator } from "../../QuestionGenerator.js"; import type { ServiceContext } from "../../ServiceContext.js"; -import { PromptMixin } from "../../prompts/Mixin.js"; import type { BaseSynthesizer } from "../../synthesizers/index.js"; import { CompactAndRefine, @@ -13,6 +12,7 @@ import { } from "../../synthesizers/index.js"; import type { BaseTool, ToolMetadata } from "@llamaindex/core/llms"; +import { PromptMixin, type PromptsRecord } from "@llamaindex/core/prompts"; import type { BaseQueryEngine, QueryType } from "@llamaindex/core/query-engine"; import { wrapEventCaller } from "@llamaindex/core/utils"; import type { BaseQuestionGenerator, SubQuestion } from "./types.js"; @@ -43,6 +43,12 @@ export class SubQuestionQueryEngine this.metadatas = init.queryEngineTools.map((tool) => tool.metadata); } + protected _getPrompts(): PromptsRecord { + return {}; + } + + protected _updatePrompts() {} + protected _getPromptModules(): Record<string, any> { return { questionGen: this.questionGen, diff --git a/packages/llamaindex/src/evaluation/Correctness.ts b/packages/llamaindex/src/evaluation/Correctness.ts index 3dbf34a6e..3bc3a5146 100644 --- a/packages/llamaindex/src/evaluation/Correctness.ts +++ b/packages/llamaindex/src/evaluation/Correctness.ts @@ -1,7 +1,7 @@ import type { ChatMessage, LLM } from "@llamaindex/core/llms"; +import { PromptMixin } from "@llamaindex/core/prompts"; import { MetadataMode } from "@llamaindex/core/schema"; import { extractText } from "@llamaindex/core/utils"; -import { PromptMixin } from "../prompts/Mixin.js"; import type { ServiceContext } from "../ServiceContext.js"; import { llmFromSettingsOrContext } from "../Settings.js"; import type { CorrectnessSystemPrompt } from "./prompts.js"; @@ -41,9 +41,18 @@ export class CorrectnessEvaluator extends PromptMixin implements BaseEvaluator { this.parserFunction = params?.parserFunction ?? defaultEvaluationParser; } - _updatePrompts(prompts: { + protected _getPrompts() { + return { + correctnessPrompt: this.correctnessPrompt, + }; + } + protected _getPromptModules() { + return {}; + } + + protected _updatePrompts(prompts: { correctnessPrompt: CorrectnessSystemPrompt; - }): void { + }) { if ("correctnessPrompt" in prompts) { this.correctnessPrompt = prompts["correctnessPrompt"]; } @@ -69,11 +78,11 @@ export class CorrectnessEvaluator extends PromptMixin implements BaseEvaluator { const messages: ChatMessage[] = [ { role: "system", - content: this.correctnessPrompt(), + content: this.correctnessPrompt.format(), }, { role: "user", - content: defaultUserPrompt({ + content: defaultUserPrompt.format({ query: extractText(query), generatedAnswer: response, referenceAnswer: reference || "(NO REFERENCE ANSWER SUPPLIED)", diff --git a/packages/llamaindex/src/evaluation/Faithfulness.ts b/packages/llamaindex/src/evaluation/Faithfulness.ts index b1a84a551..dce2b50f3 100644 --- a/packages/llamaindex/src/evaluation/Faithfulness.ts +++ b/packages/llamaindex/src/evaluation/Faithfulness.ts @@ -1,8 +1,8 @@ +import { PromptMixin, type ModuleRecord } from "@llamaindex/core/prompts"; import { Document, MetadataMode } from "@llamaindex/core/schema"; import { extractText } from "@llamaindex/core/utils"; import type { ServiceContext } from "../ServiceContext.js"; import { SummaryIndex } from "../indices/summary/index.js"; -import { PromptMixin } from "../prompts/Mixin.js"; import type { FaithfulnessRefinePrompt, FaithfulnessTextQAPrompt, @@ -43,6 +43,10 @@ export class FaithfulnessEvaluator params?.faithFulnessRefinePrompt ?? defaultFaithfulnessRefinePrompt; } + protected _getPromptModules(): ModuleRecord { + return {}; + } + protected _getPrompts(): { [x: string]: any } { return { faithfulnessSystemPrompt: this.evalTemplate, diff --git a/packages/llamaindex/src/evaluation/Relevancy.ts b/packages/llamaindex/src/evaluation/Relevancy.ts index e738b328e..d06c45897 100644 --- a/packages/llamaindex/src/evaluation/Relevancy.ts +++ b/packages/llamaindex/src/evaluation/Relevancy.ts @@ -1,8 +1,8 @@ +import { PromptMixin, type ModuleRecord } from "@llamaindex/core/prompts"; import { Document, MetadataMode } from "@llamaindex/core/schema"; import { extractText } from "@llamaindex/core/utils"; import type { ServiceContext } from "../ServiceContext.js"; import { SummaryIndex } from "../indices/summary/index.js"; -import { PromptMixin } from "../prompts/Mixin.js"; import type { RelevancyEvalPrompt, RelevancyRefinePrompt } from "./prompts.js"; import { defaultRelevancyEvalPrompt, @@ -39,6 +39,10 @@ export class RelevancyEvaluator extends PromptMixin implements BaseEvaluator { params?.refineTemplate ?? defaultRelevancyRefinePrompt; } + protected _getPromptModules(): ModuleRecord { + return {}; + } + _getPrompts() { return { evalTemplate: this.evalTemplate, diff --git a/packages/llamaindex/src/evaluation/prompts.ts b/packages/llamaindex/src/evaluation/prompts.ts index ed51156a6..a1b8e13ea 100644 --- a/packages/llamaindex/src/evaluation/prompts.ts +++ b/packages/llamaindex/src/evaluation/prompts.ts @@ -1,26 +1,26 @@ -export const defaultUserPrompt = ({ - query, - referenceAnswer, - generatedAnswer, -}: { - query: string; - referenceAnswer: string; - generatedAnswer: string; -}) => ` +import { PromptTemplate } from "@llamaindex/core/prompts"; + +export const defaultUserPrompt = new PromptTemplate({ + templateVars: ["query", "referenceAnswer", "generatedAnswer"], + template: ` ## User Query -${query} +{query} ## Reference Answer -${referenceAnswer} +{referenceAnswer} ## Generated Answer -${generatedAnswer} -`; +{generatedAnswer} +`, +}); -export type UserPrompt = typeof defaultUserPrompt; +export type UserPrompt = PromptTemplate< + ["query", "referenceAnswer", "generatedAnswer"] +>; -export const defaultCorrectnessSystemPrompt = - () => `You are an expert evaluation system for a question answering chatbot. +export const defaultCorrectnessSystemPrompt: CorrectnessSystemPrompt = + new PromptTemplate({ + template: `You are an expert evaluation system for a question answering chatbot. You are given the following information: - a user query, and @@ -47,41 +47,35 @@ Example Response: 4.0 The generated answer has the exact same metrics as the reference answer but it is not as concise. -`; - -export type CorrectnessSystemPrompt = typeof defaultCorrectnessSystemPrompt; - -export const defaultFaithfulnessRefinePrompt = ({ - query, - context, - existingAnswer, -}: { - query: string; - context: string; - existingAnswer: string; -}) => ` +`, + }); + +export type CorrectnessSystemPrompt = PromptTemplate<[]>; + +export const defaultFaithfulnessRefinePrompt = new PromptTemplate({ + templateVars: ["query", "existingAnswer", "context"], + template: ` We want to understand if the following information is present -in the context information: ${query} -We have provided an existing YES/NO answer: ${existingAnswer} +in the context information: {query} +We have provided an existing YES/NO answer: {existingAnswer} We have the opportunity to refine the existing answer (only if needed) with some more context below. ------------ -${context} +{context} ------------ If the existing answer was already YES, still answer YES. If the information is present in the new context, answer YES. Otherwise answer NO. -`; +`, +}); -export type FaithfulnessRefinePrompt = typeof defaultFaithfulnessRefinePrompt; +export type FaithfulnessRefinePrompt = PromptTemplate< + ["query", "existingAnswer", "context"] +>; -export const defaultFaithfulnessTextQaPrompt = ({ - query, - context, -}: { - query: string; - context: string; -}) => ` +export const defaultFaithfulnessTextQaPrompt = new PromptTemplate({ + templateVars: ["context", "query"], + template: ` Please tell if a given piece of information is supported by the context. You need to answer with either YES or NO. @@ -107,49 +101,43 @@ It is generally double-crusted, with pastry both above and below the filling; the upper crust may be solid or latticed (woven of crosswise strips). Answer: NO -Information: ${query} -Context: ${context} +Information: {query} +Context: {context} Answer: -`; +`, +}); -export type FaithfulnessTextQAPrompt = typeof defaultFaithfulnessTextQaPrompt; +export type FaithfulnessTextQAPrompt = PromptTemplate<["query", "context"]>; -export const defaultRelevancyEvalPrompt = ({ - query, - context, -}: { - query: string; - context: string; -}) => `Your task is to evaluate if the response for the query is in line with the context information provided. +export type RelevancyEvalPrompt = PromptTemplate<["context", "query"]>; +export const defaultRelevancyEvalPrompt = new PromptTemplate({ + templateVars: ["context", "query"], + template: `Your task is to evaluate if the response for the query is in line with the context information provided. You have two options to answer. Either YES/ NO. Answer - YES, if the response for the query is in line with context information otherwise NO. -Query and Response: ${query} -Context: ${context} -Answer: `; - -export type RelevancyEvalPrompt = typeof defaultRelevancyEvalPrompt; - -export const defaultRelevancyRefinePrompt = ({ - query, - existingAnswer, - contextMsg, -}: { - query: string; - existingAnswer: string; - contextMsg: string; -}) => `We want to understand if the following query and response is +Query and Response: {query} +Context: {context} +Answer: `, +}); + +export const defaultRelevancyRefinePrompt = new PromptTemplate({ + templateVars: ["query", "existingAnswer", "contextMsg"], + template: `We want to understand if the following query and response is in line with the context information: -${query} +{query} We have provided an existing YES/NO answer: -${existingAnswer} +{existingAnswer} We have the opportunity to refine the existing answer (only if needed) with some more context below. ------------ -${contextMsg} +{contextMsg} ------------ If the existing answer was already YES, still answer YES. If the information is present in the new context, answer YES. Otherwise answer NO. -`; +`, +}); -export type RelevancyRefinePrompt = typeof defaultRelevancyRefinePrompt; +export type RelevancyRefinePrompt = PromptTemplate< + ["query", "existingAnswer", "contextMsg"] +>; diff --git a/packages/llamaindex/src/index.edge.ts b/packages/llamaindex/src/index.edge.ts index 100361ef2..4332921b4 100644 --- a/packages/llamaindex/src/index.edge.ts +++ b/packages/llamaindex/src/index.edge.ts @@ -1,6 +1,7 @@ import type { AgentEndEvent, AgentStartEvent } from "./agent/types.js"; import type { RetrievalEndEvent, RetrievalStartEvent } from "./llm/types.js"; +export * from "@llamaindex/core/prompts"; export * from "@llamaindex/core/schema"; declare module "@llamaindex/core/global" { @@ -43,9 +44,7 @@ export * from "./nodeParsers/index.js"; export * from "./objects/index.js"; export * from "./OutputParser.js"; export * from "./postprocessors/index.js"; -export * from "./Prompt.js"; export * from "./PromptHelper.js"; -export * from "./prompts/index.js"; export * from "./QuestionGenerator.js"; export * from "./Retriever.js"; export * from "./selectors/index.js"; diff --git a/packages/llamaindex/src/indices/keyword/index.ts b/packages/llamaindex/src/indices/keyword/index.ts index 6bfb5da04..4f53b34ff 100644 --- a/packages/llamaindex/src/indices/keyword/index.ts +++ b/packages/llamaindex/src/indices/keyword/index.ts @@ -4,14 +4,6 @@ import type { NodeWithScore, } from "@llamaindex/core/schema"; import { MetadataMode } from "@llamaindex/core/schema"; -import type { - KeywordExtractPrompt, - QueryKeywordExtractPrompt, -} from "../../Prompt.js"; -import { - defaultKeywordExtractPrompt, - defaultQueryKeywordExtractPrompt, -} from "../../Prompt.js"; import type { BaseRetriever, RetrieveParams } from "../../Retriever.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { serviceContextFromDefaults } from "../../ServiceContext.js"; @@ -32,6 +24,12 @@ import { } from "./utils.js"; import type { LLM } from "@llamaindex/core/llms"; +import { + defaultKeywordExtractPrompt, + defaultQueryKeywordExtractPrompt, + type KeywordExtractPrompt, + type QueryKeywordExtractPrompt, +} from "@llamaindex/core/prompts"; import { extractText } from "@llamaindex/core/utils"; import { llmFromSettingsOrContext } from "../../Settings.js"; @@ -116,9 +114,9 @@ abstract class BaseKeywordTableRetriever implements BaseRetriever { export class KeywordTableLLMRetriever extends BaseKeywordTableRetriever { async getKeywords(query: string): Promise<string[]> { const response = await this.llm.complete({ - prompt: this.queryKeywordExtractTemplate({ + prompt: this.queryKeywordExtractTemplate.format({ question: query, - maxKeywords: this.maxKeywordsPerQuery, + maxKeywords: `${this.maxKeywordsPerQuery}`, }), }); const keywords = extractKeywordsGivenResponse(response.text, "KEYWORDS:"); @@ -256,7 +254,7 @@ export class KeywordTableIndex extends BaseIndex<KeywordTable> { const llm = llmFromSettingsOrContext(serviceContext); const response = await llm.complete({ - prompt: defaultKeywordExtractPrompt({ + prompt: defaultKeywordExtractPrompt.format({ context: text, }), }); diff --git a/packages/llamaindex/src/indices/summary/index.ts b/packages/llamaindex/src/indices/summary/index.ts index 7800070a8..512576c7b 100644 --- a/packages/llamaindex/src/indices/summary/index.ts +++ b/packages/llamaindex/src/indices/summary/index.ts @@ -1,3 +1,7 @@ +import { + type ChoiceSelectPrompt, + defaultChoiceSelectPrompt, +} from "@llamaindex/core/prompts"; import type { BaseNode, Document, @@ -5,8 +9,6 @@ import type { } from "@llamaindex/core/schema"; import { extractText, wrapEventCaller } from "@llamaindex/core/utils"; import _ from "lodash"; -import type { ChoiceSelectPrompt } from "../../Prompt.js"; -import { defaultChoiceSelectPrompt } from "../../Prompt.js"; import type { BaseRetriever, RetrieveParams } from "../../Retriever.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { @@ -345,7 +347,7 @@ export class SummaryIndexLLMRetriever implements BaseRetriever { const rawResponse = ( await llm.complete({ - prompt: this.choiceSelectPrompt(input), + prompt: this.choiceSelectPrompt.format(input), }) ).text; diff --git a/packages/llamaindex/src/outputParsers/selectors.ts b/packages/llamaindex/src/outputParsers/selectors.ts index 491f2e617..c2da7afbd 100644 --- a/packages/llamaindex/src/outputParsers/selectors.ts +++ b/packages/llamaindex/src/outputParsers/selectors.ts @@ -1,5 +1,6 @@ +import type { BaseOutputParser } from "@llamaindex/core/schema"; import { parseJsonMarkdown } from "../OutputParser.js"; -import type { BaseOutputParser, StructuredOutput } from "../types.js"; +import type { StructuredOutput } from "../types.js"; export type Answer = { choice: number; diff --git a/packages/llamaindex/src/prompts/Mixin.ts b/packages/llamaindex/src/prompts/Mixin.ts deleted file mode 100644 index 5d0942137..000000000 --- a/packages/llamaindex/src/prompts/Mixin.ts +++ /dev/null @@ -1,90 +0,0 @@ -type PromptsDict = Record<string, any>; -type ModuleDict = Record<string, any>; - -export class PromptMixin { - /** - * Validates the prompt keys and module keys - * @param promptsDict - * @param moduleDict - */ - validatePrompts(promptsDict: PromptsDict, moduleDict: ModuleDict): void { - for (const key in promptsDict) { - if (key.includes(":")) { - throw new Error(`Prompt key ${key} cannot contain ':'.`); - } - } - - for (const key in moduleDict) { - if (key.includes(":")) { - throw new Error(`Module key ${key} cannot contain ':'.`); - } - } - } - - /** - * Returns all prompts from the mixin and its modules - */ - getPrompts(): PromptsDict { - const promptsDict: PromptsDict = this._getPrompts(); - - const moduleDict = this._getPromptModules(); - - this.validatePrompts(promptsDict, moduleDict); - - const allPrompts: PromptsDict = { ...promptsDict }; - - for (const [module_name, prompt_module] of Object.entries(moduleDict)) { - for (const [key, prompt] of Object.entries(prompt_module.getPrompts())) { - allPrompts[`${module_name}:${key}`] = prompt; - } - } - return allPrompts; - } - - /** - * Updates the prompts in the mixin and its modules - * @param promptsDict - */ - updatePrompts(promptsDict: PromptsDict): void { - const promptModules = this._getPromptModules(); - - this._updatePrompts(promptsDict); - - const subPromptDicts: Record<string, PromptsDict> = {}; - - for (const key in promptsDict) { - if (key.includes(":")) { - const [module_name, sub_key] = key.split(":"); - - if (!subPromptDicts[module_name]) { - subPromptDicts[module_name] = {}; - } - subPromptDicts[module_name][sub_key] = promptsDict[key]; - } - } - - for (const [module_name, subPromptDict] of Object.entries(subPromptDicts)) { - if (!promptModules[module_name]) { - throw new Error(`Module ${module_name} not found.`); - } - - const moduleToUpdate = promptModules[module_name]; - - moduleToUpdate.updatePrompts(subPromptDict); - } - } - - // Must be implemented by subclasses - // fixme: says must but never implemented - protected _getPrompts(): PromptsDict { - return {}; - } - - protected _getPromptModules(): Record<string, any> { - return {}; - } - - protected _updatePrompts(promptsDict: PromptsDict): void { - return; - } -} diff --git a/packages/llamaindex/src/prompts/index.ts b/packages/llamaindex/src/prompts/index.ts deleted file mode 100644 index 958a2b0f3..000000000 --- a/packages/llamaindex/src/prompts/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "./Mixin.js"; diff --git a/packages/llamaindex/src/selectors/base.ts b/packages/llamaindex/src/selectors/base.ts index cb0d18873..c6fef7e28 100644 --- a/packages/llamaindex/src/selectors/base.ts +++ b/packages/llamaindex/src/selectors/base.ts @@ -1,5 +1,5 @@ +import { PromptMixin } from "@llamaindex/core/prompts"; import type { QueryType } from "@llamaindex/core/query-engine"; -import { PromptMixin } from "../prompts/Mixin.js"; import type { ToolMetadataOnlyDescription } from "../types.js"; export interface SingleSelection { diff --git a/packages/llamaindex/src/selectors/llmSelectors.ts b/packages/llamaindex/src/selectors/llmSelectors.ts index 7ca44b869..38db9116a 100644 --- a/packages/llamaindex/src/selectors/llmSelectors.ts +++ b/packages/llamaindex/src/selectors/llmSelectors.ts @@ -1,10 +1,11 @@ import type { LLM } from "@llamaindex/core/llms"; +import type { ModuleRecord } from "@llamaindex/core/prompts"; import type { QueryBundle } from "@llamaindex/core/query-engine"; +import type { BaseOutputParser } from "@llamaindex/core/schema"; import { extractText } from "@llamaindex/core/utils"; import type { Answer } from "../outputParsers/selectors.js"; import { SelectionOutputParser } from "../outputParsers/selectors.js"; import type { - BaseOutputParser, StructuredOutput, ToolMetadataOnlyDescription, } from "../types.js"; @@ -63,16 +64,20 @@ export class LLMMultiSelector extends BaseSelector { this.outputParser = init.outputParser ?? new SelectionOutputParser(); } - _getPrompts(): Record<string, MultiSelectPrompt> { + _getPrompts() { return { prompt: this.prompt }; } - _updatePrompts(prompts: Record<string, MultiSelectPrompt>): void { + _updatePrompts(prompts: { prompt: MultiSelectPrompt }) { if ("prompt" in prompts) { this.prompt = prompts.prompt; } } + protected _getPromptModules(): ModuleRecord { + throw new Error("Method not implemented."); + } + /** * Selects a single choice from a list of choices. * @param choices @@ -84,12 +89,12 @@ export class LLMMultiSelector extends BaseSelector { ): Promise<SelectorResult> { const choicesText = buildChoicesText(choices); - const prompt = this.prompt( - choicesText.length, - choicesText, - extractText(query.query), - this.maxOutputs, - ); + const prompt = this.prompt.format({ + contextList: choicesText, + query: extractText(query.query), + maxOutputs: `${this.maxOutputs}`, + numChoices: `${choicesText.length}`, + }); const formattedPrompt = this.outputParser?.format(prompt); @@ -151,11 +156,11 @@ export class LLMSingleSelector extends BaseSelector { ): Promise<SelectorResult> { const choicesText = buildChoicesText(choices); - const prompt = this.prompt( - choicesText.length, - choicesText, - extractText(query.query), - ); + const prompt = this.prompt.format({ + numChoices: `${choicesText.length}`, + context: choicesText, + query: extractText(query.query), + }); const formattedPrompt = this.outputParser.format(prompt); @@ -175,4 +180,8 @@ export class LLMSingleSelector extends BaseSelector { asQueryComponent(): unknown { throw new Error("Method not implemented."); } + + protected _getPromptModules() { + return {}; + } } diff --git a/packages/llamaindex/src/selectors/prompts.ts b/packages/llamaindex/src/selectors/prompts.ts index b91527121..90e59fd4d 100644 --- a/packages/llamaindex/src/selectors/prompts.ts +++ b/packages/llamaindex/src/selectors/prompts.ts @@ -1,30 +1,31 @@ -export const defaultSingleSelectPrompt = ( - numChoices: number, - contextList: string, - queryStr: string, -): string => { - return `Some choices are given below. It is provided in a numbered list (1 to ${numChoices}), where each item in the list corresponds to a summary. +import { PromptTemplate } from "@llamaindex/core/prompts"; + +export const defaultSingleSelectPrompt: SingleSelectPrompt = new PromptTemplate( + { + templateVars: ["context", "query", "numChoices"], + template: `Some choices are given below. It is provided in a numbered list (1 to {numChoices}), where each item in the list corresponds to a summary. --------------------- -${contextList} +{context} --------------------- -Using only the choices above and not prior knowledge, return the choice that is most relevant to the question: '${queryStr}' -`; -}; +Using only the choices above and not prior knowledge, return the choice that is most relevant to the question: '{query}' +`, + }, +); -export type SingleSelectPrompt = typeof defaultSingleSelectPrompt; +export type SingleSelectPrompt = PromptTemplate< + ["context", "query", "numChoices"] +>; -export const defaultMultiSelectPrompt = ( - numChoices: number, - contextList: string, - queryStr: string, - maxOutputs: number, -) => { - return `Some choices are given below. It is provided in a numbered list (1 to ${numChoices}), where each item in the list corresponds to a summary. +export const defaultMultiSelectPrompt: MultiSelectPrompt = new PromptTemplate({ + templateVars: ["contextList", "query", "maxOutputs", "numChoices"], + template: `Some choices are given below. It is provided in a numbered list (1 to {numChoices}), where each item in the list corresponds to a summary. --------------------- -${contextList} +{contextList} --------------------- -Using only the choices above and not prior knowledge, return the top choices (no more than ${maxOutputs}, but only select what is needed) that are most relevant to the question: '${queryStr}' -`; -}; +Using only the choices above and not prior knowledge, return the top choices (no more than {maxOutputs}, but only select what is needed) that are most relevant to the question: '{query}' +`, +}); -export type MultiSelectPrompt = typeof defaultMultiSelectPrompt; +export type MultiSelectPrompt = PromptTemplate< + ["contextList", "query", "maxOutputs", "numChoices"] +>; diff --git a/packages/llamaindex/src/synthesizers/MultiModalResponseSynthesizer.ts b/packages/llamaindex/src/synthesizers/MultiModalResponseSynthesizer.ts index d7cfd5a53..599f9eee8 100644 --- a/packages/llamaindex/src/synthesizers/MultiModalResponseSynthesizer.ts +++ b/packages/llamaindex/src/synthesizers/MultiModalResponseSynthesizer.ts @@ -1,10 +1,13 @@ +import { + defaultTextQAPrompt, + PromptMixin, + type ModuleRecord, + type TextQAPrompt, +} from "@llamaindex/core/prompts"; import { EngineResponse, MetadataMode } from "@llamaindex/core/schema"; import { streamConverter } from "@llamaindex/core/utils"; import type { ServiceContext } from "../ServiceContext.js"; import { llmFromSettingsOrContext } from "../Settings.js"; -import { PromptMixin } from "../prompts/Mixin.js"; -import type { TextQaPrompt } from "./../Prompt.js"; -import { defaultTextQaPrompt } from "./../Prompt.js"; import type { BaseSynthesizer, SynthesizeQuery } from "./types.js"; import { createMessageContent } from "./utils.js"; @@ -14,7 +17,7 @@ export class MultiModalResponseSynthesizer { serviceContext?: ServiceContext; metadataMode: MetadataMode; - textQATemplate: TextQaPrompt; + textQATemplate: TextQAPrompt; constructor({ serviceContext, @@ -25,17 +28,21 @@ export class MultiModalResponseSynthesizer this.serviceContext = serviceContext; this.metadataMode = metadataMode ?? MetadataMode.NONE; - this.textQATemplate = textQATemplate ?? defaultTextQaPrompt; + this.textQATemplate = textQATemplate ?? defaultTextQAPrompt; } - protected _getPrompts(): { textQATemplate: TextQaPrompt } { + protected _getPromptModules(): ModuleRecord { + return {}; + } + + protected _getPrompts(): { textQATemplate: TextQAPrompt } { return { textQATemplate: this.textQATemplate, }; } protected _updatePrompts(promptsDict: { - textQATemplate: TextQaPrompt; + textQATemplate: TextQAPrompt; }): void { if (promptsDict.textQATemplate) { this.textQATemplate = promptsDict.textQATemplate; diff --git a/packages/llamaindex/src/synthesizers/ResponseSynthesizer.ts b/packages/llamaindex/src/synthesizers/ResponseSynthesizer.ts index 7ddbedb27..78dee1696 100644 --- a/packages/llamaindex/src/synthesizers/ResponseSynthesizer.ts +++ b/packages/llamaindex/src/synthesizers/ResponseSynthesizer.ts @@ -1,8 +1,7 @@ +import { PromptMixin, type PromptsRecord } from "@llamaindex/core/prompts"; import { EngineResponse, MetadataMode } from "@llamaindex/core/schema"; import { streamConverter } from "@llamaindex/core/utils"; import type { ServiceContext } from "../ServiceContext.js"; -import { PromptMixin } from "../prompts/Mixin.js"; -import type { ResponseBuilderPrompts } from "./builders.js"; import { getResponseBuilder } from "./builders.js"; import type { BaseSynthesizer, @@ -40,17 +39,15 @@ export class ResponseSynthesizer return {}; } - protected _getPrompts(): { [x: string]: ResponseBuilderPrompts } { + protected _getPrompts() { const prompts = this.responseBuilder.getPrompts?.(); return { ...prompts, }; } - protected _updatePrompts(promptsDict: { - [x: string]: ResponseBuilderPrompts; - }): void { - this.responseBuilder.updatePrompts?.(promptsDict); + protected _updatePrompts(promptsRecord: PromptsRecord): void { + this.responseBuilder.updatePrompts?.(promptsRecord); } synthesize( diff --git a/packages/llamaindex/src/synthesizers/builders.ts b/packages/llamaindex/src/synthesizers/builders.ts index 46831b445..f5abc2335 100644 --- a/packages/llamaindex/src/synthesizers/builders.ts +++ b/packages/llamaindex/src/synthesizers/builders.ts @@ -1,20 +1,18 @@ import type { LLM } from "@llamaindex/core/llms"; -import type { QueryType } from "@llamaindex/core/query-engine"; -import { extractText, streamConverter } from "@llamaindex/core/utils"; -import type { - RefinePrompt, - SimplePrompt, - TextQaPrompt, - TreeSummarizePrompt, -} from "../Prompt.js"; import { + PromptMixin, defaultRefinePrompt, - defaultTextQaPrompt, + defaultTextQAPrompt, defaultTreeSummarizePrompt, -} from "../Prompt.js"; -import type { PromptHelper } from "../PromptHelper.js"; -import { getBiggestPrompt } from "../PromptHelper.js"; -import { PromptMixin } from "../prompts/Mixin.js"; + type ModuleRecord, + type PromptsRecord, + type RefinePrompt, + type TextQAPrompt, + type TreeSummarizePrompt, +} from "@llamaindex/core/prompts"; +import type { QueryType } from "@llamaindex/core/query-engine"; +import { extractText, streamConverter } from "@llamaindex/core/utils"; +import { getBiggestPrompt, type PromptHelper } from "../PromptHelper.js"; import type { ServiceContext } from "../ServiceContext.js"; import { llmFromSettingsOrContext, @@ -35,13 +33,31 @@ enum ResponseMode { /** * A response builder that just concatenates responses. */ -export class SimpleResponseBuilder implements ResponseBuilder { +export class SimpleResponseBuilder + extends PromptMixin + implements ResponseBuilder +{ llm: LLM; - textQATemplate: TextQaPrompt; + textQATemplate: TextQAPrompt; - constructor(serviceContext?: ServiceContext, textQATemplate?: TextQaPrompt) { + constructor(serviceContext?: ServiceContext, textQATemplate?: TextQAPrompt) { + super(); this.llm = llmFromSettingsOrContext(serviceContext); - this.textQATemplate = textQATemplate ?? defaultTextQaPrompt; + this.textQATemplate = textQATemplate ?? defaultTextQAPrompt; + } + + protected _getPrompts(): PromptsRecord { + return { + textQATemplate: this.textQATemplate, + }; + } + protected _updatePrompts(prompts: { textQATemplate: TextQAPrompt }): void { + if (prompts.textQATemplate) { + this.textQATemplate = prompts.textQATemplate; + } + } + protected _getPromptModules(): ModuleRecord { + return {}; } getResponse( @@ -53,12 +69,10 @@ export class SimpleResponseBuilder implements ResponseBuilder { { query, textChunks }: ResponseBuilderQuery, stream?: boolean, ): Promise<AsyncIterable<string> | string> { - const input = { + const prompt = this.textQATemplate.format({ query: extractText(query), context: textChunks.join("\n\n"), - }; - - const prompt = this.textQATemplate(input); + }); if (stream) { const response = await this.llm.complete({ prompt, stream }); return streamConverter(response, (chunk) => chunk.text); @@ -75,24 +89,28 @@ export class SimpleResponseBuilder implements ResponseBuilder { export class Refine extends PromptMixin implements ResponseBuilder { llm: LLM; promptHelper: PromptHelper; - textQATemplate: TextQaPrompt; + textQATemplate: TextQAPrompt; refineTemplate: RefinePrompt; constructor( serviceContext?: ServiceContext, - textQATemplate?: TextQaPrompt, + textQATemplate?: TextQAPrompt, refineTemplate?: RefinePrompt, ) { super(); this.llm = llmFromSettingsOrContext(serviceContext); this.promptHelper = promptHelperFromSettingsOrContext(serviceContext); - this.textQATemplate = textQATemplate ?? defaultTextQaPrompt; + this.textQATemplate = textQATemplate ?? defaultTextQAPrompt; this.refineTemplate = refineTemplate ?? defaultRefinePrompt; } + protected _getPromptModules(): ModuleRecord { + return {}; + } + protected _getPrompts(): { - textQATemplate: RefinePrompt; + textQATemplate: TextQAPrompt; refineTemplate: RefinePrompt; } { return { @@ -102,7 +120,7 @@ export class Refine extends PromptMixin implements ResponseBuilder { } protected _updatePrompts(prompts: { - textQATemplate: RefinePrompt; + textQATemplate: TextQAPrompt; refineTemplate: RefinePrompt; }): void { if (prompts.textQATemplate) { @@ -151,9 +169,10 @@ export class Refine extends PromptMixin implements ResponseBuilder { query: QueryType, textChunk: string, stream: boolean, - ) { - const textQATemplate: SimplePrompt = (input) => - this.textQATemplate({ ...input, query: extractText(query) }); + ): Promise<AsyncIterable<string> | string> { + const textQATemplate: TextQAPrompt = this.textQATemplate.partialFormat({ + query: extractText(query), + }); const textChunks = this.promptHelper.repack(textQATemplate, [textChunk]); let response: AsyncIterable<string> | string | undefined = undefined; @@ -163,7 +182,7 @@ export class Refine extends PromptMixin implements ResponseBuilder { const lastChunk = i === textChunks.length - 1; if (!response) { response = await this.complete({ - prompt: textQATemplate({ + prompt: textQATemplate.format({ context: chunk, }), stream: stream && lastChunk, @@ -178,7 +197,7 @@ export class Refine extends PromptMixin implements ResponseBuilder { } } - return response; + return response as AsyncIterable<string> | string; } // eslint-disable-next-line max-params @@ -188,8 +207,9 @@ export class Refine extends PromptMixin implements ResponseBuilder { textChunk: string, stream: boolean, ) { - const refineTemplate: SimplePrompt = (input) => - this.refineTemplate({ ...input, query: extractText(query) }); + const refineTemplate: RefinePrompt = this.refineTemplate.partialFormat({ + query: extractText(query), + }); const textChunks = this.promptHelper.repack(refineTemplate, [textChunk]); @@ -199,7 +219,7 @@ export class Refine extends PromptMixin implements ResponseBuilder { const chunk = textChunks[i]; const lastChunk = i === textChunks.length - 1; response = await this.complete({ - prompt: refineTemplate({ + prompt: refineTemplate.format({ context: chunk, existingAnswer: response as string, }), @@ -236,16 +256,12 @@ export class CompactAndRefine extends Refine { { query, textChunks, prevResponse }: ResponseBuilderQuery, stream?: boolean, ): Promise<AsyncIterable<string> | string> { - const textQATemplate: SimplePrompt = (input) => - this.textQATemplate({ - ...input, - query: extractText(query), - }); - const refineTemplate: SimplePrompt = (input) => - this.refineTemplate({ - ...input, - query: extractText(query), - }); + const textQATemplate: TextQAPrompt = this.textQATemplate.partialFormat({ + query: extractText(query), + }); + const refineTemplate: RefinePrompt = this.refineTemplate.partialFormat({ + query: extractText(query), + }); const maxPrompt = getBiggestPrompt([textQATemplate, refineTemplate]); const newTexts = this.promptHelper.repack(maxPrompt, textChunks); @@ -285,6 +301,10 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder { this.summaryTemplate = summaryTemplate ?? defaultTreeSummarizePrompt; } + protected _getPromptModules(): ModuleRecord { + return {}; + } + protected _getPrompts(): { summaryTemplate: TreeSummarizePrompt } { return { summaryTemplate: this.summaryTemplate, @@ -320,7 +340,7 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder { if (packedTextChunks.length === 1) { const params = { - prompt: this.summaryTemplate({ + prompt: this.summaryTemplate.format({ context: packedTextChunks[0], query: extractText(query), }), @@ -334,7 +354,7 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder { const summaries = await Promise.all( packedTextChunks.map((chunk) => this.llm.complete({ - prompt: this.summaryTemplate({ + prompt: this.summaryTemplate.format({ context: chunk, query: extractText(query), }), @@ -376,6 +396,6 @@ export function getResponseBuilder( } export type ResponseBuilderPrompts = - | TextQaPrompt + | TextQAPrompt | TreeSummarizePrompt | RefinePrompt; diff --git a/packages/llamaindex/src/synthesizers/types.ts b/packages/llamaindex/src/synthesizers/types.ts index 4fb269165..a31e65cf6 100644 --- a/packages/llamaindex/src/synthesizers/types.ts +++ b/packages/llamaindex/src/synthesizers/types.ts @@ -1,6 +1,6 @@ +import type { PromptMixin } from "@llamaindex/core/prompts"; import type { QueryType } from "@llamaindex/core/query-engine"; import { EngineResponse, type NodeWithScore } from "@llamaindex/core/schema"; -import type { PromptMixin } from "../prompts/Mixin.js"; export interface SynthesizeQuery { query: QueryType; @@ -11,7 +11,7 @@ export interface SynthesizeQuery { /** * A BaseSynthesizer is used to generate a response from a query and a list of nodes. */ -export interface BaseSynthesizer { +export interface BaseSynthesizer extends PromptMixin { synthesize( query: SynthesizeQuery, stream: true, @@ -28,7 +28,7 @@ export interface ResponseBuilderQuery { /** * A ResponseBuilder is used in a response synthesizer to generate a response from multiple response chunks. */ -export interface ResponseBuilder extends Partial<PromptMixin> { +export interface ResponseBuilder extends PromptMixin { /** * Get the response from a query and a list of text chunks. */ diff --git a/packages/llamaindex/src/synthesizers/utils.ts b/packages/llamaindex/src/synthesizers/utils.ts index 489c33efd..fee771dd8 100644 --- a/packages/llamaindex/src/synthesizers/utils.ts +++ b/packages/llamaindex/src/synthesizers/utils.ts @@ -1,4 +1,5 @@ import type { MessageContentDetail } from "@llamaindex/core/llms"; +import type { BasePromptTemplate } from "@llamaindex/core/prompts"; import { ImageNode, MetadataMode, @@ -6,11 +7,10 @@ import { splitNodesByType, type BaseNode, } from "@llamaindex/core/schema"; -import type { SimplePrompt } from "../Prompt.js"; import { imageToDataUrl } from "../internal/utils.js"; export async function createMessageContent( - prompt: SimplePrompt, + prompt: BasePromptTemplate, nodes: BaseNode[], extraParams: Record<string, string | undefined> = {}, metadataMode: MetadataMode = MetadataMode.NONE, @@ -37,7 +37,7 @@ export async function createMessageContent( // eslint-disable-next-line max-params async function createContentPerModality( - prompt: SimplePrompt, + prompt: BasePromptTemplate, type: ModalityType, nodes: BaseNode[], extraParams: Record<string, string | undefined>, @@ -48,7 +48,7 @@ async function createContentPerModality( return [ { type: "text", - text: prompt({ + text: prompt.format({ ...extraParams, context: nodes.map((r) => r.getContent(metadataMode)).join("\n\n"), }), diff --git a/packages/llamaindex/src/types.ts b/packages/llamaindex/src/types.ts index 20964586f..ad36e9b3e 100644 --- a/packages/llamaindex/src/types.ts +++ b/packages/llamaindex/src/types.ts @@ -33,15 +33,6 @@ export interface QueryEngine { query(params: QueryEngineParamsNonStreaming): Promise<EngineResponse>; } -/** - * An OutputParser is used to extract structured data from the raw output of the LLM. - */ -export interface BaseOutputParser<T> { - parse(output: string): T; - - format(output: string): string; -} - /** * StructuredOutput is just a combo of the raw output and the parsed output. */ diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 0dc87238c..692ada893 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -162,7 +162,7 @@ importers: version: link:../packages/llamaindex mongodb: specifier: ^6.7.0 - version: 6.8.0(@aws-sdk/credential-providers@3.637.0(@aws-sdk/client-sso-oidc@3.637.0(@aws-sdk/client-sts@3.637.0))) + version: 6.8.0(@aws-sdk/credential-providers@3.637.0) pathe: specifier: ^1.1.2 version: 1.1.2 @@ -383,6 +383,9 @@ importers: natural: specifier: ^8.0.1 version: 8.0.1(@aws-sdk/credential-providers@3.637.0) + python-format-js: + specifier: ^1.4.3 + version: 1.4.3 packages/core/tests: devDependencies: @@ -567,7 +570,7 @@ importers: version: 2.0.0 mongodb: specifier: ^6.7.0 - version: 6.8.0(@aws-sdk/credential-providers@3.637.0(@aws-sdk/client-sso-oidc@3.637.0(@aws-sdk/client-sts@3.637.0))) + version: 6.8.0(@aws-sdk/credential-providers@3.637.0) notion-md-crawler: specifier: ^1.0.0 version: 1.0.0(encoding@0.1.13) @@ -9143,6 +9146,9 @@ packages: resolution: {integrity: sha512-FLpr4flz5xZTSJxSeaheeMKN/EDzMdK7b8PTOC6a5PYFKTucWbdqjgqaEyH0shFiSJrVB1+Qqi4Tk19ccU6Aug==} engines: {node: '>=12.20'} + python-format-js@1.4.3: + resolution: {integrity: sha512-0iK5zP5HMf4F3Xc3Uo6hggPu4ylEQCKNoLXUYe3S1YfYkFG6DxGDO3KozCrySntAZTPmP9yRI+eMq0MXweHqIw==} + qs@6.11.0: resolution: {integrity: sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==} engines: {node: '>=0.6'} @@ -17719,16 +17725,6 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-module-utils@2.8.2(@typescript-eslint/parser@8.3.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-node@0.3.9)(eslint@8.57.0): - dependencies: - debug: 3.2.7 - optionalDependencies: - '@typescript-eslint/parser': 8.3.0(eslint@8.57.0)(typescript@5.5.4) - eslint: 8.57.0 - eslint-import-resolver-node: 0.3.9 - transitivePeerDependencies: - - supports-color - eslint-plugin-import@2.29.1(@typescript-eslint/parser@8.3.0(eslint@8.57.0)(typescript@5.5.4))(eslint@8.57.0): dependencies: array-includes: 3.1.8 @@ -17739,7 +17735,7 @@ snapshots: doctrine: 2.1.0 eslint: 8.57.0 eslint-import-resolver-node: 0.3.9 - eslint-module-utils: 2.8.2(@typescript-eslint/parser@8.3.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-node@0.3.9)(eslint@8.57.0) + eslint-module-utils: 2.8.2(@typescript-eslint/parser@7.2.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.3(@typescript-eslint/parser@7.2.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-node@0.3.9)(eslint-plugin-import@2.29.1)(eslint@8.57.0))(eslint@8.57.0) hasown: 2.0.2 is-core-module: 2.15.1 is-glob: 4.0.3 @@ -20467,7 +20463,7 @@ snapshots: optionalDependencies: '@aws-sdk/credential-providers': 3.637.0(@aws-sdk/client-sso-oidc@3.637.0(@aws-sdk/client-sts@3.637.0)) - mongodb@6.8.0(@aws-sdk/credential-providers@3.637.0(@aws-sdk/client-sso-oidc@3.637.0(@aws-sdk/client-sts@3.637.0))): + mongodb@6.8.0(@aws-sdk/credential-providers@3.637.0): dependencies: '@mongodb-js/saslprep': 1.1.7 bson: 6.8.0 @@ -21674,6 +21670,8 @@ snapshots: dependencies: escape-goat: 4.0.0 + python-format-js@1.4.3: {} + qs@6.11.0: dependencies: side-channel: 1.0.6 -- GitLab