diff --git a/.changeset/plenty-wolves-smile.md b/.changeset/plenty-wolves-smile.md new file mode 100644 index 0000000000000000000000000000000000000000..9dad46ce525e6e5ef20dbc91fbe547846655eaba --- /dev/null +++ b/.changeset/plenty-wolves-smile.md @@ -0,0 +1,6 @@ +--- +"llamaindex": minor +"@llamaindex/core": minor +--- + +update metadata extractors to use PromptTemplate diff --git a/packages/core/src/prompts/index.ts b/packages/core/src/prompts/index.ts index 992411a5b9a8be89b2562da4503aa92f6cd637cb..bfcf42cdcce869cd13f744f8326a90eac1f4c4db 100644 --- a/packages/core/src/prompts/index.ts +++ b/packages/core/src/prompts/index.ts @@ -12,12 +12,15 @@ export { defaultCondenseQuestionPrompt, defaultContextSystemPrompt, defaultKeywordExtractPrompt, + defaultNodeTextTemplate, defaultQueryKeywordExtractPrompt, defaultQuestionExtractPrompt, defaultRefinePrompt, defaultSubQuestionPrompt, defaultSummaryPrompt, defaultTextQAPrompt, + defaultTitleCombinePromptTemplate, + defaultTitleExtractorPromptTemplate, defaultTreeSummarizePrompt, } from "./prompt"; export type { @@ -31,5 +34,7 @@ export type { SubQuestionPrompt, SummaryPrompt, TextQAPrompt, + TitleCombinePrompt, + TitleExtractorPrompt, TreeSummarizePrompt, } from "./prompt"; diff --git a/packages/core/src/prompts/prompt.ts b/packages/core/src/prompts/prompt.ts index 1efa0289f12a96fb9734fca71d5b6981bf9d663d..1ba69f7295e222f0502301794691725a982b818e 100644 --- a/packages/core/src/prompts/prompt.ts +++ b/packages/core/src/prompts/prompt.ts @@ -13,9 +13,12 @@ export type CondenseQuestionPrompt = PromptTemplate< ["chatHistory", "question"] >; export type ContextSystemPrompt = PromptTemplate<["context"]>; -export type KeywordExtractPrompt = PromptTemplate<["context"]>; +export type KeywordExtractPrompt = PromptTemplate<["context", "maxKeywords"]>; export type QueryKeywordExtractPrompt = PromptTemplate<["question"]>; export type QuestionExtractPrompt = PromptTemplate<["context", "numQuestions"]>; +export type TitleExtractorPrompt = PromptTemplate<["context"]>; +export type TitleCombinePrompt = PromptTemplate<["context"]>; +export type KeywordExtractorPrompt = PromptTemplate<["context", "numKeywords"]>; export const defaultTextQAPrompt: TextQAPrompt = new PromptTemplate({ templateVars: ["context", "query"], @@ -268,3 +271,41 @@ export const defaultQuestionExtractPrompt = new PromptTemplate({ }).partialFormat({ numQuestions: "5", }); + +export const defaultTitleExtractorPromptTemplate = new PromptTemplate({ + templateVars: ["context"], + template: `{context} +Give a title that summarizes all of the unique entities, titles or themes found in the context. +Title: `, +}); + +export const defaultTitleCombinePromptTemplate = new PromptTemplate({ + templateVars: ["context"], + template: `{context} +Based on the above candidate titles and contents, what is the comprehensive title for this document? +Title: `, +}); + +export const defaultKeywordExtractorPromptTemplate = new PromptTemplate({ + templateVars: ["context", "numKeywords"], + template: `{context} +Give {numKeywords} unique keywords for this document. +Format as comma separated. +Keywords: `, +}).partialFormat({ + keywordCount: "5", +}); + +export const defaultNodeTextTemplate = new PromptTemplate({ + templateVars: ["metadataStr", "content"], + template: `[Excerpt from document] +{metadataStr} +Excerpt: +----- +{content} +----- +`, +}).partialFormat({ + metadataStr: "", + content: "", +}); diff --git a/packages/llamaindex/src/extractors/MetadataExtractors.ts b/packages/llamaindex/src/extractors/MetadataExtractors.ts index e43ef941521cfda2602cde871639bc66cbc15a5b..2c75d54295c025a6ae3ab3578cd8ec5a020451dc 100644 --- a/packages/llamaindex/src/extractors/MetadataExtractors.ts +++ b/packages/llamaindex/src/extractors/MetadataExtractors.ts @@ -1,18 +1,20 @@ import type { LLM } from "@llamaindex/core/llms"; import { PromptTemplate, + defaultKeywordExtractPrompt, defaultQuestionExtractPrompt, + defaultSummaryPrompt, + defaultTitleCombinePromptTemplate, + defaultTitleExtractorPromptTemplate, + type KeywordExtractPrompt, type QuestionExtractPrompt, + type SummaryPrompt, + type TitleCombinePrompt, + type TitleExtractorPrompt, } from "@llamaindex/core/prompts"; import type { BaseNode } from "@llamaindex/core/schema"; import { MetadataMode, TextNode } from "@llamaindex/core/schema"; import { OpenAI } from "@llamaindex/openai"; -import { - defaultKeywordExtractorPromptTemplate, - defaultSummaryExtractorPromptTemplate, - defaultTitleCombinePromptTemplate, - defaultTitleExtractorPromptTemplate, -} from "./prompts.js"; import { BaseExtractor } from "./types.js"; const STRIP_REGEX = /(\r\n|\n|\r)/gm; @@ -20,6 +22,7 @@ const STRIP_REGEX = /(\r\n|\n|\r)/gm; type KeywordExtractArgs = { llm?: LLM; keywords?: number; + promptTemplate?: KeywordExtractPrompt["template"]; }; type ExtractKeyword = { @@ -43,6 +46,12 @@ export class KeywordExtractor extends BaseExtractor { */ keywords: number = 5; + /** + * The prompt template to use for the question extractor. + * @type {string} + */ + promptTemplate: KeywordExtractPrompt; + /** * Constructor for the KeywordExtractor class. * @param {LLM} llm LLM instance. @@ -57,6 +66,12 @@ export class KeywordExtractor extends BaseExtractor { this.llm = options?.llm ?? new OpenAI(); this.keywords = options?.keywords ?? 5; + this.promptTemplate = options?.promptTemplate + ? new PromptTemplate({ + templateVars: ["context", "maxKeywords"], + template: options.promptTemplate, + }) + : defaultKeywordExtractPrompt; } /** @@ -70,9 +85,9 @@ export class KeywordExtractor extends BaseExtractor { } const completion = await this.llm.complete({ - prompt: defaultKeywordExtractorPromptTemplate({ - contextStr: node.getContent(MetadataMode.ALL), - keywords: this.keywords, + prompt: this.promptTemplate.format({ + context: node.getContent(MetadataMode.ALL), + maxKeywords: this.keywords.toString(), }), }); @@ -97,8 +112,8 @@ export class KeywordExtractor extends BaseExtractor { type TitleExtractorsArgs = { llm?: LLM; nodes?: number; - nodeTemplate?: string; - combineTemplate?: string; + nodeTemplate?: TitleExtractorPrompt["template"]; + combineTemplate?: TitleCombinePrompt["template"]; }; type ExtractTitle = { @@ -133,19 +148,19 @@ export class TitleExtractor extends BaseExtractor { * The prompt template to use for the title extractor. * @type {string} */ - nodeTemplate: string; + nodeTemplate: TitleExtractorPrompt; /** * The prompt template to merge title with.. * @type {string} */ - combineTemplate: string; + combineTemplate: TitleCombinePrompt; /** * Constructor for the TitleExtractor class. * @param {LLM} llm LLM instance. * @param {number} nodes Number of nodes to extract titles from. - * @param {string} nodeTemplate The prompt template to use for the title extractor. + * @param {TitleExtractorPrompt} nodeTemplate The prompt template to use for the title extractor. * @param {string} combineTemplate The prompt template to merge title with.. */ constructor(options?: TitleExtractorsArgs) { @@ -154,10 +169,19 @@ export class TitleExtractor extends BaseExtractor { this.llm = options?.llm ?? new OpenAI(); this.nodes = options?.nodes ?? 5; - this.nodeTemplate = - options?.nodeTemplate ?? defaultTitleExtractorPromptTemplate(); - this.combineTemplate = - options?.combineTemplate ?? defaultTitleCombinePromptTemplate(); + this.nodeTemplate = options?.nodeTemplate + ? new PromptTemplate({ + templateVars: ["context"], + template: options.nodeTemplate, + }) + : defaultTitleExtractorPromptTemplate; + + this.combineTemplate = options?.combineTemplate + ? new PromptTemplate({ + templateVars: ["context"], + template: options.combineTemplate, + }) + : defaultTitleCombinePromptTemplate; } /** @@ -222,8 +246,8 @@ export class TitleExtractor extends BaseExtractor { const titleCandidates = await this.getTitlesCandidates(nodes); const combinedTitles = titleCandidates.join(", "); const completion = await this.llm.complete({ - prompt: defaultTitleCombinePromptTemplate({ - contextStr: combinedTitles, + prompt: this.combineTemplate.format({ + context: combinedTitles, }), }); @@ -236,8 +260,8 @@ export class TitleExtractor extends BaseExtractor { private async getTitlesCandidates(nodes: BaseNode[]): Promise<string[]> { const titleJobs = nodes.map(async (node) => { const completion = await this.llm.complete({ - prompt: defaultTitleExtractorPromptTemplate({ - contextStr: node.getContent(MetadataMode.ALL), + prompt: this.nodeTemplate.format({ + context: node.getContent(MetadataMode.ALL), }), }); @@ -362,7 +386,7 @@ export class QuestionsAnsweredExtractor extends BaseExtractor { type SummaryExtractArgs = { llm?: LLM; summaries?: string[]; - promptTemplate?: string; + promptTemplate?: SummaryPrompt["template"]; }; type ExtractSummary = { @@ -391,7 +415,7 @@ export class SummaryExtractor extends BaseExtractor { * The prompt template to use for the summary extractor. * @type {string} */ - promptTemplate: string; + promptTemplate: SummaryPrompt; private selfSummary: boolean; private prevSummary: boolean; @@ -410,8 +434,12 @@ export class SummaryExtractor extends BaseExtractor { this.llm = options?.llm ?? new OpenAI(); this.summaries = summaries; - this.promptTemplate = - options?.promptTemplate ?? defaultSummaryExtractorPromptTemplate(); + this.promptTemplate = options?.promptTemplate + ? new PromptTemplate({ + templateVars: ["context"], + template: options.promptTemplate, + }) + : defaultSummaryPrompt; this.selfSummary = summaries?.includes("self") ?? false; this.prevSummary = summaries?.includes("prev") ?? false; @@ -428,10 +456,10 @@ export class SummaryExtractor extends BaseExtractor { return ""; } - const contextStr = node.getContent(this.metadataMode); + const context = node.getContent(this.metadataMode); - const prompt = defaultSummaryExtractorPromptTemplate({ - contextStr, + const prompt = this.promptTemplate.format({ + context, }); const summary = await this.llm.complete({ diff --git a/packages/llamaindex/src/extractors/prompts.ts b/packages/llamaindex/src/extractors/prompts.ts deleted file mode 100644 index 0065e5d834887588caa3da1c80f6e5634285d48c..0000000000000000000000000000000000000000 --- a/packages/llamaindex/src/extractors/prompts.ts +++ /dev/null @@ -1,59 +0,0 @@ -export interface DefaultPromptTemplate { - contextStr: string; -} - -export interface DefaultKeywordExtractorPromptTemplate - extends DefaultPromptTemplate { - keywords: number; -} - -export interface DefaultNodeTextTemplate { - metadataStr: string; - content: string; -} - -export const defaultKeywordExtractorPromptTemplate = ({ - contextStr = "", - keywords = 5, -}: DefaultKeywordExtractorPromptTemplate) => `${contextStr} -Give ${keywords} unique keywords for this document. -Format as comma separated. -Keywords: `; - -export const defaultTitleExtractorPromptTemplate = ( - { contextStr = "" }: DefaultPromptTemplate = { - contextStr: "", - }, -) => `${contextStr} -Give a title that summarizes all of the unique entities, titles or themes found in the context. -Title: `; - -export const defaultTitleCombinePromptTemplate = ( - { contextStr = "" }: DefaultPromptTemplate = { - contextStr: "", - }, -) => `${contextStr} -Based on the above candidate titles and contents, what is the comprehensive title for this document? -Title: `; - -export const defaultSummaryExtractorPromptTemplate = ( - { contextStr = "" }: DefaultPromptTemplate = { - contextStr: "", - }, -) => `${contextStr} -Summarize the key topics and entities of the sections. -Summary: `; - -export const defaultNodeTextTemplate = ({ - metadataStr = "", - content = "", -}: { - metadataStr?: string; - content?: string; -} = {}) => `[Excerpt from document] -${metadataStr} -Excerpt: ------ -${content} ------ -`; diff --git a/packages/llamaindex/src/extractors/types.ts b/packages/llamaindex/src/extractors/types.ts index ae8dbce602e7aa889951435c1111a0bdc9d9674a..8984109e08741458ebc73996773712a80ffce185 100644 --- a/packages/llamaindex/src/extractors/types.ts +++ b/packages/llamaindex/src/extractors/types.ts @@ -1,10 +1,10 @@ +import { defaultNodeTextTemplate } from "@llamaindex/core/prompts"; import { BaseNode, MetadataMode, TextNode, TransformComponent, } from "@llamaindex/core/schema"; -import { defaultNodeTextTemplate } from "./prompts.js"; /* * Abstract class for all extractors. @@ -71,7 +71,7 @@ export abstract class BaseExtractor extends TransformComponent { if (newNodes[idx] instanceof TextNode) { newNodes[idx] = new TextNode({ ...newNodes[idx], - textTemplate: defaultNodeTextTemplate(), + textTemplate: defaultNodeTextTemplate.format(), }); } } diff --git a/packages/llamaindex/tests/MetadataExtractors.test.ts b/packages/llamaindex/tests/MetadataExtractors.test.ts index 6cda9693866157d6aad6265cfa8b3df1c2da1b9e..e35537c1ed0cc4e218016a3ef7d2454751cfe719 100644 --- a/packages/llamaindex/tests/MetadataExtractors.test.ts +++ b/packages/llamaindex/tests/MetadataExtractors.test.ts @@ -148,4 +148,33 @@ describe("[MetadataExtractor]: Extractors should populate the metadata", () => { sectionSummary: DEFAULT_LLM_TEXT_OUTPUT, }); }); + + test("[KeywordExtractor] KeywordExtractor uses custom prompt template", async () => { + const nodeParser = new SentenceSplitter(); + + const nodes = nodeParser.getNodesFromDocuments([ + new Document({ text: DEFAULT_LLM_TEXT_OUTPUT }), + ]); + + const llmCompleteSpy = vi.spyOn(serviceContext.llm, "complete"); + + const keywordExtractor = new KeywordExtractor({ + llm: serviceContext.llm, + keywords: 5, + promptTemplate: `This is a custom prompt template for {context} with {maxKeywords} keywords`, + }); + + await keywordExtractor.processNodes(nodes); + + expect(llmCompleteSpy).toHaveBeenCalled(); + + // Build the expected prompt + const expectedPrompt = `This is a custom prompt template for ${DEFAULT_LLM_TEXT_OUTPUT} with 5 keywords`; + + // Get the actual prompt used in llm.complete + const actualPrompt = llmCompleteSpy.mock?.calls?.[0]?.[0]; + + // Assert that the prompts match + expect(actualPrompt).toEqual({ prompt: expectedPrompt }); + }); });