From 259fe63ceb958d21134464fc8bab5eacd6b7fc5e Mon Sep 17 00:00:00 2001
From: Yi Ding <yi.s.ding@gmail.com>
Date: Tue, 29 Aug 2023 12:33:46 -0700
Subject: [PATCH] strong types for prompts

---
 .changeset/fair-pets-leave.md                 |  5 ++
 apps/simple/csv.ts                            |  5 +-
 apps/simple/openai.ts                         |  9 +---
 apps/simple/vectorIndexCustomize.ts           |  2 +-
 examples/csv.ts                               |  5 +-
 examples/openai.ts                            |  6 ++-
 examples/vectorIndexCustomize.ts              |  2 +-
 packages/core/src/ChatEngine.ts               | 31 ++++++-----
 packages/core/src/Prompt.ts                   | 53 +++++++++++--------
 packages/core/src/QuestionGenerator.ts        |  6 +--
 packages/core/src/ResponseSynthesizer.ts      | 38 ++++++++-----
 .../indices/summary/SummaryIndexRetriever.ts  |  6 +--
 12 files changed, 92 insertions(+), 76 deletions(-)
 create mode 100644 .changeset/fair-pets-leave.md

diff --git a/.changeset/fair-pets-leave.md b/.changeset/fair-pets-leave.md
new file mode 100644
index 000000000..9b74c1555
--- /dev/null
+++ b/.changeset/fair-pets-leave.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+Strong types for prompts.
diff --git a/apps/simple/csv.ts b/apps/simple/csv.ts
index d1c413ce9..1e0a11237 100644
--- a/apps/simple/csv.ts
+++ b/apps/simple/csv.ts
@@ -4,7 +4,6 @@ import {
   PapaCSVReader,
   ResponseSynthesizer,
   serviceContextFromDefaults,
-  SimplePrompt,
   VectorStoreIndex,
 } from "llamaindex";
 
@@ -23,9 +22,7 @@ async function main() {
     serviceContext,
   });
 
