From 3489e7de849cf5dc7e9ce13f604fcbf9a0a58229 Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Sun, 6 Oct 2024 18:19:05 -0700
Subject: [PATCH] fix: num output incorrect in prompt helper (#1303)

---
 .changeset/gold-jokes-deny.md                 |   5 +
 packages/core/src/indices/prompt-helper.ts    | 131 ++++++-----
 packages/core/src/node-parser/index.ts        |   2 +
 .../src/node-parser/token-text-splitter.ts    | 206 ++++++++++++++++++
 packages/core/src/node-parser/utils.ts        |   5 +-
 packages/core/src/prompts/prompt.ts           |   6 +-
 6 files changed, 297 insertions(+), 58 deletions(-)
 create mode 100644 .changeset/gold-jokes-deny.md
 create mode 100644 packages/core/src/node-parser/token-text-splitter.ts

diff --git a/.changeset/gold-jokes-deny.md b/.changeset/gold-jokes-deny.md
new file mode 100644
index 000000000..792ffc5d2
--- /dev/null
+++ b/.changeset/gold-jokes-deny.md
@@ -0,0 +1,5 @@
+---
+"@llamaindex/core": patch
+---
+
+fix: num output incorrect in prompt helper
diff --git a/packages/core/src/indices/prompt-helper.ts b/packages/core/src/indices/prompt-helper.ts
index 3d1b5adcc..960ffe31f 100644
--- a/packages/core/src/indices/prompt-helper.ts
+++ b/packages/core/src/indices/prompt-helper.ts
@@ -8,18 +8,16 @@ import {
   Settings,
 } from "../global";
 import type { LLMMetadata } from "../llms";
-import { SentenceSplitter } from "../node-parser";
-import type { PromptTemplate } from "../prompts";
+import { TextSplitter, TokenTextSplitter, truncateText } from "../node-parser";
+import { BasePromptTemplate, PromptTemplate } from "../prompts";
 
 /**
  * Get the empty prompt text given a prompt.
  */
-function getEmptyPromptTxt(prompt: PromptTemplate) {
-  return prompt.format({
-    ...Object.fromEntries(
-      [...prompt.templateVars.keys()].map((key) => [key, ""]),
-    ),
-  });
+function getEmptyPromptTxt(prompt: PromptTemplate): string {
+  return prompt.format(
+    Object.fromEntries([...prompt.templateVars.keys()].map((key) => [key, ""])),
+  );
 }
 
 /**
@@ -35,24 +33,24 @@ export function getBiggestPrompt(prompts: PromptTemplate[]): PromptTemplate {
 }
 
 export type PromptHelperOptions = {
-  contextWindow?: number;
-  numOutput?: number;
-  chunkOverlapRatio?: number;
-  chunkSizeLimit?: number;
-  tokenizer?: Tokenizer;
-  separator?: string;
+  contextWindow?: number | undefined;
+  numOutput?: number | undefined;
+  chunkOverlapRatio?: number | undefined;
+  chunkSizeLimit?: number | undefined;
+  tokenizer?: Tokenizer | undefined;
+  separator?: string | undefined;
 };
 
 /**
  * A collection of helper functions for working with prompts.
  */
 export class PromptHelper {
-  contextWindow = DEFAULT_CONTEXT_WINDOW;
-  numOutput = DEFAULT_NUM_OUTPUTS;
-  chunkOverlapRatio = DEFAULT_CHUNK_OVERLAP_RATIO;
+  contextWindow: number;
+  numOutput: number;
+  chunkOverlapRatio: number;
   chunkSizeLimit: number | undefined;
   tokenizer: Tokenizer;
-  separator = " ";
+  separator: string;
 
   constructor(options: PromptHelperOptions = {}) {
     const {
@@ -72,71 +70,93 @@ export class PromptHelper {
   }
 
   /**
-   * Given a prompt, return the maximum size of the inputs to the prompt.
-   * @param prompt
-   * @returns
+   * Calculate the available context size based on the number of prompt tokens.
    */
-  #getAvailableContextSize(prompt: PromptTemplate) {
-    const emptyPromptText = getEmptyPromptTxt(prompt);
-    const promptTokens = this.tokenizer.encode(emptyPromptText);
-    const numPromptTokens = promptTokens.length;
-
-    return this.contextWindow - numPromptTokens - this.numOutput;
+  #getAvailableContextSize(numPromptTokens: number): number {
+    const contextSizeTokens =
+      this.contextWindow - numPromptTokens - this.numOutput;
+    if (contextSizeTokens < 0) {
+      throw new Error(
+        `Calculated available context size ${contextSizeTokens} is not non-negative.`,
+      );
+    }
+    return contextSizeTokens;
   }
 
   /**
-   * Find the maximum size of each chunk given a prompt.
+   * Calculate the available chunk size based on the prompt and other parameters.
    */
-  #getAvailableChunkSize(
-    prompt: PromptTemplate,
-    numChunks = 1,
-    padding = 5,
+  #getAvailableChunkSize<Template extends BasePromptTemplate>(
+    prompt: Template,
+    numChunks: number = 1,
+    padding: number = 5,
   ): number {
-    const availableContextSize = this.#getAvailableContextSize(prompt);
+    let numPromptTokens = 0;
+
+    if (prompt instanceof PromptTemplate) {
+      numPromptTokens = this.tokenizer.encode(getEmptyPromptTxt(prompt)).length;
+    }
 
-    const result = Math.floor(availableContextSize / numChunks) - padding;
+    const availableContextSize = this.#getAvailableContextSize(numPromptTokens);
+    let result = Math.floor(availableContextSize / numChunks) - padding;
 
-    if (this.chunkSizeLimit) {
-      return Math.min(this.chunkSizeLimit, result);
-    } else {
-      return result;
+    if (this.chunkSizeLimit !== undefined) {
+      result = Math.min(this.chunkSizeLimit, result);
     }
+
+    return result;
   }
 
   /**
-   * Creates a text splitter with the correct chunk sizes and overlaps given a prompt.
+   * Creates a text splitter configured to maximally pack the available context window.
    */
   getTextSplitterGivenPrompt(
-    prompt: PromptTemplate,
-    numChunks = 1,
-    padding = DEFAULT_PADDING,
-  ) {
+    prompt: BasePromptTemplate,
+    numChunks: number = 1,
+    padding: number = DEFAULT_PADDING,
+  ): TextSplitter {
     const chunkSize = this.#getAvailableChunkSize(prompt, numChunks, padding);
     if (chunkSize <= 0) {
-      /**
-       * If you see this error, it means that the input is larger than LLM context window.
-       */
       throw new TypeError(`Chunk size ${chunkSize} is not positive.`);
     }
-    const chunkOverlap = this.chunkOverlapRatio * chunkSize;
-    return new SentenceSplitter({
+    const chunkOverlap = Math.floor(this.chunkOverlapRatio * chunkSize);
+    return new TokenTextSplitter({
+      separator: this.separator,
       chunkSize,
       chunkOverlap,
-      separator: this.separator,
       tokenizer: this.tokenizer,
     });
   }
 
   /**
-   * Repack resplits the strings based on the optimal text splitter.
+   * Truncate text chunks to fit within the available context window.
+   */
+  truncate(
+    prompt: BasePromptTemplate,
+    textChunks: string[],
+    padding: number = DEFAULT_PADDING,
+  ): string[] {
+    const textSplitter = this.getTextSplitterGivenPrompt(
+      prompt,
+      textChunks.length,
+      padding,
+    );
+    return textChunks.map((chunk) => truncateText(chunk, textSplitter));
+  }
+
+  /**
+   * Repack text chunks to better utilize the available context window.
    */
   repack(
-    prompt: PromptTemplate,
+    prompt: BasePromptTemplate,
     textChunks: string[],
-    padding = DEFAULT_PADDING,
-  ) {
+    padding: number = DEFAULT_PADDING,
+  ): string[] {
     const textSplitter = this.getTextSplitterGivenPrompt(prompt, 1, padding);
-    const combinedStr = textChunks.join("\n\n");
+    const combinedStr = textChunks
+      .map((c) => c.trim())
+      .filter((c) => c.length > 0)
+      .join("\n\n");
     return textSplitter.splitText(combinedStr);
   }
 
@@ -157,7 +177,8 @@ export class PromptHelper {
     } = options ?? {};
     return new PromptHelper({
       contextWindow: metadata.contextWindow,
-      numOutput: metadata.maxTokens ?? DEFAULT_NUM_OUTPUTS,
+      // fixme: numOutput is not in LLMMetadata
+      numOutput: DEFAULT_NUM_OUTPUTS,
       chunkOverlapRatio,
       chunkSizeLimit,
       tokenizer,
diff --git a/packages/core/src/node-parser/index.ts b/packages/core/src/node-parser/index.ts
index 6ecc98005..48a5c1902 100644
--- a/packages/core/src/node-parser/index.ts
+++ b/packages/core/src/node-parser/index.ts
@@ -13,6 +13,7 @@ export { MetadataAwareTextSplitter, NodeParser, TextSplitter } from "./base";
 export { MarkdownNodeParser } from "./markdown";
 export { SentenceSplitter } from "./sentence-splitter";
 export { SentenceWindowNodeParser } from "./sentence-window";
+export { TokenTextSplitter } from "./token-text-splitter";
 export type { SplitterParams } from "./type";
 export {
   splitByChar,
@@ -20,5 +21,6 @@ export {
   splitByRegex,
   splitBySentenceTokenizer,
   splitBySep,
+  truncateText,
 } from "./utils";
 export type { TextSplitterFn } from "./utils";
diff --git a/packages/core/src/node-parser/token-text-splitter.ts b/packages/core/src/node-parser/token-text-splitter.ts
new file mode 100644
index 000000000..e4f7b8dd8
--- /dev/null
+++ b/packages/core/src/node-parser/token-text-splitter.ts
@@ -0,0 +1,206 @@
+import type { Tokenizer } from "@llamaindex/env";
+import { z } from "zod";
+import { DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE, Settings } from "../global";
+import { MetadataAwareTextSplitter } from "./base";
+import type { SplitterParams } from "./type";
+import { splitByChar, splitBySep } from "./utils";
+
+const DEFAULT_METADATA_FORMAT_LEN = 2;
+
+const tokenTextSplitterSchema = z.object({
+  chunkSize: z.number().positive().default(DEFAULT_CHUNK_SIZE),
+  chunkOverlap: z.number().nonnegative().default(DEFAULT_CHUNK_OVERLAP),
+  separator: z.string().default(" "),
+  backupSeparators: z.array(z.string()).default(["\n"]),
+});
+
+export class TokenTextSplitter extends MetadataAwareTextSplitter {
+  chunkSize: number = DEFAULT_CHUNK_SIZE;
+  chunkOverlap: number = DEFAULT_CHUNK_OVERLAP;
+  separator: string = " ";
+  backupSeparators: string[] = ["\n"];
+  #tokenizer: Tokenizer;
+  #splitFns: Array<(text: string) => string[]> = [];
+
+  constructor(
+    params?: SplitterParams & Partial<z.infer<typeof tokenTextSplitterSchema>>,
+  ) {
+    super();
+
+    if (params) {
+      const parsedParams = tokenTextSplitterSchema.parse(params);
+      this.chunkSize = parsedParams.chunkSize;
+      this.chunkOverlap = parsedParams.chunkOverlap;
+      this.separator = parsedParams.separator;
+      this.backupSeparators = parsedParams.backupSeparators;
+    }
+
+    if (this.chunkOverlap > this.chunkSize) {
+      throw new Error(
+        `Got a larger chunk overlap (${this.chunkOverlap}) than chunk size (${this.chunkSize}), should be smaller.`,
+      );
+    }
+
+    this.#tokenizer = params?.tokenizer ?? Settings.tokenizer;
+
+    const allSeparators = [this.separator, ...this.backupSeparators];
+    this.#splitFns = allSeparators.map((sep) => splitBySep(sep));
+    this.#splitFns.push(splitByChar());
+  }
+
+  /**
+   * Split text into chunks, reserving space required for metadata string.
+   * @param text The text to split.
+   * @param metadata The metadata string.
+   * @returns An array of text chunks.
+   */
+  splitTextMetadataAware(text: string, metadata: string): string[] {
+    const metadataLength =
+      this.tokenSize(metadata) + DEFAULT_METADATA_FORMAT_LEN;
+    const effectiveChunkSize = this.chunkSize - metadataLength;
+
+    if (effectiveChunkSize <= 0) {
+      throw new Error(
+        `Metadata length (${metadataLength}) is longer than chunk size (${this.chunkSize}). ` +
+          `Consider increasing the chunk size or decreasing the size of your metadata to avoid this.`,
+      );
+    } else if (effectiveChunkSize < 50) {
+      console.warn(
+        `Metadata length (${metadataLength}) is close to chunk size (${this.chunkSize}). ` +
+          `Resulting chunks are less than 50 tokens. Consider increasing the chunk size or decreasing the size of your metadata to avoid this.`,
+      );
+    }
+
+    return this._splitText(text, effectiveChunkSize);
+  }
+
+  /**
+   * Split text into chunks.
+   * @param text The text to split.
+   * @returns An array of text chunks.
+   */
+  splitText(text: string): string[] {
+    return this._splitText(text, this.chunkSize);
+  }
+
+  /**
+   * Internal method to split text into chunks up to a specified size.
+   * @param text The text to split.
+   * @param chunkSize The maximum size of each chunk.
+   * @returns An array of text chunks.
+   */
+  private _splitText(text: string, chunkSize: number): string[] {
+    if (text === "") return [text];
+
+    // Dispatch chunking start event
+    Settings.callbackManager.dispatchEvent("chunking-start", { text: [text] });
+
+    const splits = this._split(text, chunkSize);
+    const chunks = this._merge(splits, chunkSize);
+
+    Settings.callbackManager.dispatchEvent("chunking-end", { chunks });
+
+    return chunks;
+  }
+
+  /**
+   * Break text into splits that are smaller than the chunk size.
+   * @param text The text to split.
+   * @param chunkSize The maximum size of each split.
+   * @returns An array of text splits.
+   */
+  private _split(text: string, chunkSize: number): string[] {
+    if (this.tokenSize(text) <= chunkSize) {
+      return [text];
+    }
+
+    for (const splitFn of this.#splitFns) {
+      const splits = splitFn(text);
+      if (splits.length > 1) {
+        const newSplits: string[] = [];
+        for (const split of splits) {
+          const splitLen = this.tokenSize(split);
+          if (splitLen <= chunkSize) {
+            newSplits.push(split);
+          } else {
+            newSplits.push(...this._split(split, chunkSize));
+          }
+        }
+        return newSplits;
+      }
+    }
+
+    return [text];
+  }
+
+  /**
+   * Merge splits into chunks with overlap.
+   * @param splits The array of text splits.
+   * @param chunkSize The maximum size of each chunk.
+   * @returns An array of merged text chunks.
+   */
+  private _merge(splits: string[], chunkSize: number): string[] {
+    const chunks: string[] = [];
+    let currentChunk: string[] = [];
+    let currentLength = 0;
+
+    for (const split of splits) {
+      const splitLength = this.tokenSize(split);
+
+      if (splitLength > chunkSize) {
+        console.warn(
+          `Got a split of size ${splitLength}, larger than chunk size ${chunkSize}.`,
+        );
+      }
+
+      if (currentLength + splitLength > chunkSize) {
+        const chunk = currentChunk.join("").trim();
+        if (chunk) {
+          chunks.push(chunk);
+        }
+
+        currentChunk = [];
+        currentLength = 0;
+
+        const overlapTokens = this.chunkOverlap;
+        const overlapSplits: string[] = [];
+
+        let overlapLength = 0;
+        while (
+          overlapSplits.length < splits.length &&
+          overlapLength < overlapTokens
+        ) {
+          const overlapSplit = currentChunk.shift();
+          if (!overlapSplit) break;
+          overlapSplits.push(overlapSplit);
+          overlapLength += this.tokenSize(overlapSplit);
+        }
+
+        for (const overlapSplit of overlapSplits.reverse()) {
+          currentChunk.push(overlapSplit);
+          currentLength += this.tokenSize(overlapSplit);
+          if (currentLength >= overlapTokens) break;
+        }
+      }
+
+      currentChunk.push(split);
+      currentLength += splitLength;
+    }
+
+    const finalChunk = currentChunk.join("").trim();
+    if (finalChunk) {
+      chunks.push(finalChunk);
+    }
+
+    return chunks;
+  }
+
+  /**
+   * Calculate the number of tokens in the text using the tokenizer.
+   * @param text The text to tokenize.
+   * @returns The number of tokens.
+   */
+  private tokenSize(text: string): number {
+    return this.#tokenizer.encode(text).length;
+  }
+}
diff --git a/packages/core/src/node-parser/utils.ts b/packages/core/src/node-parser/utils.ts
index eb5a7cbf3..4d6de3657 100644
--- a/packages/core/src/node-parser/utils.ts
+++ b/packages/core/src/node-parser/utils.ts
@@ -3,7 +3,10 @@ import SentenceTokenizer from "./sentence_tokenizer";
 
 export type TextSplitterFn = (text: string) => string[];
 
-const truncateText = (text: string, textSplitter: TextSplitter): string => {
+export const truncateText = (
+  text: string,
+  textSplitter: TextSplitter,
+): string => {
   const chunks = textSplitter.splitText(text);
   return chunks[0] ?? text;
 };
diff --git a/packages/core/src/prompts/prompt.ts b/packages/core/src/prompts/prompt.ts
index 53851f25f..aeb6a7e64 100644
--- a/packages/core/src/prompts/prompt.ts
+++ b/packages/core/src/prompts/prompt.ts
@@ -64,11 +64,13 @@ export const defaultRefinePrompt: RefinePrompt = new PromptTemplate({
   templateVars: ["query", "existingAnswer", "context"],
   template: `The original query is as follows: {query}
 We have provided an existing answer: {existingAnswer}
-We have the opportunity to refine the existing answer (only if needed) with some more context below.
+We have the opportunity to refine the existing answer
+(only if needed) with some more context below.
 ------------
 {context}
 ------------
-Given the new context, refine the original answer to better answer the query. If the context isn't useful, return the original answer.
+Given the new context, refine the original answer to better answer the query.
+If the context isn't useful, return the original answer.
 Refined Answer:`,
 });
 
-- 
GitLab