From 3489e7de849cf5dc7e9ce13f604fcbf9a0a58229 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Sun, 6 Oct 2024 18:19:05 -0700 Subject: [PATCH] fix: num output incorrect in prompt helper (#1303) --- .changeset/gold-jokes-deny.md | 5 + packages/core/src/indices/prompt-helper.ts | 131 ++++++----- packages/core/src/node-parser/index.ts | 2 + .../src/node-parser/token-text-splitter.ts | 206 ++++++++++++++++++ packages/core/src/node-parser/utils.ts | 5 +- packages/core/src/prompts/prompt.ts | 6 +- 6 files changed, 297 insertions(+), 58 deletions(-) create mode 100644 .changeset/gold-jokes-deny.md create mode 100644 packages/core/src/node-parser/token-text-splitter.ts diff --git a/.changeset/gold-jokes-deny.md b/.changeset/gold-jokes-deny.md new file mode 100644 index 000000000..792ffc5d2 --- /dev/null +++ b/.changeset/gold-jokes-deny.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/core": patch +--- + +fix: num output incorrect in prompt helper diff --git a/packages/core/src/indices/prompt-helper.ts b/packages/core/src/indices/prompt-helper.ts index 3d1b5adcc..960ffe31f 100644 --- a/packages/core/src/indices/prompt-helper.ts +++ b/packages/core/src/indices/prompt-helper.ts @@ -8,18 +8,16 @@ import { Settings, } from "../global"; import type { LLMMetadata } from "../llms"; -import { SentenceSplitter } from "../node-parser"; -import type { PromptTemplate } from "../prompts"; +import { TextSplitter, TokenTextSplitter, truncateText } from "../node-parser"; +import { BasePromptTemplate, PromptTemplate } from "../prompts"; /** * Get the empty prompt text given a prompt. */ -function getEmptyPromptTxt(prompt: PromptTemplate) { - return prompt.format({ - ...Object.fromEntries( - [...prompt.templateVars.keys()].map((key) => [key, ""]), - ), - }); +function getEmptyPromptTxt(prompt: PromptTemplate): string { + return prompt.format( + Object.fromEntries([...prompt.templateVars.keys()].map((key) => [key, ""])), + ); } /** @@ -35,24 +33,24 @@ export function getBiggestPrompt(prompts: PromptTemplate[]): PromptTemplate { } export type PromptHelperOptions = { - contextWindow?: number; - numOutput?: number; - chunkOverlapRatio?: number; - chunkSizeLimit?: number; - tokenizer?: Tokenizer; - separator?: string; + contextWindow?: number | undefined; + numOutput?: number | undefined; + chunkOverlapRatio?: number | undefined; + chunkSizeLimit?: number | undefined; + tokenizer?: Tokenizer | undefined; + separator?: string | undefined; }; /** * A collection of helper functions for working with prompts. */ export class PromptHelper { - contextWindow = DEFAULT_CONTEXT_WINDOW; - numOutput = DEFAULT_NUM_OUTPUTS; - chunkOverlapRatio = DEFAULT_CHUNK_OVERLAP_RATIO; + contextWindow: number; + numOutput: number; + chunkOverlapRatio: number; chunkSizeLimit: number | undefined; tokenizer: Tokenizer; - separator = " "; + separator: string; constructor(options: PromptHelperOptions = {}) { const { @@ -72,71 +70,93 @@ export class PromptHelper { } /** - * Given a prompt, return the maximum size of the inputs to the prompt. - * @param prompt - * @returns + * Calculate the available context size based on the number of prompt tokens. */ - #getAvailableContextSize(prompt: PromptTemplate) { - const emptyPromptText = getEmptyPromptTxt(prompt); - const promptTokens = this.tokenizer.encode(emptyPromptText); - const numPromptTokens = promptTokens.length; - - return this.contextWindow - numPromptTokens - this.numOutput; + #getAvailableContextSize(numPromptTokens: number): number { + const contextSizeTokens = + this.contextWindow - numPromptTokens - this.numOutput; + if (contextSizeTokens < 0) { + throw new Error( + `Calculated available context size ${contextSizeTokens} is not non-negative.`, + ); + } + return contextSizeTokens; } /** - * Find the maximum size of each chunk given a prompt. + * Calculate the available chunk size based on the prompt and other parameters. */ - #getAvailableChunkSize( - prompt: PromptTemplate, - numChunks = 1, - padding = 5, + #getAvailableChunkSize<Template extends BasePromptTemplate>( + prompt: Template, + numChunks: number = 1, + padding: number = 5, ): number { - const availableContextSize = this.#getAvailableContextSize(prompt); + let numPromptTokens = 0; + + if (prompt instanceof PromptTemplate) { + numPromptTokens = this.tokenizer.encode(getEmptyPromptTxt(prompt)).length; + } - const result = Math.floor(availableContextSize / numChunks) - padding; + const availableContextSize = this.#getAvailableContextSize(numPromptTokens); + let result = Math.floor(availableContextSize / numChunks) - padding; - if (this.chunkSizeLimit) { - return Math.min(this.chunkSizeLimit, result); - } else { - return result; + if (this.chunkSizeLimit !== undefined) { + result = Math.min(this.chunkSizeLimit, result); } + + return result; } /** - * Creates a text splitter with the correct chunk sizes and overlaps given a prompt. + * Creates a text splitter configured to maximally pack the available context window. */ getTextSplitterGivenPrompt( - prompt: PromptTemplate, - numChunks = 1, - padding = DEFAULT_PADDING, - ) { + prompt: BasePromptTemplate, + numChunks: number = 1, + padding: number = DEFAULT_PADDING, + ): TextSplitter { const chunkSize = this.#getAvailableChunkSize(prompt, numChunks, padding); if (chunkSize <= 0) { - /** - * If you see this error, it means that the input is larger than LLM context window. - */ throw new TypeError(`Chunk size ${chunkSize} is not positive.`); } - const chunkOverlap = this.chunkOverlapRatio * chunkSize; - return new SentenceSplitter({ + const chunkOverlap = Math.floor(this.chunkOverlapRatio * chunkSize); + return new TokenTextSplitter({ + separator: this.separator, chunkSize, chunkOverlap, - separator: this.separator, tokenizer: this.tokenizer, }); } /** - * Repack resplits the strings based on the optimal text splitter. + * Truncate text chunks to fit within the available context window. + */ + truncate( + prompt: BasePromptTemplate, + textChunks: string[], + padding: number = DEFAULT_PADDING, + ): string[] { + const textSplitter = this.getTextSplitterGivenPrompt( + prompt, + textChunks.length, + padding, + ); + return textChunks.map((chunk) => truncateText(chunk, textSplitter)); + } + + /** + * Repack text chunks to better utilize the available context window. */ repack( - prompt: PromptTemplate, + prompt: BasePromptTemplate, textChunks: string[], - padding = DEFAULT_PADDING, - ) { + padding: number = DEFAULT_PADDING, + ): string[] { const textSplitter = this.getTextSplitterGivenPrompt(prompt, 1, padding); - const combinedStr = textChunks.join("\n\n"); + const combinedStr = textChunks + .map((c) => c.trim()) + .filter((c) => c.length > 0) + .join("\n\n"); return textSplitter.splitText(combinedStr); } @@ -157,7 +177,8 @@ export class PromptHelper { } = options ?? {}; return new PromptHelper({ contextWindow: metadata.contextWindow, - numOutput: metadata.maxTokens ?? DEFAULT_NUM_OUTPUTS, + // fixme: numOutput is not in LLMMetadata + numOutput: DEFAULT_NUM_OUTPUTS, chunkOverlapRatio, chunkSizeLimit, tokenizer, diff --git a/packages/core/src/node-parser/index.ts b/packages/core/src/node-parser/index.ts index 6ecc98005..48a5c1902 100644 --- a/packages/core/src/node-parser/index.ts +++ b/packages/core/src/node-parser/index.ts @@ -13,6 +13,7 @@ export { MetadataAwareTextSplitter, NodeParser, TextSplitter } from "./base"; export { MarkdownNodeParser } from "./markdown"; export { SentenceSplitter } from "./sentence-splitter"; export { SentenceWindowNodeParser } from "./sentence-window"; +export { TokenTextSplitter } from "./token-text-splitter"; export type { SplitterParams } from "./type"; export { splitByChar, @@ -20,5 +21,6 @@ export { splitByRegex, splitBySentenceTokenizer, splitBySep, + truncateText, } from "./utils"; export type { TextSplitterFn } from "./utils"; diff --git a/packages/core/src/node-parser/token-text-splitter.ts b/packages/core/src/node-parser/token-text-splitter.ts new file mode 100644 index 000000000..e4f7b8dd8 --- /dev/null +++ b/packages/core/src/node-parser/token-text-splitter.ts @@ -0,0 +1,206 @@ +import type { Tokenizer } from "@llamaindex/env"; +import { z } from "zod"; +import { DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE, Settings } from "../global"; +import { MetadataAwareTextSplitter } from "./base"; +import type { SplitterParams } from "./type"; +import { splitByChar, splitBySep } from "./utils"; + +const DEFAULT_METADATA_FORMAT_LEN = 2; + +const tokenTextSplitterSchema = z.object({ + chunkSize: z.number().positive().default(DEFAULT_CHUNK_SIZE), + chunkOverlap: z.number().nonnegative().default(DEFAULT_CHUNK_OVERLAP), + separator: z.string().default(" "), + backupSeparators: z.array(z.string()).default(["\n"]), +}); + +export class TokenTextSplitter extends MetadataAwareTextSplitter { + chunkSize: number = DEFAULT_CHUNK_SIZE; + chunkOverlap: number = DEFAULT_CHUNK_OVERLAP; + separator: string = " "; + backupSeparators: string[] = ["\n"]; + #tokenizer: Tokenizer; + #splitFns: Array<(text: string) => string[]> = []; + + constructor( + params?: SplitterParams & Partial<z.infer<typeof tokenTextSplitterSchema>>, + ) { + super(); + + if (params) { + const parsedParams = tokenTextSplitterSchema.parse(params); + this.chunkSize = parsedParams.chunkSize; + this.chunkOverlap = parsedParams.chunkOverlap; + this.separator = parsedParams.separator; + this.backupSeparators = parsedParams.backupSeparators; + } + + if (this.chunkOverlap > this.chunkSize) { + throw new Error( + `Got a larger chunk overlap (${this.chunkOverlap}) than chunk size (${this.chunkSize}), should be smaller.`, + ); + } + + this.#tokenizer = params?.tokenizer ?? Settings.tokenizer; + + const allSeparators = [this.separator, ...this.backupSeparators]; + this.#splitFns = allSeparators.map((sep) => splitBySep(sep)); + this.#splitFns.push(splitByChar()); + } + + /** + * Split text into chunks, reserving space required for metadata string. + * @param text The text to split. + * @param metadata The metadata string. + * @returns An array of text chunks. + */ + splitTextMetadataAware(text: string, metadata: string): string[] { + const metadataLength = + this.tokenSize(metadata) + DEFAULT_METADATA_FORMAT_LEN; + const effectiveChunkSize = this.chunkSize - metadataLength; + + if (effectiveChunkSize <= 0) { + throw new Error( + `Metadata length (${metadataLength}) is longer than chunk size (${this.chunkSize}). ` + + `Consider increasing the chunk size or decreasing the size of your metadata to avoid this.`, + ); + } else if (effectiveChunkSize < 50) { + console.warn( + `Metadata length (${metadataLength}) is close to chunk size (${this.chunkSize}). ` + + `Resulting chunks are less than 50 tokens. Consider increasing the chunk size or decreasing the size of your metadata to avoid this.`, + ); + } + + return this._splitText(text, effectiveChunkSize); + } + + /** + * Split text into chunks. + * @param text The text to split. + * @returns An array of text chunks. + */ + splitText(text: string): string[] { + return this._splitText(text, this.chunkSize); + } + + /** + * Internal method to split text into chunks up to a specified size. + * @param text The text to split. + * @param chunkSize The maximum size of each chunk. + * @returns An array of text chunks. + */ + private _splitText(text: string, chunkSize: number): string[] { + if (text === "") return [text]; + + // Dispatch chunking start event + Settings.callbackManager.dispatchEvent("chunking-start", { text: [text] }); + + const splits = this._split(text, chunkSize); + const chunks = this._merge(splits, chunkSize); + + Settings.callbackManager.dispatchEvent("chunking-end", { chunks }); + + return chunks; + } + + /** + * Break text into splits that are smaller than the chunk size. + * @param text The text to split. + * @param chunkSize The maximum size of each split. + * @returns An array of text splits. + */ + private _split(text: string, chunkSize: number): string[] { + if (this.tokenSize(text) <= chunkSize) { + return [text]; + } + + for (const splitFn of this.#splitFns) { + const splits = splitFn(text); + if (splits.length > 1) { + const newSplits: string[] = []; + for (const split of splits) { + const splitLen = this.tokenSize(split); + if (splitLen <= chunkSize) { + newSplits.push(split); + } else { + newSplits.push(...this._split(split, chunkSize)); + } + } + return newSplits; + } + } + + return [text]; + } + + /** + * Merge splits into chunks with overlap. + * @param splits The array of text splits. + * @param chunkSize The maximum size of each chunk. + * @returns An array of merged text chunks. + */ + private _merge(splits: string[], chunkSize: number): string[] { + const chunks: string[] = []; + let currentChunk: string[] = []; + let currentLength = 0; + + for (const split of splits) { + const splitLength = this.tokenSize(split); + + if (splitLength > chunkSize) { + console.warn( + `Got a split of size ${splitLength}, larger than chunk size ${chunkSize}.`, + ); + } + + if (currentLength + splitLength > chunkSize) { + const chunk = currentChunk.join("").trim(); + if (chunk) { + chunks.push(chunk); + } + + currentChunk = []; + currentLength = 0; + + const overlapTokens = this.chunkOverlap; + const overlapSplits: string[] = []; + + let overlapLength = 0; + while ( + overlapSplits.length < splits.length && + overlapLength < overlapTokens + ) { + const overlapSplit = currentChunk.shift(); + if (!overlapSplit) break; + overlapSplits.push(overlapSplit); + overlapLength += this.tokenSize(overlapSplit); + } + + for (const overlapSplit of overlapSplits.reverse()) { + currentChunk.push(overlapSplit); + currentLength += this.tokenSize(overlapSplit); + if (currentLength >= overlapTokens) break; + } + } + + currentChunk.push(split); + currentLength += splitLength; + } + + const finalChunk = currentChunk.join("").trim(); + if (finalChunk) { + chunks.push(finalChunk); + } + + return chunks; + } + + /** + * Calculate the number of tokens in the text using the tokenizer. + * @param text The text to tokenize. + * @returns The number of tokens. + */ + private tokenSize(text: string): number { + return this.#tokenizer.encode(text).length; + } +} diff --git a/packages/core/src/node-parser/utils.ts b/packages/core/src/node-parser/utils.ts index eb5a7cbf3..4d6de3657 100644 --- a/packages/core/src/node-parser/utils.ts +++ b/packages/core/src/node-parser/utils.ts @@ -3,7 +3,10 @@ import SentenceTokenizer from "./sentence_tokenizer"; export type TextSplitterFn = (text: string) => string[]; -const truncateText = (text: string, textSplitter: TextSplitter): string => { +export const truncateText = ( + text: string, + textSplitter: TextSplitter, +): string => { const chunks = textSplitter.splitText(text); return chunks[0] ?? text; }; diff --git a/packages/core/src/prompts/prompt.ts b/packages/core/src/prompts/prompt.ts index 53851f25f..aeb6a7e64 100644 --- a/packages/core/src/prompts/prompt.ts +++ b/packages/core/src/prompts/prompt.ts @@ -64,11 +64,13 @@ export const defaultRefinePrompt: RefinePrompt = new PromptTemplate({ templateVars: ["query", "existingAnswer", "context"], template: `The original query is as follows: {query} We have provided an existing answer: {existingAnswer} -We have the opportunity to refine the existing answer (only if needed) with some more context below. +We have the opportunity to refine the existing answer +(only if needed) with some more context below. ------------ {context} ------------ -Given the new context, refine the original answer to better answer the query. If the context isn't useful, return the original answer. +Given the new context, refine the original answer to better answer the query. +If the context isn't useful, return the original answer. Refined Answer:`, }); -- GitLab