From cfd6f3ca8c85b1176ae58150edfcefd3008732a3 Mon Sep 17 00:00:00 2001 From: Emanuel Ferreira <contatoferreirads@gmail.com> Date: Sun, 18 Feb 2024 18:44:08 -0300 Subject: [PATCH] feat: prompt mixin (#543) --- apps/docs/docs/modules/prompt/_category_.yml | 2 + apps/docs/docs/modules/prompt/index.md | 76 ++++++++++ examples/prompts/promptMixin.ts | 51 +++++++ packages/core/src/QuestionGenerator.ts | 22 ++- .../chat/CondenseQuestionChatEngine.ts | 24 ++- .../src/engines/chat/ContextChatEngine.ts | 11 +- .../engines/chat/DefaultContextGenerator.ts | 22 ++- .../src/engines/query/RetrieverQueryEngine.ts | 14 +- .../src/engines/query/RouterQueryEngine.ts | 19 ++- .../engines/query/SubQuestionQueryEngine.ts | 15 +- packages/core/src/index.ts | 1 + .../indices/vectorStore/VectorStoreIndex.ts | 2 +- packages/core/src/prompts/Mixin.ts | 89 +++++++++++ packages/core/src/prompts/index.ts | 1 + packages/core/src/selectors/base.ts | 3 +- packages/core/src/selectors/llmSelectors.ts | 108 +++++++------- .../MultiModalResponseSynthesizer.ts | 22 ++- .../src/synthesizers/ResponseSynthesizer.ts | 27 +++- packages/core/src/synthesizers/builders.ts | 51 ++++++- packages/core/src/synthesizers/types.ts | 3 +- packages/core/src/tests/prompts/Mixin.test.ts | 138 ++++++++++++++++++ 21 files changed, 631 insertions(+), 70 deletions(-) create mode 100644 apps/docs/docs/modules/prompt/_category_.yml create mode 100644 apps/docs/docs/modules/prompt/index.md create mode 100644 examples/prompts/promptMixin.ts create mode 100644 packages/core/src/prompts/Mixin.ts create mode 100644 packages/core/src/prompts/index.ts create mode 100644 packages/core/src/tests/prompts/Mixin.test.ts diff --git a/apps/docs/docs/modules/prompt/_category_.yml b/apps/docs/docs/modules/prompt/_category_.yml new file mode 100644 index 000000000..597aa54df --- /dev/null +++ b/apps/docs/docs/modules/prompt/_category_.yml @@ -0,0 +1,2 @@ +label: "Prompts" +position: 0 diff --git a/apps/docs/docs/modules/prompt/index.md b/apps/docs/docs/modules/prompt/index.md new file mode 100644 index 000000000..fa9374643 --- /dev/null +++ b/apps/docs/docs/modules/prompt/index.md @@ -0,0 +1,76 @@ +# Prompts + +Prompting is the fundamental input that gives LLMs their expressive power. LlamaIndex uses prompts to build the index, do insertion, perform traversal during querying, and to synthesize the final answer. + +Users may also provide their own prompt templates to further customize the behavior of the framework. The best method for customizing is copying the default prompt from the link above, and using that as the base for any modifications. + +## Usage Pattern + +Currently, there are two ways to customize prompts in LlamaIndex: + +For both methods, you will need to create an function that overrides the default prompt. + +```ts +// Define a custom prompt +const newTextQaPrompt: TextQaPrompt = ({ context, query }) => { + return `Context information is below. +--------------------- +${context} +--------------------- +Given the context information and not prior knowledge, answer the query. +Answer the query in the style of a Sherlock Holmes detective novel. +Query: ${query} +Answer:`; +}; +``` + +### 1. Customizing the default prompt on initialization + +The first method is to create a new instance of `ResponseSynthesizer` (or the module you would like to update the prompt) and pass the custom prompt to the `responseBuilder` parameter. Then, pass the instance to the `asQueryEngine` method of the index. + +```ts +// Create an instance of response synthesizer +const responseSynthesizer = new ResponseSynthesizer({ + responseBuilder: new CompactAndRefine(serviceContext, newTextQaPrompt), +}); + +// Create index +const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, +}); + +// Query the index +const queryEngine = index.asQueryEngine({ responseSynthesizer }); + +const response = await queryEngine.query({ + query: "What did the author do in college?", +}); +``` + +### 2. Customizing submodules prompt + +The second method is that most of the modules in LlamaIndex have a `getPrompts` and a `updatePrompt` method that allows you to override the default prompt. This method is useful when you want to change the prompt on the fly or in submodules on a more granular level. + +```ts +// Create index +const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, +}); + +// Query the index +const queryEngine = index.asQueryEngine(); + +// Get a list of prompts for the query engine +const prompts = queryEngine.getPrompts(); + +// output: { "responseSynthesizer:textQATemplate": defaultTextQaPrompt, "responseSynthesizer:refineTemplate": defaultRefineTemplatePrompt } + +// Now, we can override the default prompt +queryEngine.updatePrompt({ + "responseSynthesizer:textQATemplate": newTextQaPrompt, +}); + +const response = await queryEngine.query({ + query: "What did the author do in college?", +}); +``` diff --git a/examples/prompts/promptMixin.ts b/examples/prompts/promptMixin.ts new file mode 100644 index 000000000..d0b940796 --- /dev/null +++ b/examples/prompts/promptMixin.ts @@ -0,0 +1,51 @@ +import { + Document, + ResponseSynthesizer, + TreeSummarize, + TreeSummarizePrompt, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; + +const treeSummarizePrompt: TreeSummarizePrompt = ({ context, query }) => { + return `Context information from multiple sources is below. +--------------------- +${context} +--------------------- +Given the information from multiple sources and not prior knowledge. +Answer the query in the style of a Shakespeare play" +Query: ${query} +Answer:`; +}; + +async function main() { + const documents = new Document({ + text: "The quick brown fox jumps over the lazy dog", + }); + + const index = await VectorStoreIndex.fromDocuments([documents]); + + const query = "The quick brown fox jumps over the lazy dog"; + + const ctx = serviceContextFromDefaults({}); + + const responseSynthesizer = new ResponseSynthesizer({ + responseBuilder: new TreeSummarize(ctx), + }); + + const queryEngine = index.asQueryEngine({ + responseSynthesizer, + }); + + console.log({ + promptsToUse: queryEngine.getPrompts(), + }); + + queryEngine.updatePrompts({ + "responseSynthesizer:summaryTemplate": treeSummarizePrompt, + }); + + await queryEngine.query({ query }); +} + +main(); diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts index ae4c0feb0..5dc1fa9ae 100644 --- a/packages/core/src/QuestionGenerator.ts +++ b/packages/core/src/QuestionGenerator.ts @@ -7,22 +7,42 @@ import { import { BaseQuestionGenerator, SubQuestion } from "./engines/query/types"; import { OpenAI } from "./llm/LLM"; import { LLM } from "./llm/types"; +import { PromptMixin } from "./prompts"; import { BaseOutputParser, StructuredOutput, ToolMetadata } from "./types"; /** * LLMQuestionGenerator uses the LLM to generate new questions for the LLM using tools and a user query. */ -export class LLMQuestionGenerator implements BaseQuestionGenerator { +export class LLMQuestionGenerator + extends PromptMixin + implements BaseQuestionGenerator +{ llm: LLM; prompt: SubQuestionPrompt; outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>; constructor(init?: Partial<LLMQuestionGenerator>) { + super(); + this.llm = init?.llm ?? new OpenAI(); this.prompt = init?.prompt ?? defaultSubQuestionPrompt; this.outputParser = init?.outputParser ?? new SubQuestionOutputParser(); } + protected _getPrompts(): { [x: string]: SubQuestionPrompt } { + return { + subQuestion: this.prompt, + }; + } + + protected _updatePrompts(promptsDict: { + subQuestion: SubQuestionPrompt; + }): void { + if ("subQuestion" in promptsDict) { + this.prompt = promptsDict.subQuestion; + } + } + async generate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]> { const toolsStr = buildToolsText(tools); const queryStr = query; diff --git a/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts b/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts index 6d8e93b49..0c73a1017 100644 --- a/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts +++ b/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts @@ -11,6 +11,7 @@ import { } from "../../ServiceContext"; import { ChatMessage, LLM } from "../../llm"; import { extractText, streamReducer } from "../../llm/utils"; +import { PromptMixin } from "../../prompts"; import { BaseQueryEngine } from "../../types"; import { ChatEngine, @@ -29,7 +30,10 @@ import { * data, or are very referential to previous context. */ -export class CondenseQuestionChatEngine implements ChatEngine { +export class CondenseQuestionChatEngine + extends PromptMixin + implements ChatEngine +{ queryEngine: BaseQueryEngine; chatHistory: ChatHistory; llm: LLM; @@ -41,6 +45,8 @@ export class CondenseQuestionChatEngine implements ChatEngine { serviceContext?: ServiceContext; condenseMessagePrompt?: CondenseQuestionPrompt; }) { + super(); + this.queryEngine = init.queryEngine; this.chatHistory = getHistory(init?.chatHistory); this.llm = init?.serviceContext?.llm ?? serviceContextFromDefaults().llm; @@ -48,13 +54,27 @@ export class CondenseQuestionChatEngine implements ChatEngine { init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt; } + protected _getPrompts(): { condenseMessagePrompt: CondenseQuestionPrompt } { + return { + condenseMessagePrompt: this.condenseMessagePrompt, + }; + } + + protected _updatePrompts(promptsDict: { + condenseMessagePrompt: CondenseQuestionPrompt; + }): void { + if (promptsDict.condenseMessagePrompt) { + this.condenseMessagePrompt = promptsDict.condenseMessagePrompt; + } + } + private async condenseQuestion(chatHistory: ChatHistory, question: string) { const chatHistoryStr = messagesToHistoryStr( await chatHistory.requestMessages(), ); return this.llm.complete({ - prompt: defaultCondenseQuestionPrompt({ + prompt: this.condenseMessagePrompt({ question: question, chatHistory: chatHistoryStr, }), diff --git a/packages/core/src/engines/chat/ContextChatEngine.ts b/packages/core/src/engines/chat/ContextChatEngine.ts index 43cdd2af2..c56ce0c50 100644 --- a/packages/core/src/engines/chat/ContextChatEngine.ts +++ b/packages/core/src/engines/chat/ContextChatEngine.ts @@ -8,6 +8,7 @@ import { ChatMessage, ChatResponseChunk, LLM, OpenAI } from "../../llm"; import { MessageContent } from "../../llm/types"; import { extractText, streamConverter, streamReducer } from "../../llm/utils"; import { BaseNodePostprocessor } from "../../postprocessors"; +import { PromptMixin } from "../../prompts"; import { DefaultContextGenerator } from "./DefaultContextGenerator"; import { ChatEngine, @@ -21,7 +22,7 @@ import { * The context is stored in the system prompt, and the chat history is preserved, * ideally allowing the appropriate context to be surfaced for each query. */ -export class ContextChatEngine implements ChatEngine { +export class ContextChatEngine extends PromptMixin implements ChatEngine { chatModel: LLM; chatHistory: ChatHistory; contextGenerator: ContextGenerator; @@ -33,6 +34,8 @@ export class ContextChatEngine implements ChatEngine { contextSystemPrompt?: ContextSystemPrompt; nodePostprocessors?: BaseNodePostprocessor[]; }) { + super(); + this.chatModel = init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" }); this.chatHistory = getHistory(init?.chatHistory); @@ -43,6 +46,12 @@ export class ContextChatEngine implements ChatEngine { }); } + protected _getPromptModules(): Record<string, ContextGenerator> { + return { + contextGenerator: this.contextGenerator, + }; + } + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; chat(params: ChatEngineParamsNonStreaming): Promise<Response>; async chat( diff --git a/packages/core/src/engines/chat/DefaultContextGenerator.ts b/packages/core/src/engines/chat/DefaultContextGenerator.ts index 11566b85d..4b359ac0f 100644 --- a/packages/core/src/engines/chat/DefaultContextGenerator.ts +++ b/packages/core/src/engines/chat/DefaultContextGenerator.ts @@ -4,9 +4,13 @@ import { BaseRetriever } from "../../Retriever"; import { Event } from "../../callbacks/CallbackManager"; import { randomUUID } from "../../env"; import { BaseNodePostprocessor } from "../../postprocessors"; +import { PromptMixin } from "../../prompts"; import { Context, ContextGenerator } from "./types"; -export class DefaultContextGenerator implements ContextGenerator { +export class DefaultContextGenerator + extends PromptMixin + implements ContextGenerator +{ retriever: BaseRetriever; contextSystemPrompt: ContextSystemPrompt; nodePostprocessors: BaseNodePostprocessor[]; @@ -16,12 +20,28 @@ export class DefaultContextGenerator implements ContextGenerator { contextSystemPrompt?: ContextSystemPrompt; nodePostprocessors?: BaseNodePostprocessor[]; }) { + super(); + this.retriever = init.retriever; this.contextSystemPrompt = init?.contextSystemPrompt ?? defaultContextSystemPrompt; this.nodePostprocessors = init.nodePostprocessors || []; } + protected _getPrompts(): { contextSystemPrompt: ContextSystemPrompt } { + return { + contextSystemPrompt: this.contextSystemPrompt, + }; + } + + protected _updatePrompts(promptsDict: { + contextSystemPrompt: ContextSystemPrompt; + }): void { + if (promptsDict.contextSystemPrompt) { + this.contextSystemPrompt = promptsDict.contextSystemPrompt; + } + } + private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) { let nodesWithScore = nodes; diff --git a/packages/core/src/engines/query/RetrieverQueryEngine.ts b/packages/core/src/engines/query/RetrieverQueryEngine.ts index cbec0f7e3..94c5b86c4 100644 --- a/packages/core/src/engines/query/RetrieverQueryEngine.ts +++ b/packages/core/src/engines/query/RetrieverQueryEngine.ts @@ -5,6 +5,7 @@ import { ServiceContext } from "../../ServiceContext"; import { Event } from "../../callbacks/CallbackManager"; import { randomUUID } from "../../env"; import { BaseNodePostprocessor } from "../../postprocessors"; +import { PromptMixin } from "../../prompts"; import { BaseSynthesizer, ResponseSynthesizer } from "../../synthesizers"; import { BaseQueryEngine, @@ -15,7 +16,10 @@ import { /** * A query engine that uses a retriever to query an index and then synthesizes the response. */ -export class RetrieverQueryEngine implements BaseQueryEngine { +export class RetrieverQueryEngine + extends PromptMixin + implements BaseQueryEngine +{ retriever: BaseRetriever; responseSynthesizer: BaseSynthesizer; nodePostprocessors: BaseNodePostprocessor[]; @@ -27,6 +31,8 @@ export class RetrieverQueryEngine implements BaseQueryEngine { preFilters?: unknown, nodePostprocessors?: BaseNodePostprocessor[], ) { + super(); + this.retriever = retriever; const serviceContext: ServiceContext | undefined = this.retriever.getServiceContext(); @@ -36,6 +42,12 @@ export class RetrieverQueryEngine implements BaseQueryEngine { this.nodePostprocessors = nodePostprocessors || []; } + _getPromptModules() { + return { + responseSynthesizer: this.responseSynthesizer, + }; + } + private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) { let nodesWithScore = nodes; diff --git a/packages/core/src/engines/query/RouterQueryEngine.ts b/packages/core/src/engines/query/RouterQueryEngine.ts index babe01d14..d48a23cf4 100644 --- a/packages/core/src/engines/query/RouterQueryEngine.ts +++ b/packages/core/src/engines/query/RouterQueryEngine.ts @@ -1,8 +1,10 @@ +import { BaseNode } from "../../Node"; import { Response } from "../../Response"; import { ServiceContext, serviceContextFromDefaults, } from "../../ServiceContext"; +import { PromptMixin } from "../../prompts"; import { BaseSelector, LLMSingleSelector } from "../../selectors"; import { TreeSummarize } from "../../synthesizers"; import { @@ -31,8 +33,8 @@ async function combineResponses( console.log("Combining responses from multiple query engines."); } - const responseStrs = []; - const sourceNodes = []; + const responseStrs: string[] = []; + const sourceNodes: BaseNode[] = []; for (const response of responses) { if (response?.sourceNodes) { @@ -53,7 +55,7 @@ async function combineResponses( /** * A query engine that uses multiple query engines and selects the best one. */ -export class RouterQueryEngine implements BaseQueryEngine { +export class RouterQueryEngine extends PromptMixin implements BaseQueryEngine { serviceContext: ServiceContext; private selector: BaseSelector; @@ -69,6 +71,8 @@ export class RouterQueryEngine implements BaseQueryEngine { summarizer?: TreeSummarize; verbose?: boolean; }) { + super(); + this.serviceContext = init.serviceContext || serviceContextFromDefaults({}); this.selector = init.selector; this.queryEngines = init.queryEngineTools.map((tool) => tool.queryEngine); @@ -79,6 +83,13 @@ export class RouterQueryEngine implements BaseQueryEngine { this.verbose = init.verbose ?? false; } + _getPromptModules(): Record<string, any> { + return { + selector: this.selector, + summarizer: this.summarizer, + }; + } + static fromDefaults(init: { queryEngineTools: RouterQueryEngineTool[]; selector?: BaseSelector; @@ -119,7 +130,7 @@ export class RouterQueryEngine implements BaseQueryEngine { const result = await this.selector.select(this.metadatas, queryBundle); if (result.selections.length > 1) { - const responses = []; + const responses: Response[] = []; for (let i = 0; i < result.selections.length; i++) { const engineInd = result.selections[i]; const logStr = `Selecting query engine ${engineInd}: ${result.selections[i]}.`; diff --git a/packages/core/src/engines/query/SubQuestionQueryEngine.ts b/packages/core/src/engines/query/SubQuestionQueryEngine.ts index 73d0ada6c..0139ecd0a 100644 --- a/packages/core/src/engines/query/SubQuestionQueryEngine.ts +++ b/packages/core/src/engines/query/SubQuestionQueryEngine.ts @@ -7,6 +7,7 @@ import { } from "../../ServiceContext"; import { Event } from "../../callbacks/CallbackManager"; import { randomUUID } from "../../env"; +import { PromptMixin } from "../../prompts"; import { BaseSynthesizer, CompactAndRefine, @@ -26,7 +27,10 @@ import { BaseQuestionGenerator, SubQuestion } from "./types"; /** * SubQuestionQueryEngine decomposes a question into subquestions and then */ -export class SubQuestionQueryEngine implements BaseQueryEngine { +export class SubQuestionQueryEngine + extends PromptMixin + implements BaseQueryEngine +{ responseSynthesizer: BaseSynthesizer; questionGen: BaseQuestionGenerator; queryEngines: BaseTool[]; @@ -37,6 +41,8 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { responseSynthesizer: BaseSynthesizer; queryEngineTools: BaseTool[]; }) { + super(); + this.questionGen = init.questionGen; this.responseSynthesizer = init.responseSynthesizer ?? new ResponseSynthesizer(); @@ -44,6 +50,13 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { this.metadatas = init.queryEngineTools.map((tool) => tool.metadata); } + protected _getPromptModules(): Record<string, any> { + return { + questionGen: this.questionGen, + responseSynthesizer: this.responseSynthesizer, + }; + } + static fromDefaults(init: { queryEngineTools: BaseTool[]; questionGen?: BaseQuestionGenerator; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index f2449885b..0af9e53cb 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -22,6 +22,7 @@ export * from "./llm"; export * from "./nodeParsers"; export * from "./objects"; export * from "./postprocessors"; +export * from "./prompts"; export * from "./readers"; export * from "./selectors"; export * from "./storage"; diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index 9515bab44..137d1073a 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -270,7 +270,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { responseSynthesizer?: BaseSynthesizer; preFilters?: MetadataFilters; nodePostprocessors?: BaseNodePostprocessor[]; - }): BaseQueryEngine { + }): BaseQueryEngine & RetrieverQueryEngine { const { retriever, responseSynthesizer } = options ?? {}; return new RetrieverQueryEngine( retriever ?? this.asRetriever(), diff --git a/packages/core/src/prompts/Mixin.ts b/packages/core/src/prompts/Mixin.ts new file mode 100644 index 000000000..eac847b91 --- /dev/null +++ b/packages/core/src/prompts/Mixin.ts @@ -0,0 +1,89 @@ +type PromptsDict = Record<string, any>; +type ModuleDict = Record<string, any>; + +export class PromptMixin { + /** + * Validates the prompt keys and module keys + * @param promptsDict + * @param moduleDict + */ + validatePrompts(promptsDict: PromptsDict, moduleDict: ModuleDict): void { + for (let key in promptsDict) { + if (key.includes(":")) { + throw new Error(`Prompt key ${key} cannot contain ':'.`); + } + } + + for (let key in moduleDict) { + if (key.includes(":")) { + throw new Error(`Module key ${key} cannot contain ':'.`); + } + } + } + + /** + * Returns all prompts from the mixin and its modules + */ + getPrompts(): PromptsDict { + let promptsDict: PromptsDict = this._getPrompts(); + + let moduleDict = this._getPromptModules(); + + this.validatePrompts(promptsDict, moduleDict); + + let allPrompts: PromptsDict = { ...promptsDict }; + + for (let [module_name, prompt_module] of Object.entries(moduleDict)) { + for (let [key, prompt] of Object.entries(prompt_module.getPrompts())) { + allPrompts[`${module_name}:${key}`] = prompt; + } + } + return allPrompts; + } + + /** + * Updates the prompts in the mixin and its modules + * @param promptsDict + */ + updatePrompts(promptsDict: PromptsDict): void { + let promptModules = this._getPromptModules(); + + this._updatePrompts(promptsDict); + + let subPromptDicts: Record<string, PromptsDict> = {}; + + for (let key in promptsDict) { + if (key.includes(":")) { + let [module_name, sub_key] = key.split(":"); + + if (!subPromptDicts[module_name]) { + subPromptDicts[module_name] = {}; + } + subPromptDicts[module_name][sub_key] = promptsDict[key]; + } + } + + for (let [module_name, subPromptDict] of Object.entries(subPromptDicts)) { + if (!promptModules[module_name]) { + throw new Error(`Module ${module_name} not found.`); + } + + let moduleToUpdate = promptModules[module_name]; + + moduleToUpdate.updatePrompts(subPromptDict); + } + } + + // Must be implemented by subclasses + protected _getPrompts(): PromptsDict { + return {}; + } + + protected _getPromptModules(): Record<string, any> { + return {}; + } + + protected _updatePrompts(promptsDict: PromptsDict): void { + return; + } +} diff --git a/packages/core/src/prompts/index.ts b/packages/core/src/prompts/index.ts new file mode 100644 index 000000000..fb7823314 --- /dev/null +++ b/packages/core/src/prompts/index.ts @@ -0,0 +1 @@ +export * from "./Mixin"; diff --git a/packages/core/src/selectors/base.ts b/packages/core/src/selectors/base.ts index 74bd89e37..3268d55e1 100644 --- a/packages/core/src/selectors/base.ts +++ b/packages/core/src/selectors/base.ts @@ -1,3 +1,4 @@ +import { PromptMixin } from "../prompts"; import { QueryBundle, ToolMetadataOnlyDescription } from "../types"; export interface SingleSelection { @@ -31,7 +32,7 @@ function wrapQuery(query: QueryType): QueryBundle { type MetadataType = string | ToolMetadataOnlyDescription; -export abstract class BaseSelector { +export abstract class BaseSelector extends PromptMixin { async select(choices: MetadataType[], query: QueryType) { const metadatas = choices.map((choice) => wrapChoice(choice)); const queryBundle = wrapQuery(query); diff --git a/packages/core/src/selectors/llmSelectors.ts b/packages/core/src/selectors/llmSelectors.ts index 74acdc87f..c1a51c6e1 100644 --- a/packages/core/src/selectors/llmSelectors.ts +++ b/packages/core/src/selectors/llmSelectors.ts @@ -1,4 +1,3 @@ -import { DefaultPromptTemplate } from "../extractors/prompts"; import { LLM } from "../llm"; import { Answer, SelectionOutputParser } from "../outputParsers/selectors"; import { @@ -8,7 +7,12 @@ import { ToolMetadataOnlyDescription, } from "../types"; import { BaseSelector, SelectorResult } from "./base"; -import { defaultSingleSelectPrompt } from "./prompts"; +import { + MultiSelectPrompt, + SingleSelectPrompt, + defaultMultiSelectPrompt, + defaultSingleSelectPrompt, +} from "./prompts"; function buildChoicesText(choices: ToolMetadataOnlyDescription[]): string { const texts: string[] = []; @@ -20,7 +24,7 @@ function buildChoicesText(choices: ToolMetadataOnlyDescription[]): string { return texts.join(""); } -function _structuredOutputToSelectorResult( +function structuredOutputToSelectorResult( output: StructuredOutput<Answer[]>, ): SelectorResult { const structuredOutput = output; @@ -40,32 +44,32 @@ type LLMPredictorType = LLM; * A selector that uses the LLM to select a single or multiple choices from a list of choices. */ export class LLMMultiSelector extends BaseSelector { - _llm: LLMPredictorType; - _prompt: DefaultPromptTemplate | undefined; - _maxOutputs: number | null; - _outputParser: BaseOutputParser<any> | null; + llm: LLMPredictorType; + prompt: MultiSelectPrompt; + maxOutputs: number; + outputParser: BaseOutputParser<StructuredOutput<Answer[]>> | null; constructor(init: { llm: LLMPredictorType; - prompt?: DefaultPromptTemplate; + prompt?: MultiSelectPrompt; maxOutputs?: number; - outputParser?: BaseOutputParser<any>; + outputParser?: BaseOutputParser<StructuredOutput<Answer[]>>; }) { super(); - this._llm = init.llm; - this._prompt = init.prompt; - this._maxOutputs = init.maxOutputs ?? null; + this.llm = init.llm; + this.prompt = init.prompt ?? defaultMultiSelectPrompt; + this.maxOutputs = init.maxOutputs ?? 10; - this._outputParser = init.outputParser ?? new SelectionOutputParser(); + this.outputParser = init.outputParser ?? new SelectionOutputParser(); } - _getPrompts(): Record<string, any> { - return { prompt: this._prompt }; + _getPrompts(): Record<string, MultiSelectPrompt> { + return { prompt: this.prompt }; } - _updatePrompts(prompts: Record<string, any>): void { + _updatePrompts(prompts: Record<string, MultiSelectPrompt>): void { if ("prompt" in prompts) { - this._prompt = prompts.prompt; + this.prompt = prompts.prompt; } } @@ -80,22 +84,26 @@ export class LLMMultiSelector extends BaseSelector { ): Promise<SelectorResult> { const choicesText = buildChoicesText(choices); - const prompt = - this._prompt?.contextStr ?? - defaultSingleSelectPrompt( - choicesText.length, - choicesText, - query.queryStr, - ); - const formattedPrompt = this._outputParser?.format(prompt); + const prompt = this.prompt( + choicesText.length, + choicesText, + query.queryStr, + this.maxOutputs, + ); + + const formattedPrompt = this.outputParser?.format(prompt); - const prediction = await this._llm.complete({ + const prediction = await this.llm.complete({ prompt: formattedPrompt, }); - const parsed = this._outputParser?.parse(prediction.text); + const parsed = this.outputParser?.parse(prediction.text); + + if (!parsed) { + throw new Error("Parsed output is undefined"); + } - return _structuredOutputToSelectorResult(parsed); + return structuredOutputToSelectorResult(parsed); } asQueryComponent(): unknown { @@ -107,28 +115,28 @@ export class LLMMultiSelector extends BaseSelector { * A selector that uses the LLM to select a single choice from a list of choices. */ export class LLMSingleSelector extends BaseSelector { - _llm: LLMPredictorType; - _prompt: DefaultPromptTemplate | undefined; - _outputParser: BaseOutputParser<any> | null; + llm: LLMPredictorType; + prompt: SingleSelectPrompt; + outputParser: BaseOutputParser<StructuredOutput<Answer[]>> | null; constructor(init: { llm: LLMPredictorType; - prompt?: DefaultPromptTemplate; - outputParser?: BaseOutputParser<any>; + prompt?: SingleSelectPrompt; + outputParser?: BaseOutputParser<StructuredOutput<Answer[]>>; }) { super(); - this._llm = init.llm; - this._prompt = init.prompt; - this._outputParser = init.outputParser ?? new SelectionOutputParser(); + this.llm = init.llm; + this.prompt = init.prompt ?? defaultSingleSelectPrompt; + this.outputParser = init.outputParser ?? new SelectionOutputParser(); } - _getPrompts(): Record<string, any> { - return { prompt: this._prompt }; + _getPrompts(): Record<string, SingleSelectPrompt> { + return { prompt: this.prompt }; } - _updatePrompts(prompts: Record<string, any>): void { + _updatePrompts(prompts: Record<string, SingleSelectPrompt>): void { if ("prompt" in prompts) { - this._prompt = prompts.prompt; + this.prompt = prompts.prompt; } } @@ -143,23 +151,21 @@ export class LLMSingleSelector extends BaseSelector { ): Promise<SelectorResult> { const choicesText = buildChoicesText(choices); - const prompt = - this._prompt?.contextStr ?? - defaultSingleSelectPrompt( - choicesText.length, - choicesText, - query.queryStr, - ); + const prompt = this.prompt(choicesText.length, choicesText, query.queryStr); - const formattedPrompt = this._outputParser?.format(prompt); + const formattedPrompt = this.outputParser?.format(prompt); - const prediction = await this._llm.complete({ + const prediction = await this.llm.complete({ prompt: formattedPrompt, }); - const parsed = this._outputParser?.parse(prediction.text); + const parsed = this.outputParser?.parse(prediction.text); + + if (!parsed) { + throw new Error("Parsed output is undefined"); + } - return _structuredOutputToSelectorResult(parsed); + return structuredOutputToSelectorResult(parsed); } asQueryComponent(): unknown { diff --git a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts index 307138350..a6a442a9d 100644 --- a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts @@ -3,6 +3,7 @@ import { Response } from "../Response"; import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; import { imageToDataUrl } from "../embeddings"; import { MessageContentDetail } from "../llm/types"; +import { PromptMixin } from "../prompts"; import { TextQaPrompt, defaultTextQaPrompt } from "./../Prompt"; import { BaseSynthesizer, @@ -10,7 +11,10 @@ import { SynthesizeParamsStreaming, } from "./types"; -export class MultiModalResponseSynthesizer implements BaseSynthesizer { +export class MultiModalResponseSynthesizer + extends PromptMixin + implements BaseSynthesizer +{ serviceContext: ServiceContext; metadataMode: MetadataMode; textQATemplate: TextQaPrompt; @@ -20,11 +24,27 @@ export class MultiModalResponseSynthesizer implements BaseSynthesizer { textQATemplate, metadataMode, }: Partial<MultiModalResponseSynthesizer> = {}) { + super(); + this.serviceContext = serviceContext ?? serviceContextFromDefaults(); this.metadataMode = metadataMode ?? MetadataMode.NONE; this.textQATemplate = textQATemplate ?? defaultTextQaPrompt; } + protected _getPrompts(): { textQATemplate: TextQaPrompt } { + return { + textQATemplate: this.textQATemplate, + }; + } + + protected _updatePrompts(promptsDict: { + textQATemplate: TextQaPrompt; + }): void { + if (promptsDict.textQATemplate) { + this.textQATemplate = promptsDict.textQATemplate; + } + } + synthesize( params: SynthesizeParamsStreaming, ): Promise<AsyncIterable<Response>>; diff --git a/packages/core/src/synthesizers/ResponseSynthesizer.ts b/packages/core/src/synthesizers/ResponseSynthesizer.ts index 61ed392fe..7b28677b4 100644 --- a/packages/core/src/synthesizers/ResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/ResponseSynthesizer.ts @@ -2,7 +2,8 @@ import { MetadataMode } from "../Node"; import { Response } from "../Response"; import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; import { streamConverter } from "../llm/utils"; -import { getResponseBuilder } from "./builders"; +import { PromptMixin } from "../prompts"; +import { ResponseBuilderPrompts, getResponseBuilder } from "./builders"; import { BaseSynthesizer, ResponseBuilder, @@ -13,7 +14,10 @@ import { /** * A ResponseSynthesizer is used to generate a response from a query and a list of nodes. */ -export class ResponseSynthesizer implements BaseSynthesizer { +export class ResponseSynthesizer + extends PromptMixin + implements BaseSynthesizer +{ responseBuilder: ResponseBuilder; serviceContext: ServiceContext; metadataMode: MetadataMode; @@ -27,12 +31,31 @@ export class ResponseSynthesizer implements BaseSynthesizer { serviceContext?: ServiceContext; metadataMode?: MetadataMode; } = {}) { + super(); + this.serviceContext = serviceContext ?? serviceContextFromDefaults(); this.responseBuilder = responseBuilder ?? getResponseBuilder(this.serviceContext); this.metadataMode = metadataMode; } + _getPromptModules() { + return {}; + } + + protected _getPrompts(): { [x: string]: ResponseBuilderPrompts } { + const prompts = this.responseBuilder.getPrompts?.(); + return { + ...prompts, + }; + } + + protected _updatePrompts(promptsDict: { + [x: string]: ResponseBuilderPrompts; + }): void { + this.responseBuilder.updatePrompts?.(promptsDict); + } + synthesize( params: SynthesizeParamsStreaming, ): Promise<AsyncIterable<Response>>; diff --git a/packages/core/src/synthesizers/builders.ts b/packages/core/src/synthesizers/builders.ts index 67f1173b2..edec2cfdb 100644 --- a/packages/core/src/synthesizers/builders.ts +++ b/packages/core/src/synthesizers/builders.ts @@ -11,6 +11,7 @@ import { TreeSummarizePrompt, } from "../Prompt"; import { getBiggestPrompt, PromptHelper } from "../PromptHelper"; +import { PromptMixin } from "../prompts"; import { ServiceContext } from "../ServiceContext"; import { ResponseBuilder, @@ -73,7 +74,7 @@ export class SimpleResponseBuilder implements ResponseBuilder { /** * A response builder that uses the query to ask the LLM generate a better response using multiple text chunks. */ -export class Refine implements ResponseBuilder { +export class Refine extends PromptMixin implements ResponseBuilder { llm: LLM; promptHelper: PromptHelper; textQATemplate: TextQaPrompt; @@ -84,12 +85,37 @@ export class Refine implements ResponseBuilder { textQATemplate?: TextQaPrompt, refineTemplate?: RefinePrompt, ) { + super(); + this.llm = serviceContext.llm; this.promptHelper = serviceContext.promptHelper; this.textQATemplate = textQATemplate ?? defaultTextQaPrompt; this.refineTemplate = refineTemplate ?? defaultRefinePrompt; } + protected _getPrompts(): { + textQATemplate: RefinePrompt; + refineTemplate: RefinePrompt; + } { + return { + textQATemplate: this.textQATemplate, + refineTemplate: this.refineTemplate, + }; + } + + protected _updatePrompts(prompts: { + textQATemplate: RefinePrompt; + refineTemplate: RefinePrompt; + }): void { + if (prompts.textQATemplate) { + this.textQATemplate = prompts.textQATemplate; + } + + if (prompts.refineTemplate) { + this.refineTemplate = prompts.refineTemplate; + } + } + getResponse( params: ResponseBuilderParamsStreaming, ): Promise<AsyncIterable<string>>; @@ -258,7 +284,7 @@ export class CompactAndRefine extends Refine { /** * TreeSummarize repacks the text chunks into the smallest possible number of chunks and then summarizes them, then recursively does so until there's one chunk left. */ -export class TreeSummarize implements ResponseBuilder { +export class TreeSummarize extends PromptMixin implements ResponseBuilder { llm: LLM; promptHelper: PromptHelper; summaryTemplate: TreeSummarizePrompt; @@ -267,11 +293,27 @@ export class TreeSummarize implements ResponseBuilder { serviceContext: ServiceContext, summaryTemplate?: TreeSummarizePrompt, ) { + super(); + this.llm = serviceContext.llm; this.promptHelper = serviceContext.promptHelper; this.summaryTemplate = summaryTemplate ?? defaultTreeSummarizePrompt; } + protected _getPrompts(): { summaryTemplate: TreeSummarizePrompt } { + return { + summaryTemplate: this.summaryTemplate, + }; + } + + protected _updatePrompts(prompts: { + summaryTemplate: TreeSummarizePrompt; + }): void { + if (prompts.summaryTemplate) { + this.summaryTemplate = prompts.summaryTemplate; + } + } + getResponse( params: ResponseBuilderParamsStreaming, ): Promise<AsyncIterable<string>>; @@ -352,3 +394,8 @@ export function getResponseBuilder( return new CompactAndRefine(serviceContext); } } + +export type ResponseBuilderPrompts = + | TextQaPrompt + | TreeSummarizePrompt + | RefinePrompt; diff --git a/packages/core/src/synthesizers/types.ts b/packages/core/src/synthesizers/types.ts index a188a6e76..307982836 100644 --- a/packages/core/src/synthesizers/types.ts +++ b/packages/core/src/synthesizers/types.ts @@ -1,5 +1,6 @@ import { Event } from "../callbacks/CallbackManager"; import { NodeWithScore } from "../Node"; +import { PromptMixin } from "../prompts"; import { Response } from "../Response"; export interface SynthesizeParamsBase { @@ -46,7 +47,7 @@ export interface ResponseBuilderParamsNonStreaming /** * A ResponseBuilder is used in a response synthesizer to generate a response from multiple response chunks. */ -export interface ResponseBuilder { +export interface ResponseBuilder extends Partial<PromptMixin> { /** * Get the response from a query and a list of text chunks. * @param params diff --git a/packages/core/src/tests/prompts/Mixin.test.ts b/packages/core/src/tests/prompts/Mixin.test.ts new file mode 100644 index 000000000..ba9293ce6 --- /dev/null +++ b/packages/core/src/tests/prompts/Mixin.test.ts @@ -0,0 +1,138 @@ +import { PromptMixin } from "../../prompts"; + +type MockPrompt = { + context: string; + query: string; +}; + +const mockPrompt = ({ context, query }: MockPrompt) => + `context: ${context} query: ${query}`; + +const mockPrompt2 = ({ context, query }: MockPrompt) => + `query: ${query} context: ${context}`; + +type MockPromptFunction = typeof mockPrompt; + +class MockObject2 extends PromptMixin { + _prompt_dict_2: MockPromptFunction; + + constructor() { + super(); + this._prompt_dict_2 = mockPrompt; + } + + protected _getPrompts(): { [x: string]: MockPromptFunction } { + return { + abc: this._prompt_dict_2, + }; + } + + _updatePrompts(promptsDict: { [x: string]: MockPromptFunction }): void { + if ("abc" in promptsDict) { + this._prompt_dict_2 = promptsDict["abc"]; + } + } +} + +class MockObject1 extends PromptMixin { + mockObject2: MockObject2; + + fooPrompt: MockPromptFunction; + barPrompt: MockPromptFunction; + + constructor() { + super(); + this.mockObject2 = new MockObject2(); + this.fooPrompt = mockPrompt; + this.barPrompt = mockPrompt; + } + + protected _getPrompts(): { [x: string]: any } { + return { + bar: this.barPrompt, + foo: this.fooPrompt, + }; + } + + protected _getPromptModules(): { [x: string]: any } { + return { mock_object_2: this.mockObject2 }; + } + + _updatePrompts(promptsDict: { [x: string]: any }): void { + if ("bar" in promptsDict) { + this.barPrompt = promptsDict["bar"]; + } + if ("foo" in promptsDict) { + this.fooPrompt = promptsDict["foo"]; + } + } +} + +describe("PromptMixin", () => { + it("should return prompts", () => { + const mockObj1 = new MockObject1(); + + const prompts = mockObj1.getPrompts(); + + expect( + mockObj1.fooPrompt({ + context: "{foo}", + query: "{foo}", + }), + ).toEqual("context: {foo} query: {foo}"); + + expect( + mockObj1.barPrompt({ + context: "{foo} {bar}", + query: "{foo} {bar}", + }), + ).toEqual("context: {foo} {bar} query: {foo} {bar}"); + + expect(mockObj1.fooPrompt).toEqual(prompts.foo); + expect(mockObj1.barPrompt).toEqual(prompts.bar); + + expect(mockObj1.getPrompts()).toEqual({ + bar: mockPrompt, + foo: mockPrompt, + "mock_object_2:abc": mockPrompt, + }); + }); + + it("should update prompts", () => { + const mockObj1 = new MockObject1(); + + expect( + mockObj1.barPrompt({ + context: "{foo} {bar}", + query: "{foo} {bar}", + }), + ).toEqual(mockPrompt({ context: "{foo} {bar}", query: "{foo} {bar}" })); + + expect( + mockObj1.mockObject2._prompt_dict_2({ + context: "{bar} testing", + query: "{bar} testing", + }), + ).toEqual(mockPrompt({ context: "{bar} testing", query: "{bar} testing" })); + + mockObj1.updatePrompts({ + bar: mockPrompt2, + "mock_object_2:abc": mockPrompt2, + }); + + expect( + mockObj1.barPrompt({ + context: "{foo} {bar}", + query: "{bar} {foo}", + }), + ).toEqual(mockPrompt2({ context: "{foo} {bar}", query: "{bar} {foo}" })); + expect( + mockObj1.mockObject2._prompt_dict_2({ + context: "{bar} testing", + query: "{bar} testing", + }), + ).toEqual( + mockPrompt2({ context: "{bar} testing", query: "{bar} testing" }), + ); + }); +}); -- GitLab