diff --git a/.changeset/fair-pets-leave.md b/.changeset/fair-pets-leave.md new file mode 100644 index 0000000000000000000000000000000000000000..9b74c155568ba80dccb1c00486b5d055e5e6c1c7 --- /dev/null +++ b/.changeset/fair-pets-leave.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Strong types for prompts. diff --git a/apps/simple/csv.ts b/apps/simple/csv.ts index d1c413ce97ea06dcc8e824d472e69a8e2c5775ef..1e0a11237297b086fbc0c2d20b38ef6bea68c818 100644 --- a/apps/simple/csv.ts +++ b/apps/simple/csv.ts @@ -4,7 +4,6 @@ import { PapaCSVReader, ResponseSynthesizer, serviceContextFromDefaults, - SimplePrompt, VectorStoreIndex, } from "llamaindex"; @@ -23,9 +22,7 @@ async function main() { serviceContext, }); - const csvPrompt: SimplePrompt = (input) => { - const { context = "", query = "" } = input; - + const csvPrompt = ({ context = "", query = "" }) => { return `The following CSV file is loaded from ${path} \`\`\`csv ${context} diff --git a/apps/simple/openai.ts b/apps/simple/openai.ts index 1c40fb9ba5a556af8d135449369f0044ef51b5ea..4c7856be0ab9e5912cac3ca119416c6157c750ad 100644 --- a/apps/simple/openai.ts +++ b/apps/simple/openai.ts @@ -1,14 +1,7 @@ import { OpenAI } from "llamaindex"; (async () => { - const llm = new OpenAI({ - model: "gpt-3.5-turbo", - temperature: 0.1, - additionalChatOptions: { frequency_penalty: 0.1 }, - additionalSessionOptions: { - defaultHeaders: { "X-Test-Header-Please-Ignore": "true" }, - }, - }); + const llm = new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 }); // complete api const response1 = await llm.complete("How are you?"); diff --git a/apps/simple/vectorIndexCustomize.ts b/apps/simple/vectorIndexCustomize.ts index b9dbe8d8bbeb1f3619a139fc9f295d79bb934200..5ad55cff6c50defaca130c1e1cefe43857564741 100644 --- a/apps/simple/vectorIndexCustomize.ts +++ b/apps/simple/vectorIndexCustomize.ts @@ -12,7 +12,7 @@ async function main() { const document = new Document({ text: essay, id_: "essay" }); const serviceContext = serviceContextFromDefaults({ - llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 }), + llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }), }); const index = await VectorStoreIndex.fromDocuments([document], { diff --git a/examples/csv.ts b/examples/csv.ts index d1c413ce97ea06dcc8e824d472e69a8e2c5775ef..1e0a11237297b086fbc0c2d20b38ef6bea68c818 100644 --- a/examples/csv.ts +++ b/examples/csv.ts @@ -4,7 +4,6 @@ import { PapaCSVReader, ResponseSynthesizer, serviceContextFromDefaults, - SimplePrompt, VectorStoreIndex, } from "llamaindex"; @@ -23,9 +22,7 @@ async function main() { serviceContext, }); - const csvPrompt: SimplePrompt = (input) => { - const { context = "", query = "" } = input; - + const csvPrompt = ({ context = "", query = "" }) => { return `The following CSV file is loaded from ${path} \`\`\`csv ${context} diff --git a/examples/openai.ts b/examples/openai.ts index f53709c6495d2a098a439e97b1a1f54321ec6ca3..4c7856be0ab9e5912cac3ca119416c6157c750ad 100644 --- a/examples/openai.ts +++ b/examples/openai.ts @@ -2,12 +2,14 @@ import { OpenAI } from "llamaindex"; (async () => { const llm = new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 }); - + // complete api const response1 = await llm.complete("How are you?"); console.log(response1.message.content); // chat api - const response2 = await llm.chat([{ content: "Tell me a joke!", role: "user" }]); + const response2 = await llm.chat([ + { content: "Tell me a joke!", role: "user" }, + ]); console.log(response2.message.content); })(); diff --git a/examples/vectorIndexCustomize.ts b/examples/vectorIndexCustomize.ts index b9dbe8d8bbeb1f3619a139fc9f295d79bb934200..5ad55cff6c50defaca130c1e1cefe43857564741 100644 --- a/examples/vectorIndexCustomize.ts +++ b/examples/vectorIndexCustomize.ts @@ -12,7 +12,7 @@ async function main() { const document = new Document({ text: essay, id_: "essay" }); const serviceContext = serviceContextFromDefaults({ - llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 }), + llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }), }); const index = await VectorStoreIndex.fromDocuments([document], { diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index 1cc847569e1fc2b630bf189981acf87416b3ab0c..9ba53f5fe5545c04d8560b2d14e892aee969ef2b 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -1,17 +1,18 @@ -import { ChatMessage, OpenAI, ChatResponse, LLM } from "./llm/LLM"; +import { v4 as uuidv4 } from "uuid"; import { TextNode } from "./Node"; import { - SimplePrompt, - contextSystemPrompt, + CondenseQuestionPrompt, + ContextSystemPrompt, defaultCondenseQuestionPrompt, + defaultContextSystemPrompt, messagesToHistoryStr, } from "./Prompt"; import { BaseQueryEngine } from "./QueryEngine"; import { Response } from "./Response"; import { BaseRetriever } from "./Retriever"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; -import { v4 as uuidv4 } from "uuid"; import { Event } from "./callbacks/CallbackManager"; +import { ChatMessage, LLM, OpenAI } from "./llm/LLM"; /** * A ChatEngine is used to handle back and forth chats between the application and the LLM. @@ -70,13 +71,13 @@ export class CondenseQuestionChatEngine implements ChatEngine { queryEngine: BaseQueryEngine; chatHistory: ChatMessage[]; serviceContext: ServiceContext; - condenseMessagePrompt: SimplePrompt; + condenseMessagePrompt: CondenseQuestionPrompt; constructor(init: { queryEngine: BaseQueryEngine; chatHistory: ChatMessage[]; serviceContext?: ServiceContext; - condenseMessagePrompt?: SimplePrompt; + condenseMessagePrompt?: CondenseQuestionPrompt; }) { this.queryEngine = init.queryEngine; this.chatHistory = init?.chatHistory ?? []; @@ -92,14 +93,14 @@ export class CondenseQuestionChatEngine implements ChatEngine { return this.serviceContext.llm.complete( defaultCondenseQuestionPrompt({ question: question, - chat_history: chatHistoryStr, - }) + chatHistory: chatHistoryStr, + }), ); } async chat( message: string, - chatHistory?: ChatMessage[] | undefined + chatHistory?: ChatMessage[] | undefined, ): Promise<Response> { chatHistory = chatHistory ?? this.chatHistory; @@ -129,16 +130,20 @@ export class ContextChatEngine implements ChatEngine { retriever: BaseRetriever; chatModel: OpenAI; chatHistory: ChatMessage[]; + contextSystemPrompt: ContextSystemPrompt; constructor(init: { retriever: BaseRetriever; chatModel?: OpenAI; chatHistory?: ChatMessage[]; + contextSystemPrompt?: ContextSystemPrompt; }) { this.retriever = init.retriever; this.chatModel = init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" }); this.chatHistory = init?.chatHistory ?? []; + this.contextSystemPrompt = + init?.contextSystemPrompt ?? defaultContextSystemPrompt; } async chat(message: string, chatHistory?: ChatMessage[] | undefined) { @@ -151,11 +156,11 @@ export class ContextChatEngine implements ChatEngine { }; const sourceNodesWithScore = await this.retriever.retrieve( message, - parentEvent + parentEvent, ); const systemMessage: ChatMessage = { - content: contextSystemPrompt({ + content: this.contextSystemPrompt({ context: sourceNodesWithScore .map((r) => (r.node as TextNode).text) .join("\n\n"), @@ -167,7 +172,7 @@ export class ContextChatEngine implements ChatEngine { const response = await this.chatModel.chat( [systemMessage, ...chatHistory], - parentEvent + parentEvent, ); chatHistory.push(response.message); @@ -175,7 +180,7 @@ export class ContextChatEngine implements ChatEngine { return new Response( response.message.content, - sourceNodesWithScore.map((r) => r.node) + sourceNodesWithScore.map((r) => r.node), ); } diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index d6ccde84ec4703629773c20b58cf391c193b1f0b..86e1a4fd4c8a9d0c96c8c4547567e48f135698f8 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -22,9 +22,7 @@ DEFAULT_TEXT_QA_PROMPT_TMPL = ( ) */ -export const defaultTextQaPrompt: SimplePrompt = (input) => { - const { context = "", query = "" } = input; - +export const defaultTextQaPrompt = ({ context = "", query = "" }) => { return `Context information is below. --------------------- ${context} @@ -34,6 +32,8 @@ Query: ${query} Answer:`; }; +export type TextQaPrompt = typeof defaultTextQaPrompt; + /* DEFAULT_SUMMARY_PROMPT_TMPL = ( "Write a summary of the following. Try to use only the " @@ -48,9 +48,7 @@ DEFAULT_SUMMARY_PROMPT_TMPL = ( ) */ -export const defaultSummaryPrompt: SimplePrompt = (input) => { - const { context = "" } = input; - +export const defaultSummaryPrompt = ({ context = "" }) => { return `Write a summary of the following. Try to use only the information provided. Try to include as many key details as possible. @@ -61,6 +59,8 @@ SUMMARY:""" `; }; +export type SummaryPrompt = typeof defaultSummaryPrompt; + /* DEFAULT_REFINE_PROMPT_TMPL = ( "The original query is as follows: {query_str}\n" @@ -77,9 +77,11 @@ DEFAULT_REFINE_PROMPT_TMPL = ( ) */ -export const defaultRefinePrompt: SimplePrompt = (input) => { - const { query = "", existingAnswer = "", context = "" } = input; - +export const defaultRefinePrompt = ({ + query = "", + existingAnswer = "", + context = "", +}) => { return `The original query is as follows: ${query} We have provided an existing answer: ${existingAnswer} We have the opportunity to refine the existing answer (only if needed) with some more context below. @@ -90,6 +92,8 @@ Given the new context, refine the original answer to better answer the query. If Refined Answer:`; }; +export type RefinePrompt = typeof defaultRefinePrompt; + /* DEFAULT_TREE_SUMMARIZE_TMPL = ( "Context information from multiple sources is below.\n" @@ -103,9 +107,7 @@ DEFAULT_TREE_SUMMARIZE_TMPL = ( ) */ -export const defaultTreeSummarizePrompt: SimplePrompt = (input) => { - const { context = "", query = "" } = input; - +export const defaultTreeSummarizePrompt = ({ context = "", query = "" }) => { return `Context information from multiple sources is below. --------------------- ${context} @@ -115,9 +117,9 @@ Query: ${query} Answer:`; }; -export const defaultChoiceSelectPrompt: SimplePrompt = (input) => { - const { context = "", query = "" } = input; +export type TreeSummarizePrompt = typeof defaultTreeSummarizePrompt; +export const defaultChoiceSelectPrompt = ({ context = "", query = "" }) => { return `A list of documents is shown below. Each document has a number next to it along with a summary of the document. A question is also provided. Respond with the numbers of the documents @@ -149,6 +151,8 @@ Question: ${query} Answer:`; }; +export type ChoiceSelectPrompt = typeof defaultChoiceSelectPrompt; + /* PREFIX = """\ Given a user question, and a list of tools, output a list of relevant sub-questions \ @@ -266,9 +270,7 @@ const exampleOutput: SubQuestion[] = [ }, ]; -export const defaultSubQuestionPrompt: SimplePrompt = (input) => { - const { toolsStr, queryStr } = input; - +export const defaultSubQuestionPrompt = ({ toolsStr = "", queryStr = "" }) => { return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question: # Example 1 @@ -298,6 +300,8 @@ ${queryStr} `; }; +export type SubQuestionPrompt = typeof defaultSubQuestionPrompt; + // DEFAULT_TEMPLATE = """\ // Given a conversation (between Human and Assistant) and a follow up message from Human, \ // rewrite the message to be a standalone question that captures all relevant context \ @@ -312,9 +316,10 @@ ${queryStr} // <Standalone question> // """ -export const defaultCondenseQuestionPrompt: SimplePrompt = (input) => { - const { chatHistory, question } = input; - +export const defaultCondenseQuestionPrompt = ({ + chatHistory = "", + question = "", +}) => { return `Given a conversation (between Human and Assistant) and a follow up message from Human, rewrite the message to be a standalone question that captures all relevant context from the conversation. <Chat History> @@ -327,6 +332,8 @@ ${question} `; }; +export type CondenseQuestionPrompt = typeof defaultCondenseQuestionPrompt; + export function messagesToHistoryStr(messages: ChatMessage[]) { return messages.reduce((acc, message) => { acc += acc ? "\n" : ""; @@ -339,11 +346,11 @@ export function messagesToHistoryStr(messages: ChatMessage[]) { }, ""); } -export const contextSystemPrompt: SimplePrompt = (input) => { - const { context } = input; - +export const defaultContextSystemPrompt = ({ context = "" }) => { return `Context information is below. --------------------- ${context} ---------------------`; }; + +export type ContextSystemPrompt = typeof defaultContextSystemPrompt; diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts index 24d46e954938be822d9ca8bac2dc984d6508d2bc..ada522b5715441c9f03567fc66b73b4d51901301 100644 --- a/packages/core/src/QuestionGenerator.ts +++ b/packages/core/src/QuestionGenerator.ts @@ -4,7 +4,7 @@ import { SubQuestionOutputParser, } from "./OutputParser"; import { - SimplePrompt, + SubQuestionPrompt, buildToolsText, defaultSubQuestionPrompt, } from "./Prompt"; @@ -28,7 +28,7 @@ export interface BaseQuestionGenerator { */ export class LLMQuestionGenerator implements BaseQuestionGenerator { llm: LLM; - prompt: SimplePrompt; + prompt: SubQuestionPrompt; outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>; constructor(init?: Partial<LLMQuestionGenerator>) { @@ -45,7 +45,7 @@ export class LLMQuestionGenerator implements BaseQuestionGenerator { this.prompt({ toolsStr, queryStr, - }) + }), ) ).message.content; diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index d781d5a198b1252f5b80bd033c8a7dbce0d7d3d2..912c02516eb2f7fc1f6cb2ac6b3424352af5a41f 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -1,6 +1,9 @@ import { MetadataMode, NodeWithScore } from "./Node"; import { + RefinePrompt, SimplePrompt, + TextQaPrompt, + TreeSummarizePrompt, defaultRefinePrompt, defaultTextQaPrompt, defaultTreeSummarizePrompt, @@ -73,13 +76,13 @@ export class SimpleResponseBuilder implements BaseResponseBuilder { */ export class Refine implements BaseResponseBuilder { serviceContext: ServiceContext; - textQATemplate: SimplePrompt; - refineTemplate: SimplePrompt; + textQATemplate: TextQaPrompt; + refineTemplate: RefinePrompt; constructor( serviceContext: ServiceContext, - textQATemplate?: SimplePrompt, - refineTemplate?: SimplePrompt, + textQATemplate?: TextQaPrompt, + refineTemplate?: RefinePrompt, ) { this.serviceContext = serviceContext; this.textQATemplate = textQATemplate ?? defaultTextQaPrompt; @@ -209,9 +212,14 @@ export class CompactAndRefine extends Refine { */ export class TreeSummarize implements BaseResponseBuilder { serviceContext: ServiceContext; + summaryTemplate: TreeSummarizePrompt; - constructor(serviceContext: ServiceContext) { + constructor( + serviceContext: ServiceContext, + summaryTemplate?: TreeSummarizePrompt, + ) { this.serviceContext = serviceContext; + this.summaryTemplate = summaryTemplate ?? defaultTreeSummarizePrompt; } async getResponse( @@ -219,21 +227,19 @@ export class TreeSummarize implements BaseResponseBuilder { textChunks: string[], parentEvent?: Event, ): Promise<string> { - const summaryTemplate: SimplePrompt = defaultTreeSummarizePrompt; - if (!textChunks || textChunks.length === 0) { throw new Error("Must have at least one text chunk"); } const packedTextChunks = this.serviceContext.promptHelper.repack( - summaryTemplate, + this.summaryTemplate, textChunks, ); if (packedTextChunks.length === 1) { return ( await this.serviceContext.llm.complete( - summaryTemplate({ + this.summaryTemplate({ context: packedTextChunks[0], }), parentEvent, @@ -243,7 +249,7 @@ export class TreeSummarize implements BaseResponseBuilder { const summaries = await Promise.all( packedTextChunks.map((chunk) => this.serviceContext.llm.complete( - summaryTemplate({ + this.summaryTemplate({ context: chunk, }), parentEvent, @@ -298,9 +304,13 @@ export class ResponseSynthesizer { this.metadataMode = metadataMode; } - async synthesize(query: string, nodes: NodeWithScore[], parentEvent?: Event) { - let textChunks: string[] = nodes.map((node) => - node.node.getContent(this.metadataMode) + async synthesize( + query: string, + nodesWithScore: NodeWithScore[], + parentEvent?: Event, + ) { + let textChunks: string[] = nodesWithScore.map(({ node }) => + node.getContent(this.metadataMode), ); const response = await this.responseBuilder.getResponse( query, @@ -309,7 +319,7 @@ export class ResponseSynthesizer { ); return new Response( response, - nodes.map((node) => node.node), + nodesWithScore.map(({ node }) => node), ); } } diff --git a/packages/core/src/indices/summary/SummaryIndexRetriever.ts b/packages/core/src/indices/summary/SummaryIndexRetriever.ts index 61d9f2180e26643ebe34ef6f23088d31a1365d98..c7259ed8fd6af6fa97cb8d924ca8f3437109931a 100644 --- a/packages/core/src/indices/summary/SummaryIndexRetriever.ts +++ b/packages/core/src/indices/summary/SummaryIndexRetriever.ts @@ -1,7 +1,7 @@ import _ from "lodash"; import { globalsHelper } from "../../GlobalsHelper"; import { NodeWithScore } from "../../Node"; -import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt"; +import { ChoiceSelectPrompt, defaultChoiceSelectPrompt } from "../../Prompt"; import { BaseRetriever } from "../../Retriever"; import { ServiceContext } from "../../ServiceContext"; import { Event } from "../../callbacks/CallbackManager"; @@ -55,7 +55,7 @@ export class SummaryIndexRetriever implements BaseRetriever { */ export class SummaryIndexLLMRetriever implements BaseRetriever { index: SummaryIndex; - choiceSelectPrompt: SimplePrompt; + choiceSelectPrompt: ChoiceSelectPrompt; choiceBatchSize: number; formatNodeBatchFn: NodeFormatterFunction; parseChoiceSelectAnswerFn: ChoiceSelectParserFunction; @@ -63,7 +63,7 @@ export class SummaryIndexLLMRetriever implements BaseRetriever { constructor( index: SummaryIndex, - choiceSelectPrompt?: SimplePrompt, + choiceSelectPrompt?: ChoiceSelectPrompt, choiceBatchSize: number = 10, formatNodeBatchFn?: NodeFormatterFunction, parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction,