diff --git a/apps/docs/docs/modules/prompt/_category_.yml b/apps/docs/docs/modules/prompt/_category_.yml new file mode 100644 index 0000000000000000000000000000000000000000..597aa54df0dbc6bb47f50c02f3621573d55530d1 --- /dev/null +++ b/apps/docs/docs/modules/prompt/_category_.yml @@ -0,0 +1,2 @@ +label: "Prompts" +position: 0 diff --git a/apps/docs/docs/modules/prompt/index.md b/apps/docs/docs/modules/prompt/index.md new file mode 100644 index 0000000000000000000000000000000000000000..fa93746430263f84393434b1a4e6e0c39027d77c --- /dev/null +++ b/apps/docs/docs/modules/prompt/index.md @@ -0,0 +1,76 @@ +# Prompts + +Prompting is the fundamental input that gives LLMs their expressive power. LlamaIndex uses prompts to build the index, do insertion, perform traversal during querying, and to synthesize the final answer. + +Users may also provide their own prompt templates to further customize the behavior of the framework. The best method for customizing is copying the default prompt from the link above, and using that as the base for any modifications. + +## Usage Pattern + +Currently, there are two ways to customize prompts in LlamaIndex: + +For both methods, you will need to create an function that overrides the default prompt. + +```ts +// Define a custom prompt +const newTextQaPrompt: TextQaPrompt = ({ context, query }) => { + return `Context information is below. +--------------------- +${context} +--------------------- +Given the context information and not prior knowledge, answer the query. +Answer the query in the style of a Sherlock Holmes detective novel. +Query: ${query} +Answer:`; +}; +``` + +### 1. Customizing the default prompt on initialization + +The first method is to create a new instance of `ResponseSynthesizer` (or the module you would like to update the prompt) and pass the custom prompt to the `responseBuilder` parameter. Then, pass the instance to the `asQueryEngine` method of the index. + +```ts +// Create an instance of response synthesizer +const responseSynthesizer = new ResponseSynthesizer({ + responseBuilder: new CompactAndRefine(serviceContext, newTextQaPrompt), +}); + +// Create index +const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, +}); + +// Query the index +const queryEngine = index.asQueryEngine({ responseSynthesizer }); + +const response = await queryEngine.query({ + query: "What did the author do in college?", +}); +``` + +### 2. Customizing submodules prompt + +The second method is that most of the modules in LlamaIndex have a `getPrompts` and a `updatePrompt` method that allows you to override the default prompt. This method is useful when you want to change the prompt on the fly or in submodules on a more granular level. + +```ts +// Create index +const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, +}); + +// Query the index +const queryEngine = index.asQueryEngine(); + +// Get a list of prompts for the query engine +const prompts = queryEngine.getPrompts(); + +// output: { "responseSynthesizer:textQATemplate": defaultTextQaPrompt, "responseSynthesizer:refineTemplate": defaultRefineTemplatePrompt } + +// Now, we can override the default prompt +queryEngine.updatePrompt({ + "responseSynthesizer:textQATemplate": newTextQaPrompt, +}); + +const response = await queryEngine.query({ + query: "What did the author do in college?", +}); +``` diff --git a/examples/prompts/promptMixin.ts b/examples/prompts/promptMixin.ts new file mode 100644 index 0000000000000000000000000000000000000000..d0b94079670b0be5e39e54b4239984c6c01a1295 --- /dev/null +++ b/examples/prompts/promptMixin.ts @@ -0,0 +1,51 @@ +import { + Document, + ResponseSynthesizer, + TreeSummarize, + TreeSummarizePrompt, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; + +const treeSummarizePrompt: TreeSummarizePrompt = ({ context, query }) => { + return `Context information from multiple sources is below. +--------------------- +${context} +--------------------- +Given the information from multiple sources and not prior knowledge. +Answer the query in the style of a Shakespeare play" +Query: ${query} +Answer:`; +}; + +async function main() { + const documents = new Document({ + text: "The quick brown fox jumps over the lazy dog", + }); + + const index = await VectorStoreIndex.fromDocuments([documents]); + + const query = "The quick brown fox jumps over the lazy dog"; + + const ctx = serviceContextFromDefaults({}); + + const responseSynthesizer = new ResponseSynthesizer({ + responseBuilder: new TreeSummarize(ctx), + }); + + const queryEngine = index.asQueryEngine({ + responseSynthesizer, + }); + + console.log({ + promptsToUse: queryEngine.getPrompts(), + }); + + queryEngine.updatePrompts({ + "responseSynthesizer:summaryTemplate": treeSummarizePrompt, + }); + + await queryEngine.query({ query }); +} + +main(); diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts index ae4c0feb0e94ebbe26ba03a50fcfce3cc398af01..5dc1fa9ae34257bb661ba754d08f9bf289685f80 100644 --- a/packages/core/src/QuestionGenerator.ts +++ b/packages/core/src/QuestionGenerator.ts @@ -7,22 +7,42 @@ import { import { BaseQuestionGenerator, SubQuestion } from "./engines/query/types"; import { OpenAI } from "./llm/LLM"; import { LLM } from "./llm/types"; +import { PromptMixin } from "./prompts"; import { BaseOutputParser, StructuredOutput, ToolMetadata } from "./types"; /** * LLMQuestionGenerator uses the LLM to generate new questions for the LLM using tools and a user query. */ -export class LLMQuestionGenerator implements BaseQuestionGenerator { +export class LLMQuestionGenerator + extends PromptMixin + implements BaseQuestionGenerator +{ llm: LLM; prompt: SubQuestionPrompt; outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>; constructor(init?: Partial<LLMQuestionGenerator>) { + super(); + this.llm = init?.llm ?? new OpenAI(); this.prompt = init?.prompt ?? defaultSubQuestionPrompt; this.outputParser = init?.outputParser ?? new SubQuestionOutputParser(); } + protected _getPrompts(): { [x: string]: SubQuestionPrompt } { + return { + subQuestion: this.prompt, + }; + } + + protected _updatePrompts(promptsDict: { + subQuestion: SubQuestionPrompt; + }): void { + if ("subQuestion" in promptsDict) { + this.prompt = promptsDict.subQuestion; + } + } + async generate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]> { const toolsStr = buildToolsText(tools); const queryStr = query; diff --git a/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts b/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts index 6d8e93b491c6ab692391539f4b12e32229b06fbc..0c73a1017f3b6fb9ddd69c78fe1f60b5faca06d9 100644 --- a/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts +++ b/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts @@ -11,6 +11,7 @@ import { } from "../../ServiceContext"; import { ChatMessage, LLM } from "../../llm"; import { extractText, streamReducer } from "../../llm/utils"; +import { PromptMixin } from "../../prompts"; import { BaseQueryEngine } from "../../types"; import { ChatEngine, @@ -29,7 +30,10 @@ import { * data, or are very referential to previous context. */ -export class CondenseQuestionChatEngine implements ChatEngine { +export class CondenseQuestionChatEngine + extends PromptMixin + implements ChatEngine +{ queryEngine: BaseQueryEngine; chatHistory: ChatHistory; llm: LLM; @@ -41,6 +45,8 @@ export class CondenseQuestionChatEngine implements ChatEngine { serviceContext?: ServiceContext; condenseMessagePrompt?: CondenseQuestionPrompt; }) { + super(); + this.queryEngine = init.queryEngine; this.chatHistory = getHistory(init?.chatHistory); this.llm = init?.serviceContext?.llm ?? serviceContextFromDefaults().llm; @@ -48,13 +54,27 @@ export class CondenseQuestionChatEngine implements ChatEngine { init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt; } + protected _getPrompts(): { condenseMessagePrompt: CondenseQuestionPrompt } { + return { + condenseMessagePrompt: this.condenseMessagePrompt, + }; + } + + protected _updatePrompts(promptsDict: { + condenseMessagePrompt: CondenseQuestionPrompt; + }): void { + if (promptsDict.condenseMessagePrompt) { + this.condenseMessagePrompt = promptsDict.condenseMessagePrompt; + } + } + private async condenseQuestion(chatHistory: ChatHistory, question: string) { const chatHistoryStr = messagesToHistoryStr( await chatHistory.requestMessages(), ); return this.llm.complete({ - prompt: defaultCondenseQuestionPrompt({ + prompt: this.condenseMessagePrompt({ question: question, chatHistory: chatHistoryStr, }), diff --git a/packages/core/src/engines/chat/ContextChatEngine.ts b/packages/core/src/engines/chat/ContextChatEngine.ts index 43cdd2af257e5e7d92fef98e3f20e86f16cc8ff1..c56ce0c50e35db7974ed8e2b47f0b3b55e4961e1 100644 --- a/packages/core/src/engines/chat/ContextChatEngine.ts +++ b/packages/core/src/engines/chat/ContextChatEngine.ts @@ -8,6 +8,7 @@ import { ChatMessage, ChatResponseChunk, LLM, OpenAI } from "../../llm"; import { MessageContent } from "../../llm/types"; import { extractText, streamConverter, streamReducer } from "../../llm/utils"; import { BaseNodePostprocessor } from "../../postprocessors"; +import { PromptMixin } from "../../prompts"; import { DefaultContextGenerator } from "./DefaultContextGenerator"; import { ChatEngine, @@ -21,7 +22,7 @@ import { * The context is stored in the system prompt, and the chat history is preserved, * ideally allowing the appropriate context to be surfaced for each query. */ -export class ContextChatEngine implements ChatEngine { +export class ContextChatEngine extends PromptMixin implements ChatEngine { chatModel: LLM; chatHistory: ChatHistory; contextGenerator: ContextGenerator; @@ -33,6 +34,8 @@ export class ContextChatEngine implements ChatEngine { contextSystemPrompt?: ContextSystemPrompt; nodePostprocessors?: BaseNodePostprocessor[]; }) { + super(); + this.chatModel = init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" }); this.chatHistory = getHistory(init?.chatHistory); @@ -43,6 +46,12 @@ export class ContextChatEngine implements ChatEngine { }); } + protected _getPromptModules(): Record<string, ContextGenerator> { + return { + contextGenerator: this.contextGenerator, + }; + } + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; chat(params: ChatEngineParamsNonStreaming): Promise<Response>; async chat( diff --git a/packages/core/src/engines/chat/DefaultContextGenerator.ts b/packages/core/src/engines/chat/DefaultContextGenerator.ts index 11566b85dd6b896f0a4c6c40e7efd97aab1212a8..4b359ac0fea4de9a4c48e8284ecaab9fdbae0881 100644 --- a/packages/core/src/engines/chat/DefaultContextGenerator.ts +++ b/packages/core/src/engines/chat/DefaultContextGenerator.ts @@ -4,9 +4,13 @@ import { BaseRetriever } from "../../Retriever"; import { Event } from "../../callbacks/CallbackManager"; import { randomUUID } from "../../env"; import { BaseNodePostprocessor } from "../../postprocessors"; +import { PromptMixin } from "../../prompts"; import { Context, ContextGenerator } from "./types"; -export class DefaultContextGenerator implements ContextGenerator { +export class DefaultContextGenerator + extends PromptMixin + implements ContextGenerator +{ retriever: BaseRetriever; contextSystemPrompt: ContextSystemPrompt; nodePostprocessors: BaseNodePostprocessor[]; @@ -16,12 +20,28 @@ export class DefaultContextGenerator implements ContextGenerator { contextSystemPrompt?: ContextSystemPrompt; nodePostprocessors?: BaseNodePostprocessor[]; }) { + super(); + this.retriever = init.retriever; this.contextSystemPrompt = init?.contextSystemPrompt ?? defaultContextSystemPrompt; this.nodePostprocessors = init.nodePostprocessors || []; } + protected _getPrompts(): { contextSystemPrompt: ContextSystemPrompt } { + return { + contextSystemPrompt: this.contextSystemPrompt, + }; + } + + protected _updatePrompts(promptsDict: { + contextSystemPrompt: ContextSystemPrompt; + }): void { + if (promptsDict.contextSystemPrompt) { + this.contextSystemPrompt = promptsDict.contextSystemPrompt; + } + } + private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) { let nodesWithScore = nodes; diff --git a/packages/core/src/engines/query/RetrieverQueryEngine.ts b/packages/core/src/engines/query/RetrieverQueryEngine.ts index cbec0f7e3d7cae19ecfca811e173419f79eb5c9e..94c5b86c452c8f483d0cdd3d8d3583c3ac732da6 100644 --- a/packages/core/src/engines/query/RetrieverQueryEngine.ts +++ b/packages/core/src/engines/query/RetrieverQueryEngine.ts @@ -5,6 +5,7 @@ import { ServiceContext } from "../../ServiceContext"; import { Event } from "../../callbacks/CallbackManager"; import { randomUUID } from "../../env"; import { BaseNodePostprocessor } from "../../postprocessors"; +import { PromptMixin } from "../../prompts"; import { BaseSynthesizer, ResponseSynthesizer } from "../../synthesizers"; import { BaseQueryEngine, @@ -15,7 +16,10 @@ import { /** * A query engine that uses a retriever to query an index and then synthesizes the response. */ -export class RetrieverQueryEngine implements BaseQueryEngine { +export class RetrieverQueryEngine + extends PromptMixin + implements BaseQueryEngine +{ retriever: BaseRetriever; responseSynthesizer: BaseSynthesizer; nodePostprocessors: BaseNodePostprocessor[]; @@ -27,6 +31,8 @@ export class RetrieverQueryEngine implements BaseQueryEngine { preFilters?: unknown, nodePostprocessors?: BaseNodePostprocessor[], ) { + super(); + this.retriever = retriever; const serviceContext: ServiceContext | undefined = this.retriever.getServiceContext(); @@ -36,6 +42,12 @@ export class RetrieverQueryEngine implements BaseQueryEngine { this.nodePostprocessors = nodePostprocessors || []; } + _getPromptModules() { + return { + responseSynthesizer: this.responseSynthesizer, + }; + } + private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) { let nodesWithScore = nodes; diff --git a/packages/core/src/engines/query/RouterQueryEngine.ts b/packages/core/src/engines/query/RouterQueryEngine.ts index babe01d14729e057e2e40948002ef657498dcfe3..d48a23cf4cc0bbe7fe32a46334b5732c69f9bdf7 100644 --- a/packages/core/src/engines/query/RouterQueryEngine.ts +++ b/packages/core/src/engines/query/RouterQueryEngine.ts @@ -1,8 +1,10 @@ +import { BaseNode } from "../../Node"; import { Response } from "../../Response"; import { ServiceContext, serviceContextFromDefaults, } from "../../ServiceContext"; +import { PromptMixin } from "../../prompts"; import { BaseSelector, LLMSingleSelector } from "../../selectors"; import { TreeSummarize } from "../../synthesizers"; import { @@ -31,8 +33,8 @@ async function combineResponses( console.log("Combining responses from multiple query engines."); } - const responseStrs = []; - const sourceNodes = []; + const responseStrs: string[] = []; + const sourceNodes: BaseNode[] = []; for (const response of responses) { if (response?.sourceNodes) { @@ -53,7 +55,7 @@ async function combineResponses( /** * A query engine that uses multiple query engines and selects the best one. */ -export class RouterQueryEngine implements BaseQueryEngine { +export class RouterQueryEngine extends PromptMixin implements BaseQueryEngine { serviceContext: ServiceContext; private selector: BaseSelector; @@ -69,6 +71,8 @@ export class RouterQueryEngine implements BaseQueryEngine { summarizer?: TreeSummarize; verbose?: boolean; }) { + super(); + this.serviceContext = init.serviceContext || serviceContextFromDefaults({}); this.selector = init.selector; this.queryEngines = init.queryEngineTools.map((tool) => tool.queryEngine); @@ -79,6 +83,13 @@ export class RouterQueryEngine implements BaseQueryEngine { this.verbose = init.verbose ?? false; } + _getPromptModules(): Record<string, any> { + return { + selector: this.selector, + summarizer: this.summarizer, + }; + } + static fromDefaults(init: { queryEngineTools: RouterQueryEngineTool[]; selector?: BaseSelector; @@ -119,7 +130,7 @@ export class RouterQueryEngine implements BaseQueryEngine { const result = await this.selector.select(this.metadatas, queryBundle); if (result.selections.length > 1) { - const responses = []; + const responses: Response[] = []; for (let i = 0; i < result.selections.length; i++) { const engineInd = result.selections[i]; const logStr = `Selecting query engine ${engineInd}: ${result.selections[i]}.`; diff --git a/packages/core/src/engines/query/SubQuestionQueryEngine.ts b/packages/core/src/engines/query/SubQuestionQueryEngine.ts index 73d0ada6c52d0cd3edc19698f7d36ae867139e0a..0139ecd0a255ad240b24f31b4c55bdf60c98be48 100644 --- a/packages/core/src/engines/query/SubQuestionQueryEngine.ts +++ b/packages/core/src/engines/query/SubQuestionQueryEngine.ts @@ -7,6 +7,7 @@ import { } from "../../ServiceContext"; import { Event } from "../../callbacks/CallbackManager"; import { randomUUID } from "../../env"; +import { PromptMixin } from "../../prompts"; import { BaseSynthesizer, CompactAndRefine, @@ -26,7 +27,10 @@ import { BaseQuestionGenerator, SubQuestion } from "./types"; /** * SubQuestionQueryEngine decomposes a question into subquestions and then */ -export class SubQuestionQueryEngine implements BaseQueryEngine { +export class SubQuestionQueryEngine + extends PromptMixin + implements BaseQueryEngine +{ responseSynthesizer: BaseSynthesizer; questionGen: BaseQuestionGenerator; queryEngines: BaseTool[]; @@ -37,6 +41,8 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { responseSynthesizer: BaseSynthesizer; queryEngineTools: BaseTool[]; }) { + super(); + this.questionGen = init.questionGen; this.responseSynthesizer = init.responseSynthesizer ?? new ResponseSynthesizer(); @@ -44,6 +50,13 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { this.metadatas = init.queryEngineTools.map((tool) => tool.metadata); } + protected _getPromptModules(): Record<string, any> { + return { + questionGen: this.questionGen, + responseSynthesizer: this.responseSynthesizer, + }; + } + static fromDefaults(init: { queryEngineTools: BaseTool[]; questionGen?: BaseQuestionGenerator; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index f2449885b585d95ef189f61b090a6110e1d35168..0af9e53cb8764307e5596bbc5d4910de6432f219 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -22,6 +22,7 @@ export * from "./llm"; export * from "./nodeParsers"; export * from "./objects"; export * from "./postprocessors"; +export * from "./prompts"; export * from "./readers"; export * from "./selectors"; export * from "./storage"; diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index 9515bab4471d35eb438c194d99ddcbf03a89cfb2..137d1073a8f0d857abda8610d4d581a90cc48c66 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -270,7 +270,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { responseSynthesizer?: BaseSynthesizer; preFilters?: MetadataFilters; nodePostprocessors?: BaseNodePostprocessor[]; - }): BaseQueryEngine { + }): BaseQueryEngine & RetrieverQueryEngine { const { retriever, responseSynthesizer } = options ?? {}; return new RetrieverQueryEngine( retriever ?? this.asRetriever(), diff --git a/packages/core/src/prompts/Mixin.ts b/packages/core/src/prompts/Mixin.ts new file mode 100644 index 0000000000000000000000000000000000000000..eac847b916f8f8cb1a2846eeb5e1b3ec54106e2e --- /dev/null +++ b/packages/core/src/prompts/Mixin.ts @@ -0,0 +1,89 @@ +type PromptsDict = Record<string, any>; +type ModuleDict = Record<string, any>; + +export class PromptMixin { + /** + * Validates the prompt keys and module keys + * @param promptsDict + * @param moduleDict + */ + validatePrompts(promptsDict: PromptsDict, moduleDict: ModuleDict): void { + for (let key in promptsDict) { + if (key.includes(":")) { + throw new Error(`Prompt key ${key} cannot contain ':'.`); + } + } + + for (let key in moduleDict) { + if (key.includes(":")) { + throw new Error(`Module key ${key} cannot contain ':'.`); + } + } + } + + /** + * Returns all prompts from the mixin and its modules + */ + getPrompts(): PromptsDict { + let promptsDict: PromptsDict = this._getPrompts(); + + let moduleDict = this._getPromptModules(); + + this.validatePrompts(promptsDict, moduleDict); + + let allPrompts: PromptsDict = { ...promptsDict }; + + for (let [module_name, prompt_module] of Object.entries(moduleDict)) { + for (let [key, prompt] of Object.entries(prompt_module.getPrompts())) { + allPrompts[`${module_name}:${key}`] = prompt; + } + } + return allPrompts; + } + + /** + * Updates the prompts in the mixin and its modules + * @param promptsDict + */ + updatePrompts(promptsDict: PromptsDict): void { + let promptModules = this._getPromptModules(); + + this._updatePrompts(promptsDict); + + let subPromptDicts: Record<string, PromptsDict> = {}; + + for (let key in promptsDict) { + if (key.includes(":")) { + let [module_name, sub_key] = key.split(":"); + + if (!subPromptDicts[module_name]) { + subPromptDicts[module_name] = {}; + } + subPromptDicts[module_name][sub_key] = promptsDict[key]; + } + } + + for (let [module_name, subPromptDict] of Object.entries(subPromptDicts)) { + if (!promptModules[module_name]) { + throw new Error(`Module ${module_name} not found.`); + } + + let moduleToUpdate = promptModules[module_name]; + + moduleToUpdate.updatePrompts(subPromptDict); + } + } + + // Must be implemented by subclasses + protected _getPrompts(): PromptsDict { + return {}; + } + + protected _getPromptModules(): Record<string, any> { + return {}; + } + + protected _updatePrompts(promptsDict: PromptsDict): void { + return; + } +} diff --git a/packages/core/src/prompts/index.ts b/packages/core/src/prompts/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..fb7823314cd60123bade0d4203083942dc1be078 --- /dev/null +++ b/packages/core/src/prompts/index.ts @@ -0,0 +1 @@ +export * from "./Mixin"; diff --git a/packages/core/src/selectors/base.ts b/packages/core/src/selectors/base.ts index 74bd89e372f1afb07cd095d9637bbcfa3ad0ab19..3268d55e1f0e61f6271732712a9cab3cedc7e065 100644 --- a/packages/core/src/selectors/base.ts +++ b/packages/core/src/selectors/base.ts @@ -1,3 +1,4 @@ +import { PromptMixin } from "../prompts"; import { QueryBundle, ToolMetadataOnlyDescription } from "../types"; export interface SingleSelection { @@ -31,7 +32,7 @@ function wrapQuery(query: QueryType): QueryBundle { type MetadataType = string | ToolMetadataOnlyDescription; -export abstract class BaseSelector { +export abstract class BaseSelector extends PromptMixin { async select(choices: MetadataType[], query: QueryType) { const metadatas = choices.map((choice) => wrapChoice(choice)); const queryBundle = wrapQuery(query); diff --git a/packages/core/src/selectors/llmSelectors.ts b/packages/core/src/selectors/llmSelectors.ts index 74acdc87f8c8065913fe889759ff9ab95990a64f..c1a51c6e168bc6ade31bfb22ef0b33ae840f329d 100644 --- a/packages/core/src/selectors/llmSelectors.ts +++ b/packages/core/src/selectors/llmSelectors.ts @@ -1,4 +1,3 @@ -import { DefaultPromptTemplate } from "../extractors/prompts"; import { LLM } from "../llm"; import { Answer, SelectionOutputParser } from "../outputParsers/selectors"; import { @@ -8,7 +7,12 @@ import { ToolMetadataOnlyDescription, } from "../types"; import { BaseSelector, SelectorResult } from "./base"; -import { defaultSingleSelectPrompt } from "./prompts"; +import { + MultiSelectPrompt, + SingleSelectPrompt, + defaultMultiSelectPrompt, + defaultSingleSelectPrompt, +} from "./prompts"; function buildChoicesText(choices: ToolMetadataOnlyDescription[]): string { const texts: string[] = []; @@ -20,7 +24,7 @@ function buildChoicesText(choices: ToolMetadataOnlyDescription[]): string { return texts.join(""); } -function _structuredOutputToSelectorResult( +function structuredOutputToSelectorResult( output: StructuredOutput<Answer[]>, ): SelectorResult { const structuredOutput = output; @@ -40,32 +44,32 @@ type LLMPredictorType = LLM; * A selector that uses the LLM to select a single or multiple choices from a list of choices. */ export class LLMMultiSelector extends BaseSelector { - _llm: LLMPredictorType; - _prompt: DefaultPromptTemplate | undefined; - _maxOutputs: number | null; - _outputParser: BaseOutputParser<any> | null; + llm: LLMPredictorType; + prompt: MultiSelectPrompt; + maxOutputs: number; + outputParser: BaseOutputParser<StructuredOutput<Answer[]>> | null; constructor(init: { llm: LLMPredictorType; - prompt?: DefaultPromptTemplate; + prompt?: MultiSelectPrompt; maxOutputs?: number; - outputParser?: BaseOutputParser<any>; + outputParser?: BaseOutputParser<StructuredOutput<Answer[]>>; }) { super(); - this._llm = init.llm; - this._prompt = init.prompt; - this._maxOutputs = init.maxOutputs ?? null; + this.llm = init.llm; + this.prompt = init.prompt ?? defaultMultiSelectPrompt; + this.maxOutputs = init.maxOutputs ?? 10; - this._outputParser = init.outputParser ?? new SelectionOutputParser(); + this.outputParser = init.outputParser ?? new SelectionOutputParser(); } - _getPrompts(): Record<string, any> { - return { prompt: this._prompt }; + _getPrompts(): Record<string, MultiSelectPrompt> { + return { prompt: this.prompt }; } - _updatePrompts(prompts: Record<string, any>): void { + _updatePrompts(prompts: Record<string, MultiSelectPrompt>): void { if ("prompt" in prompts) { - this._prompt = prompts.prompt; + this.prompt = prompts.prompt; } } @@ -80,22 +84,26 @@ export class LLMMultiSelector extends BaseSelector { ): Promise<SelectorResult> { const choicesText = buildChoicesText(choices); - const prompt = - this._prompt?.contextStr ?? - defaultSingleSelectPrompt( - choicesText.length, - choicesText, - query.queryStr, - ); - const formattedPrompt = this._outputParser?.format(prompt); + const prompt = this.prompt( + choicesText.length, + choicesText, + query.queryStr, + this.maxOutputs, + ); + + const formattedPrompt = this.outputParser?.format(prompt); - const prediction = await this._llm.complete({ + const prediction = await this.llm.complete({ prompt: formattedPrompt, }); - const parsed = this._outputParser?.parse(prediction.text); + const parsed = this.outputParser?.parse(prediction.text); + + if (!parsed) { + throw new Error("Parsed output is undefined"); + } - return _structuredOutputToSelectorResult(parsed); + return structuredOutputToSelectorResult(parsed); } asQueryComponent(): unknown { @@ -107,28 +115,28 @@ export class LLMMultiSelector extends BaseSelector { * A selector that uses the LLM to select a single choice from a list of choices. */ export class LLMSingleSelector extends BaseSelector { - _llm: LLMPredictorType; - _prompt: DefaultPromptTemplate | undefined; - _outputParser: BaseOutputParser<any> | null; + llm: LLMPredictorType; + prompt: SingleSelectPrompt; + outputParser: BaseOutputParser<StructuredOutput<Answer[]>> | null; constructor(init: { llm: LLMPredictorType; - prompt?: DefaultPromptTemplate; - outputParser?: BaseOutputParser<any>; + prompt?: SingleSelectPrompt; + outputParser?: BaseOutputParser<StructuredOutput<Answer[]>>; }) { super(); - this._llm = init.llm; - this._prompt = init.prompt; - this._outputParser = init.outputParser ?? new SelectionOutputParser(); + this.llm = init.llm; + this.prompt = init.prompt ?? defaultSingleSelectPrompt; + this.outputParser = init.outputParser ?? new SelectionOutputParser(); } - _getPrompts(): Record<string, any> { - return { prompt: this._prompt }; + _getPrompts(): Record<string, SingleSelectPrompt> { + return { prompt: this.prompt }; } - _updatePrompts(prompts: Record<string, any>): void { + _updatePrompts(prompts: Record<string, SingleSelectPrompt>): void { if ("prompt" in prompts) { - this._prompt = prompts.prompt; + this.prompt = prompts.prompt; } } @@ -143,23 +151,21 @@ export class LLMSingleSelector extends BaseSelector { ): Promise<SelectorResult> { const choicesText = buildChoicesText(choices); - const prompt = - this._prompt?.contextStr ?? - defaultSingleSelectPrompt( - choicesText.length, - choicesText, - query.queryStr, - ); + const prompt = this.prompt(choicesText.length, choicesText, query.queryStr); - const formattedPrompt = this._outputParser?.format(prompt); + const formattedPrompt = this.outputParser?.format(prompt); - const prediction = await this._llm.complete({ + const prediction = await this.llm.complete({ prompt: formattedPrompt, }); - const parsed = this._outputParser?.parse(prediction.text); + const parsed = this.outputParser?.parse(prediction.text); + + if (!parsed) { + throw new Error("Parsed output is undefined"); + } - return _structuredOutputToSelectorResult(parsed); + return structuredOutputToSelectorResult(parsed); } asQueryComponent(): unknown { diff --git a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts index 307138350605cb30b2180921138ed4752c317ab7..a6a442a9d30a69a79c35f9c5f7c052d2a23c11fe 100644 --- a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts @@ -3,6 +3,7 @@ import { Response } from "../Response"; import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; import { imageToDataUrl } from "../embeddings"; import { MessageContentDetail } from "../llm/types"; +import { PromptMixin } from "../prompts"; import { TextQaPrompt, defaultTextQaPrompt } from "./../Prompt"; import { BaseSynthesizer, @@ -10,7 +11,10 @@ import { SynthesizeParamsStreaming, } from "./types"; -export class MultiModalResponseSynthesizer implements BaseSynthesizer { +export class MultiModalResponseSynthesizer + extends PromptMixin + implements BaseSynthesizer +{ serviceContext: ServiceContext; metadataMode: MetadataMode; textQATemplate: TextQaPrompt; @@ -20,11 +24,27 @@ export class MultiModalResponseSynthesizer implements BaseSynthesizer { textQATemplate, metadataMode, }: Partial<MultiModalResponseSynthesizer> = {}) { + super(); + this.serviceContext = serviceContext ?? serviceContextFromDefaults(); this.metadataMode = metadataMode ?? MetadataMode.NONE; this.textQATemplate = textQATemplate ?? defaultTextQaPrompt; } + protected _getPrompts(): { textQATemplate: TextQaPrompt } { + return { + textQATemplate: this.textQATemplate, + }; + } + + protected _updatePrompts(promptsDict: { + textQATemplate: TextQaPrompt; + }): void { + if (promptsDict.textQATemplate) { + this.textQATemplate = promptsDict.textQATemplate; + } + } + synthesize( params: SynthesizeParamsStreaming, ): Promise<AsyncIterable<Response>>; diff --git a/packages/core/src/synthesizers/ResponseSynthesizer.ts b/packages/core/src/synthesizers/ResponseSynthesizer.ts index 61ed392fe85c7a7ea3c2682a41b8722a22a656b5..7b28677b484a818f2455766af22331cee9188d2e 100644 --- a/packages/core/src/synthesizers/ResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/ResponseSynthesizer.ts @@ -2,7 +2,8 @@ import { MetadataMode } from "../Node"; import { Response } from "../Response"; import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; import { streamConverter } from "../llm/utils"; -import { getResponseBuilder } from "./builders"; +import { PromptMixin } from "../prompts"; +import { ResponseBuilderPrompts, getResponseBuilder } from "./builders"; import { BaseSynthesizer, ResponseBuilder, @@ -13,7 +14,10 @@ import { /** * A ResponseSynthesizer is used to generate a response from a query and a list of nodes. */ -export class ResponseSynthesizer implements BaseSynthesizer { +export class ResponseSynthesizer + extends PromptMixin + implements BaseSynthesizer +{ responseBuilder: ResponseBuilder; serviceContext: ServiceContext; metadataMode: MetadataMode; @@ -27,12 +31,31 @@ export class ResponseSynthesizer implements BaseSynthesizer { serviceContext?: ServiceContext; metadataMode?: MetadataMode; } = {}) { + super(); + this.serviceContext = serviceContext ?? serviceContextFromDefaults(); this.responseBuilder = responseBuilder ?? getResponseBuilder(this.serviceContext); this.metadataMode = metadataMode; } + _getPromptModules() { + return {}; + } + + protected _getPrompts(): { [x: string]: ResponseBuilderPrompts } { + const prompts = this.responseBuilder.getPrompts?.(); + return { + ...prompts, + }; + } + + protected _updatePrompts(promptsDict: { + [x: string]: ResponseBuilderPrompts; + }): void { + this.responseBuilder.updatePrompts?.(promptsDict); + } + synthesize( params: SynthesizeParamsStreaming, ): Promise<AsyncIterable<Response>>; diff --git a/packages/core/src/synthesizers/builders.ts b/packages/core/src/synthesizers/builders.ts index 67f1173b2dff982251ceab7826678cfd49ff33af..edec2cfdb0351fcb6cafa93c38df502d9684c5db 100644 --- a/packages/core/src/synthesizers/builders.ts +++ b/packages/core/src/synthesizers/builders.ts @@ -11,6 +11,7 @@ import { TreeSummarizePrompt, } from "../Prompt"; import { getBiggestPrompt, PromptHelper } from "../PromptHelper"; +import { PromptMixin } from "../prompts"; import { ServiceContext } from "../ServiceContext"; import { ResponseBuilder, @@ -73,7 +74,7 @@ export class SimpleResponseBuilder implements ResponseBuilder { /** * A response builder that uses the query to ask the LLM generate a better response using multiple text chunks. */ -export class Refine implements ResponseBuilder { +export class Refine extends PromptMixin implements ResponseBuilder { llm: LLM; promptHelper: PromptHelper; textQATemplate: TextQaPrompt; @@ -84,12 +85,37 @@ export class Refine implements ResponseBuilder { textQATemplate?: TextQaPrompt, refineTemplate?: RefinePrompt, ) { + super(); + this.llm = serviceContext.llm; this.promptHelper = serviceContext.promptHelper; this.textQATemplate = textQATemplate ?? defaultTextQaPrompt; this.refineTemplate = refineTemplate ?? defaultRefinePrompt; } + protected _getPrompts(): { + textQATemplate: RefinePrompt; + refineTemplate: RefinePrompt; + } { + return { + textQATemplate: this.textQATemplate, + refineTemplate: this.refineTemplate, + }; + } + + protected _updatePrompts(prompts: { + textQATemplate: RefinePrompt; + refineTemplate: RefinePrompt; + }): void { + if (prompts.textQATemplate) { + this.textQATemplate = prompts.textQATemplate; + } + + if (prompts.refineTemplate) { + this.refineTemplate = prompts.refineTemplate; + } + } + getResponse( params: ResponseBuilderParamsStreaming, ): Promise<AsyncIterable<string>>; @@ -258,7 +284,7 @@ export class CompactAndRefine extends Refine { /** * TreeSummarize repacks the text chunks into the smallest possible number of chunks and then summarizes them, then recursively does so until there's one chunk left. */ -export class TreeSummarize implements ResponseBuilder { +export class TreeSummarize extends PromptMixin implements ResponseBuilder { llm: LLM; promptHelper: PromptHelper; summaryTemplate: TreeSummarizePrompt; @@ -267,11 +293,27 @@ export class TreeSummarize implements ResponseBuilder { serviceContext: ServiceContext, summaryTemplate?: TreeSummarizePrompt, ) { + super(); + this.llm = serviceContext.llm; this.promptHelper = serviceContext.promptHelper; this.summaryTemplate = summaryTemplate ?? defaultTreeSummarizePrompt; } + protected _getPrompts(): { summaryTemplate: TreeSummarizePrompt } { + return { + summaryTemplate: this.summaryTemplate, + }; + } + + protected _updatePrompts(prompts: { + summaryTemplate: TreeSummarizePrompt; + }): void { + if (prompts.summaryTemplate) { + this.summaryTemplate = prompts.summaryTemplate; + } + } + getResponse( params: ResponseBuilderParamsStreaming, ): Promise<AsyncIterable<string>>; @@ -352,3 +394,8 @@ export function getResponseBuilder( return new CompactAndRefine(serviceContext); } } + +export type ResponseBuilderPrompts = + | TextQaPrompt + | TreeSummarizePrompt + | RefinePrompt; diff --git a/packages/core/src/synthesizers/types.ts b/packages/core/src/synthesizers/types.ts index a188a6e7657d87ca8fc44a4212d6e840cf3328bc..30798283649ea98e7a8a214f844253ed0b21daa8 100644 --- a/packages/core/src/synthesizers/types.ts +++ b/packages/core/src/synthesizers/types.ts @@ -1,5 +1,6 @@ import { Event } from "../callbacks/CallbackManager"; import { NodeWithScore } from "../Node"; +import { PromptMixin } from "../prompts"; import { Response } from "../Response"; export interface SynthesizeParamsBase { @@ -46,7 +47,7 @@ export interface ResponseBuilderParamsNonStreaming /** * A ResponseBuilder is used in a response synthesizer to generate a response from multiple response chunks. */ -export interface ResponseBuilder { +export interface ResponseBuilder extends Partial<PromptMixin> { /** * Get the response from a query and a list of text chunks. * @param params diff --git a/packages/core/src/tests/prompts/Mixin.test.ts b/packages/core/src/tests/prompts/Mixin.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..ba9293ce619695f1accab8629776bf7a92c8cbd6 --- /dev/null +++ b/packages/core/src/tests/prompts/Mixin.test.ts @@ -0,0 +1,138 @@ +import { PromptMixin } from "../../prompts"; + +type MockPrompt = { + context: string; + query: string; +}; + +const mockPrompt = ({ context, query }: MockPrompt) => + `context: ${context} query: ${query}`; + +const mockPrompt2 = ({ context, query }: MockPrompt) => + `query: ${query} context: ${context}`; + +type MockPromptFunction = typeof mockPrompt; + +class MockObject2 extends PromptMixin { + _prompt_dict_2: MockPromptFunction; + + constructor() { + super(); + this._prompt_dict_2 = mockPrompt; + } + + protected _getPrompts(): { [x: string]: MockPromptFunction } { + return { + abc: this._prompt_dict_2, + }; + } + + _updatePrompts(promptsDict: { [x: string]: MockPromptFunction }): void { + if ("abc" in promptsDict) { + this._prompt_dict_2 = promptsDict["abc"]; + } + } +} + +class MockObject1 extends PromptMixin { + mockObject2: MockObject2; + + fooPrompt: MockPromptFunction; + barPrompt: MockPromptFunction; + + constructor() { + super(); + this.mockObject2 = new MockObject2(); + this.fooPrompt = mockPrompt; + this.barPrompt = mockPrompt; + } + + protected _getPrompts(): { [x: string]: any } { + return { + bar: this.barPrompt, + foo: this.fooPrompt, + }; + } + + protected _getPromptModules(): { [x: string]: any } { + return { mock_object_2: this.mockObject2 }; + } + + _updatePrompts(promptsDict: { [x: string]: any }): void { + if ("bar" in promptsDict) { + this.barPrompt = promptsDict["bar"]; + } + if ("foo" in promptsDict) { + this.fooPrompt = promptsDict["foo"]; + } + } +} + +describe("PromptMixin", () => { + it("should return prompts", () => { + const mockObj1 = new MockObject1(); + + const prompts = mockObj1.getPrompts(); + + expect( + mockObj1.fooPrompt({ + context: "{foo}", + query: "{foo}", + }), + ).toEqual("context: {foo} query: {foo}"); + + expect( + mockObj1.barPrompt({ + context: "{foo} {bar}", + query: "{foo} {bar}", + }), + ).toEqual("context: {foo} {bar} query: {foo} {bar}"); + + expect(mockObj1.fooPrompt).toEqual(prompts.foo); + expect(mockObj1.barPrompt).toEqual(prompts.bar); + + expect(mockObj1.getPrompts()).toEqual({ + bar: mockPrompt, + foo: mockPrompt, + "mock_object_2:abc": mockPrompt, + }); + }); + + it("should update prompts", () => { + const mockObj1 = new MockObject1(); + + expect( + mockObj1.barPrompt({ + context: "{foo} {bar}", + query: "{foo} {bar}", + }), + ).toEqual(mockPrompt({ context: "{foo} {bar}", query: "{foo} {bar}" })); + + expect( + mockObj1.mockObject2._prompt_dict_2({ + context: "{bar} testing", + query: "{bar} testing", + }), + ).toEqual(mockPrompt({ context: "{bar} testing", query: "{bar} testing" })); + + mockObj1.updatePrompts({ + bar: mockPrompt2, + "mock_object_2:abc": mockPrompt2, + }); + + expect( + mockObj1.barPrompt({ + context: "{foo} {bar}", + query: "{bar} {foo}", + }), + ).toEqual(mockPrompt2({ context: "{foo} {bar}", query: "{bar} {foo}" })); + expect( + mockObj1.mockObject2._prompt_dict_2({ + context: "{bar} testing", + query: "{bar} testing", + }), + ).toEqual( + mockPrompt2({ context: "{bar} testing", query: "{bar} testing" }), + ); + }); +});