-  const csvPrompt: SimplePrompt = (input) => {
-    const { context = "", query = "" } = input;
-
+  const csvPrompt = ({ context = "", query = "" }) => {
     return `The following CSV file is loaded from ${path}
 \`\`\`csv
 ${context}
diff --git a/apps/simple/openai.ts b/apps/simple/openai.ts
index 1c40fb9ba..4c7856be0 100644
--- a/apps/simple/openai.ts
+++ b/apps/simple/openai.ts
@@ -1,14 +1,7 @@
 import { OpenAI } from "llamaindex";
 
 (async () => {
-  const llm = new OpenAI({
-    model: "gpt-3.5-turbo",
-    temperature: 0.1,
-    additionalChatOptions: { frequency_penalty: 0.1 },
-    additionalSessionOptions: {
-      defaultHeaders: { "X-Test-Header-Please-Ignore": "true" },
-    },
-  });
+  const llm = new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 });
 
   // complete api
   const response1 = await llm.complete("How are you?");
diff --git a/apps/simple/vectorIndexCustomize.ts b/apps/simple/vectorIndexCustomize.ts
index b9dbe8d8b..5ad55cff6 100644
--- a/apps/simple/vectorIndexCustomize.ts
+++ b/apps/simple/vectorIndexCustomize.ts
@@ -12,7 +12,7 @@ async function main() {
   const document = new Document({ text: essay, id_: "essay" });
 
   const serviceContext = serviceContextFromDefaults({
-    llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 }),
+    llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }),
   });
 
   const index = await VectorStoreIndex.fromDocuments([document], {
diff --git a/examples/csv.ts b/examples/csv.ts
index d1c413ce9..1e0a11237 100644
--- a/examples/csv.ts
+++ b/examples/csv.ts
@@ -4,7 +4,6 @@ import {
   PapaCSVReader,
   ResponseSynthesizer,
   serviceContextFromDefaults,
-  SimplePrompt,
   VectorStoreIndex,
 } from "llamaindex";
 
@@ -23,9 +22,7 @@ async function main() {
     serviceContext,
   });
 
-  const csvPrompt: SimplePrompt = (input) => {
-    const { context = "", query = "" } = input;
-
+  const csvPrompt = ({ context = "", query = "" }) => {
     return `The following CSV file is loaded from ${path}
 \`\`\`csv
 ${context}
diff --git a/examples/openai.ts b/examples/openai.ts
index f53709c64..4c7856be0 100644
--- a/examples/openai.ts
+++ b/examples/openai.ts
@@ -2,12 +2,14 @@ import { OpenAI } from "llamaindex";
 
 (async () => {
   const llm = new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 });
-  
+
   // complete api
   const response1 = await llm.complete("How are you?");
   console.log(response1.message.content);
 
   // chat api
-  const response2 = await llm.chat([{ content: "Tell me a joke!", role: "user" }]);
+  const response2 = await llm.chat([
+    { content: "Tell me a joke!", role: "user" },
+  ]);
   console.log(response2.message.content);
 })();
diff --git a/examples/vectorIndexCustomize.ts b/examples/vectorIndexCustomize.ts
index b9dbe8d8b..5ad55cff6 100644
--- a/examples/vectorIndexCustomize.ts
+++ b/examples/vectorIndexCustomize.ts
@@ -12,7 +12,7 @@ async function main() {
   const document = new Document({ text: essay, id_: "essay" });
 
   const serviceContext = serviceContextFromDefaults({
-    llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 }),
+    llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }),
   });
 
   const index = await VectorStoreIndex.fromDocuments([document], {
diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts
index 1cc847569..9ba53f5fe 100644
--- a/packages/core/src/ChatEngine.ts
+++ b/packages/core/src/ChatEngine.ts
@@ -1,17 +1,18 @@
-import { ChatMessage, OpenAI, ChatResponse, LLM } from "./llm/LLM";
+import { v4 as uuidv4 } from "uuid";
 import { TextNode } from "./Node";
 import {
-  SimplePrompt,
-  contextSystemPrompt,
+  CondenseQuestionPrompt,
+  ContextSystemPrompt,
   defaultCondenseQuestionPrompt,
+  defaultContextSystemPrompt,
   messagesToHistoryStr,
 } from "./Prompt";
 import { BaseQueryEngine } from "./QueryEngine";
 import { Response } from "./Response";
 import { BaseRetriever } from "./Retriever";
 import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
-import { v4 as uuidv4 } from "uuid";
 import { Event } from "./callbacks/CallbackManager";
+import { ChatMessage, LLM, OpenAI } from "./llm/LLM";
 
 /**
  * A ChatEngine is used to handle back and forth chats between the application and the LLM.
@@ -70,13 +71,13 @@ export class CondenseQuestionChatEngine implements ChatEngine {
   queryEngine: BaseQueryEngine;
   chatHistory: ChatMessage[];
   serviceContext: ServiceContext;
-  condenseMessagePrompt: SimplePrompt;
+  condenseMessagePrompt: CondenseQuestionPrompt;
 
   constructor(init: {
     queryEngine: BaseQueryEngine;
     chatHistory: ChatMessage[];
     serviceContext?: ServiceContext;
-    condenseMessagePrompt?: SimplePrompt;
+    condenseMessagePrompt?: CondenseQuestionPrompt;
   }) {
     this.queryEngine = init.queryEngine;
     this.chatHistory = init?.chatHistory ?? [];
@@ -92,14 +93,14 @@ export class CondenseQuestionChatEngine implements ChatEngine {
     return this.serviceContext.llm.complete(
       defaultCondenseQuestionPrompt({
         question: question,
-        chat_history: chatHistoryStr,
-      })
+        chatHistory: chatHistoryStr,
+      }),
     );
   }
 
   async chat(
     message: string,
-    chatHistory?: ChatMessage[] | undefined
+    chatHistory?: ChatMessage[] | undefined,
   ): Promise<Response> {
     chatHistory = chatHistory ?? this.chatHistory;
 
@@ -129,16 +130,20 @@ export class ContextChatEngine implements ChatEngine {
   retriever: BaseRetriever;
   chatModel: OpenAI;
   chatHistory: ChatMessage[];
+  contextSystemPrompt: ContextSystemPrompt;
 
   constructor(init: {
     retriever: BaseRetriever;
     chatModel?: OpenAI;
     chatHistory?: ChatMessage[];
+    contextSystemPrompt?: ContextSystemPrompt;
   }) {
     this.retriever = init.retriever;
     this.chatModel =
       init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" });
     this.chatHistory = init?.chatHistory ?? [];
+    this.contextSystemPrompt =
+      init?.contextSystemPrompt ?? defaultContextSystemPrompt;
   }
 
   async chat(message: string, chatHistory?: ChatMessage[] | undefined) {
@@ -151,11 +156,11 @@ export class ContextChatEngine implements ChatEngine {
     };
     const sourceNodesWithScore = await this.retriever.retrieve(
       message,
-      parentEvent
+      parentEvent,
     );
 
     const systemMessage: ChatMessage = {
-      content: contextSystemPrompt({
+      content: this.contextSystemPrompt({
         context: sourceNodesWithScore
           .map((r) => (r.node as TextNode).text)
           .join("\n\n"),
@@ -167,7 +172,7 @@ export class ContextChatEngine implements ChatEngine {
 
     const response = await this.chatModel.chat(
       [systemMessage, ...chatHistory],
-      parentEvent
+      parentEvent,
     );
     chatHistory.push(response.message);
 
@@ -175,7 +180,7 @@ export class ContextChatEngine implements ChatEngine {
 
     return new Response(
       response.message.content,
-      sourceNodesWithScore.map((r) => r.node)
+      sourceNodesWithScore.map((r) => r.node),
     );
   }
 
diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts
index d6ccde84e..86e1a4fd4 100644
--- a/packages/core/src/Prompt.ts
+++ b/packages/core/src/Prompt.ts
@@ -22,9 +22,7 @@ DEFAULT_TEXT_QA_PROMPT_TMPL = (
 )
 */
 
-export const defaultTextQaPrompt: SimplePrompt = (input) => {
-  const { context = "", query = "" } = input;
-
+export const defaultTextQaPrompt = ({ context = "", query = "" }) => {
   return `Context information is below.
 ---------------------
 ${context}
@@ -34,6 +32,8 @@ Query: ${query}
 Answer:`;
 };
 
+export type TextQaPrompt = typeof defaultTextQaPrompt;
+
 /*
 DEFAULT_SUMMARY_PROMPT_TMPL = (
     "Write a summary of the following. Try to use only the "
@@ -48,9 +48,7 @@ DEFAULT_SUMMARY_PROMPT_TMPL = (
 )
 */
 
-export const defaultSummaryPrompt: SimplePrompt = (input) => {
-  const { context = "" } = input;
-
+export const defaultSummaryPrompt = ({ context = "" }) => {
   return `Write a summary of the following. Try to use only the information provided. Try to include as many key details as possible.
 
 
@@ -61,6 +59,8 @@ SUMMARY:"""
 `;
 };
 
+export type SummaryPrompt = typeof defaultSummaryPrompt;
+
 /*
 DEFAULT_REFINE_PROMPT_TMPL = (
     "The original query is as follows: {query_str}\n"
@@ -77,9 +77,11 @@ DEFAULT_REFINE_PROMPT_TMPL = (
 )
 */
 
-export const defaultRefinePrompt: SimplePrompt = (input) => {
-  const { query = "", existingAnswer = "", context = "" } = input;
-
+export const defaultRefinePrompt = ({
+  query = "",
+  existingAnswer = "",
+  context = "",
+}) => {
   return `The original query is as follows: ${query}
 We have provided an existing answer: ${existingAnswer}
 We have the opportunity to refine the existing answer (only if needed) with some more context below.
@@ -90,6 +92,8 @@ Given the new context, refine the original answer to better answer the query. If
 Refined Answer:`;
 };
 
+export type RefinePrompt = typeof defaultRefinePrompt;
+
 /*
 DEFAULT_TREE_SUMMARIZE_TMPL = (
   "Context information from multiple sources is below.\n"
@@ -103,9 +107,7 @@ DEFAULT_TREE_SUMMARIZE_TMPL = (
 )
 */
 
-export const defaultTreeSummarizePrompt: SimplePrompt = (input) => {
-  const { context = "", query = "" } = input;
-
+export const defaultTreeSummarizePrompt = ({ context = "", query = "" }) => {
   return `Context information from multiple sources is below.
 ---------------------
 ${context}
@@ -115,9 +117,9 @@ Query: ${query}
 Answer:`;
 };
 
-export const defaultChoiceSelectPrompt: SimplePrompt = (input) => {
-  const { context = "", query = "" } = input;
+export type TreeSummarizePrompt = typeof defaultTreeSummarizePrompt;
 
+export const defaultChoiceSelectPrompt = ({ context = "", query = "" }) => {
   return `A list of documents is shown below. Each document has a number next to it along 
 with a summary of the document. A question is also provided.
 Respond with the numbers of the documents
@@ -149,6 +151,8 @@ Question: ${query}
 Answer:`;
 };
 
+export type ChoiceSelectPrompt = typeof defaultChoiceSelectPrompt;
+
 /*
 PREFIX = """\
 Given a user question, and a list of tools, output a list of relevant sub-questions \
@@ -266,9 +270,7 @@ const exampleOutput: SubQuestion[] = [
   },
 ];
 
-export const defaultSubQuestionPrompt: SimplePrompt = (input) => {
-  const { toolsStr, queryStr } = input;
-
+export const defaultSubQuestionPrompt = ({ toolsStr = "", queryStr = "" }) => {
   return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question:
 
 # Example 1
@@ -298,6 +300,8 @@ ${queryStr}
 `;
 };
 
+export type SubQuestionPrompt = typeof defaultSubQuestionPrompt;
+
 // DEFAULT_TEMPLATE = """\
 // Given a conversation (between Human and Assistant) and a follow up message from Human, \
 // rewrite the message to be a standalone question that captures all relevant context \
@@ -312,9 +316,10 @@ ${queryStr}
 // <Standalone question>
 // """
 
-export const defaultCondenseQuestionPrompt: SimplePrompt = (input) => {
-  const { chatHistory, question } = input;
-
+export const defaultCondenseQuestionPrompt = ({
+  chatHistory = "",
+  question = "",
+}) => {
   return `Given a conversation (between Human and Assistant) and a follow up message from Human, rewrite the message to be a standalone question that captures all relevant context from the conversation.
 
 <Chat History>
@@ -327,6 +332,8 @@ ${question}
 `;
 };
 
+export type CondenseQuestionPrompt = typeof defaultCondenseQuestionPrompt;
+
 export function messagesToHistoryStr(messages: ChatMessage[]) {
   return messages.reduce((acc, message) => {
     acc += acc ? "\n" : "";
@@ -339,11 +346,11 @@ export function messagesToHistoryStr(messages: ChatMessage[]) {
   }, "");
 }
 
-export const contextSystemPrompt: SimplePrompt = (input) => {
-  const { context } = input;
-
+export const defaultContextSystemPrompt = ({ context = "" }) => {
   return `Context information is below.
 ---------------------
 ${context}
 ---------------------`;
 };
+
+export type ContextSystemPrompt = typeof defaultContextSystemPrompt;
diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts
index 24d46e954..ada522b57 100644
--- a/packages/core/src/QuestionGenerator.ts
+++ b/packages/core/src/QuestionGenerator.ts
@@ -4,7 +4,7 @@ import {
   SubQuestionOutputParser,
 } from "./OutputParser";
 import {
-  SimplePrompt,
+  SubQuestionPrompt,
   buildToolsText,
   defaultSubQuestionPrompt,
 } from "./Prompt";
@@ -28,7 +28,7 @@ export interface BaseQuestionGenerator {
  */
 export class LLMQuestionGenerator implements BaseQuestionGenerator {
   llm: LLM;
-  prompt: SimplePrompt;
+  prompt: SubQuestionPrompt;
   outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>;
 
   constructor(init?: Partial<LLMQuestionGenerator>) {
@@ -45,7 +45,7 @@ export class LLMQuestionGenerator implements BaseQuestionGenerator {
         this.prompt({
           toolsStr,
           queryStr,
-        })
+        }),
       )
     ).message.content;
 
diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts
index d781d5a19..912c02516 100644
--- a/packages/core/src/ResponseSynthesizer.ts
+++ b/packages/core/src/ResponseSynthesizer.ts
@@ -1,6 +1,9 @@
 import { MetadataMode, NodeWithScore } from "./Node";
 import {
+  RefinePrompt,
   SimplePrompt,
+  TextQaPrompt,
+  TreeSummarizePrompt,
   defaultRefinePrompt,
   defaultTextQaPrompt,
   defaultTreeSummarizePrompt,
@@ -73,13 +76,13 @@ export class SimpleResponseBuilder implements BaseResponseBuilder {
  */
 export class Refine implements BaseResponseBuilder {
   serviceContext: ServiceContext;
-  textQATemplate: SimplePrompt;
-  refineTemplate: SimplePrompt;
+  textQATemplate: TextQaPrompt;
+  refineTemplate: RefinePrompt;
 
   constructor(
     serviceContext: ServiceContext,
-    textQATemplate?: SimplePrompt,
-    refineTemplate?: SimplePrompt,
+    textQATemplate?: TextQaPrompt,
+    refineTemplate?: RefinePrompt,
   ) {
     this.serviceContext = serviceContext;
     this.textQATemplate = textQATemplate ?? defaultTextQaPrompt;
@@ -209,9 +212,14 @@ export class CompactAndRefine extends Refine {
  */
 export class TreeSummarize implements BaseResponseBuilder {
   serviceContext: ServiceContext;
+  summaryTemplate: TreeSummarizePrompt;
 
-  constructor(serviceContext: ServiceContext) {
+  constructor(
+    serviceContext: ServiceContext,
+    summaryTemplate?: TreeSummarizePrompt,
+  ) {
     this.serviceContext = serviceContext;
+    this.summaryTemplate = summaryTemplate ?? defaultTreeSummarizePrompt;
   }
 
   async getResponse(
@@ -219,21 +227,19 @@ export class TreeSummarize implements BaseResponseBuilder {
     textChunks: string[],
     parentEvent?: Event,
   ): Promise<string> {
-    const summaryTemplate: SimplePrompt = defaultTreeSummarizePrompt;
-
     if (!textChunks || textChunks.length === 0) {
       throw new Error("Must have at least one text chunk");
     }
 
     const packedTextChunks = this.serviceContext.promptHelper.repack(
-      summaryTemplate,
+      this.summaryTemplate,
       textChunks,
     );
 
     if (packedTextChunks.length === 1) {
       return (
         await this.serviceContext.llm.complete(
-          summaryTemplate({
+          this.summaryTemplate({
             context: packedTextChunks[0],
           }),
           parentEvent,
@@ -243,7 +249,7 @@ export class TreeSummarize implements BaseResponseBuilder {
       const summaries = await Promise.all(
         packedTextChunks.map((chunk) =>
           this.serviceContext.llm.complete(
-            summaryTemplate({
+            this.summaryTemplate({
               context: chunk,
             }),
             parentEvent,
@@ -298,9 +304,13 @@ export class ResponseSynthesizer {
     this.metadataMode = metadataMode;
   }
 
-  async synthesize(query: string, nodes: NodeWithScore[], parentEvent?: Event) {
-    let textChunks: string[] = nodes.map((node) =>
-      node.node.getContent(this.metadataMode)
+  async synthesize(
+    query: string,
+    nodesWithScore: NodeWithScore[],
+    parentEvent?: Event,
+  ) {
+    let textChunks: string[] = nodesWithScore.map(({ node }) =>
+      node.getContent(this.metadataMode),
     );
     const response = await this.responseBuilder.getResponse(
       query,
@@ -309,7 +319,7 @@ export class ResponseSynthesizer {
     );
     return new Response(
       response,
-      nodes.map((node) => node.node),
+      nodesWithScore.map(({ node }) => node),
     );
   }
 }
diff --git a/packages/core/src/indices/summary/SummaryIndexRetriever.ts b/packages/core/src/indices/summary/SummaryIndexRetriever.ts
index 61d9f2180..c7259ed8f 100644
--- a/packages/core/src/indices/summary/SummaryIndexRetriever.ts
+++ b/packages/core/src/indices/summary/SummaryIndexRetriever.ts
@@ -1,7 +1,7 @@
 import _ from "lodash";
 import { globalsHelper } from "../../GlobalsHelper";
 import { NodeWithScore } from "../../Node";
-import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt";
+import { ChoiceSelectPrompt, defaultChoiceSelectPrompt } from "../../Prompt";
 import { BaseRetriever } from "../../Retriever";
 import { ServiceContext } from "../../ServiceContext";
 import { Event } from "../../callbacks/CallbackManager";
@@ -55,7 +55,7 @@ export class SummaryIndexRetriever implements BaseRetriever {
  */
 export class SummaryIndexLLMRetriever implements BaseRetriever {
   index: SummaryIndex;
-  choiceSelectPrompt: SimplePrompt;
+  choiceSelectPrompt: ChoiceSelectPrompt;
   choiceBatchSize: number;
   formatNodeBatchFn: NodeFormatterFunction;
   parseChoiceSelectAnswerFn: ChoiceSelectParserFunction;
@@ -63,7 +63,7 @@ export class SummaryIndexLLMRetriever implements BaseRetriever {
 
   constructor(
     index: SummaryIndex,
-    choiceSelectPrompt?: SimplePrompt,
+    choiceSelectPrompt?: ChoiceSelectPrompt,
     choiceBatchSize: number = 10,
     formatNodeBatchFn?: NodeFormatterFunction,
     parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction,
-- 
GitLab