diff --git a/.changeset/giant-buses-breathe.md b/.changeset/giant-buses-breathe.md new file mode 100644 index 0000000000000000000000000000000000000000..a5571ac6587c1cd587298ad01cf2b4cfa76708a1 --- /dev/null +++ b/.changeset/giant-buses-breathe.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Truncate text to embed for OpenAI if it exceeds maxTokens diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 739f021ea02cc26a1a390631e16a1cfe249c27f6..3d21e74e801ec544e15e7b9552aaaa7915bfd179 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -91,7 +91,7 @@ jobs: - cloudflare-worker-agent - nextjs-agent - nextjs-edge-runtime - - waku-query-engine + # - waku-query-engine runs-on: ubuntu-latest name: Build Core Example (${{ matrix.packages }}) steps: diff --git a/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts b/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts index 85c2963d159beb9b20c90a12f8b2efdc7dddb0ac..2efe159afba5b407eaf072c96d30f4f1304921d6 100644 --- a/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts +++ b/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts @@ -2,10 +2,12 @@ import { BaseNode, SimilarityType, type BaseEmbedding, + type EmbeddingInfo, type MessageContentDetail, } from "llamaindex"; export class OpenAIEmbedding implements BaseEmbedding { + embedInfo?: EmbeddingInfo | undefined; embedBatchSize = 512; async getQueryEmbedding(query: MessageContentDetail) { @@ -36,4 +38,8 @@ export class OpenAIEmbedding implements BaseEmbedding { nodes.forEach((node) => (node.embedding = [0])); return nodes; } + + truncateMaxTokens(input: string[]): string[] { + return input; + } } diff --git a/packages/core/package.json b/packages/core/package.json index 11bbf858f07e6839a799e49e911cf17311a730d0..e409777ee53d96cac835027f6b8d3ee9f4f62a25 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -57,6 +57,7 @@ "portkey-ai": "^0.1.16", "rake-modified": "^1.0.8", "string-strip-html": "^13.4.8", + "tiktoken": "^1.0.15", "unpdf": "^0.10.1", "wikipedia": "^2.1.2", "wink-nlp": "^2.3.0" diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index 1da4e612ba19ab967d1cbd919b0e520bea1047f8..7d0093e51e8c084f4d43892866969d7b3e6f22f4 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -1,4 +1,4 @@ -import { globalsHelper } from "./GlobalsHelper.js"; +import { tokenizers, type Tokenizer } from "@llamaindex/env"; import type { SummaryPrompt } from "./Prompt.js"; import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js"; import { OpenAI } from "./llm/openai.js"; @@ -70,8 +70,7 @@ export class SummaryChatHistory extends ChatHistory { * Tokenizer function that converts text to tokens, * this is used to calculate the number of tokens in a message. */ - tokenizer: (text: string) => Uint32Array = - globalsHelper.defaultTokenizer.encode; + tokenizer: Tokenizer; tokensToSummarize: number; messages: ChatMessage[]; summaryPrompt: SummaryPrompt; @@ -89,6 +88,7 @@ export class SummaryChatHistory extends ChatHistory { "LLM maxTokens is not set. Needed so the summarizer ensures the context window size of the LLM.", ); } + this.tokenizer = init?.tokenizer ?? tokenizers.tokenizer(); this.tokensToSummarize = this.llm.metadata.contextWindow - this.llm.metadata.maxTokens; if (this.tokensToSummarize < this.llm.metadata.contextWindow * 0.25) { @@ -116,7 +116,8 @@ export class SummaryChatHistory extends ChatHistory { // remove oldest message until the chat history is short enough for the context window messagesToSummarize.shift(); } while ( - this.tokenizer(promptMessages[0].content).length > this.tokensToSummarize + this.tokenizer.encode(promptMessages[0].content).length > + this.tokensToSummarize ); const response = await this.llm.chat({ @@ -195,7 +196,7 @@ export class SummaryChatHistory extends ChatHistory { // get tokens of current request messages and the transient messages const tokens = requestMessages.reduce( (count, message) => - count + this.tokenizer(extractText(message.content)).length, + count + this.tokenizer.encode(extractText(message.content)).length, 0, ); if (tokens > this.tokensToSummarize) { diff --git a/packages/core/src/GlobalsHelper.ts b/packages/core/src/GlobalsHelper.ts deleted file mode 100644 index 2df512ea2a0511f609d92473593793f1fedee6b8..0000000000000000000000000000000000000000 --- a/packages/core/src/GlobalsHelper.ts +++ /dev/null @@ -1,49 +0,0 @@ -import { encodingForModel } from "js-tiktoken"; - -export enum Tokenizers { - CL100K_BASE = "cl100k_base", -} - -/** - * @internal Helper class singleton - */ -class GlobalsHelper { - defaultTokenizer: { - encode: (text: string) => Uint32Array; - decode: (tokens: Uint32Array) => string; - }; - - constructor() { - const encoding = encodingForModel("text-embedding-ada-002"); // cl100k_base - - this.defaultTokenizer = { - encode: (text: string) => { - return new Uint32Array(encoding.encode(text)); - }, - decode: (tokens: Uint32Array) => { - const numberArray = Array.from(tokens); - const text = encoding.decode(numberArray); - const uint8Array = new TextEncoder().encode(text); - return new TextDecoder().decode(uint8Array); - }, - }; - } - - tokenizer(encoding?: Tokenizers) { - if (encoding && encoding !== Tokenizers.CL100K_BASE) { - throw new Error(`Tokenizer encoding ${encoding} not yet supported`); - } - - return this.defaultTokenizer!.encode.bind(this.defaultTokenizer); - } - - tokenizerDecoder(encoding?: Tokenizers) { - if (encoding && encoding !== Tokenizers.CL100K_BASE) { - throw new Error(`Tokenizer encoding ${encoding} not yet supported`); - } - - return this.defaultTokenizer!.decode.bind(this.defaultTokenizer); - } -} - -export const globalsHelper = new GlobalsHelper(); diff --git a/packages/core/src/PromptHelper.ts b/packages/core/src/PromptHelper.ts index 289e29835873deaef9b001a714708958ad1b01e0..809a75e6b4c6c368d872f3c89518cc6d17963206 100644 --- a/packages/core/src/PromptHelper.ts +++ b/packages/core/src/PromptHelper.ts @@ -1,4 +1,4 @@ -import { globalsHelper } from "./GlobalsHelper.js"; +import { tokenizers, type Tokenizer } from "@llamaindex/env"; import type { SimplePrompt } from "./Prompt.js"; import { SentenceSplitter } from "./TextSplitter.js"; import { @@ -34,7 +34,7 @@ export class PromptHelper { numOutput = DEFAULT_NUM_OUTPUTS; chunkOverlapRatio = DEFAULT_CHUNK_OVERLAP_RATIO; chunkSizeLimit?: number; - tokenizer: (text: string) => Uint32Array; + tokenizer: Tokenizer; separator = " "; // eslint-disable-next-line max-params @@ -43,14 +43,14 @@ export class PromptHelper { numOutput = DEFAULT_NUM_OUTPUTS, chunkOverlapRatio = DEFAULT_CHUNK_OVERLAP_RATIO, chunkSizeLimit?: number, - tokenizer?: (text: string) => Uint32Array, + tokenizer?: Tokenizer, separator = " ", ) { this.contextWindow = contextWindow; this.numOutput = numOutput; this.chunkOverlapRatio = chunkOverlapRatio; this.chunkSizeLimit = chunkSizeLimit; - this.tokenizer = tokenizer || globalsHelper.tokenizer(); + this.tokenizer = tokenizer ?? tokenizers.tokenizer(); this.separator = separator; } @@ -61,7 +61,7 @@ export class PromptHelper { */ private getAvailableContextSize(prompt: SimplePrompt) { const emptyPromptText = getEmptyPromptTxt(prompt); - const promptTokens = this.tokenizer(emptyPromptText); + const promptTokens = this.tokenizer.encode(emptyPromptText); const numPromptTokens = promptTokens.length; return this.contextWindow - numPromptTokens - this.numOutput; diff --git a/packages/core/src/TextSplitter.ts b/packages/core/src/TextSplitter.ts index c8594237f5f3349edf317ee7cb40df95d12b4d9d..f8e01121b74ba739565c6bf4c14ca7651fd0d11e 100644 --- a/packages/core/src/TextSplitter.ts +++ b/packages/core/src/TextSplitter.ts @@ -1,6 +1,5 @@ -import { EOL } from "@llamaindex/env"; +import { EOL, tokenizers, type Tokenizer } from "@llamaindex/env"; // GitHub translated -import { globalsHelper } from "./GlobalsHelper.js"; import { DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE } from "./constants.js"; class TextSplit { @@ -69,8 +68,7 @@ export class SentenceSplitter { public chunkSize: number; public chunkOverlap: number; - private tokenizer: any; - private tokenizerDecoder: any; + private tokenizer: Tokenizer; private paragraphSeparator: string; private chunkingTokenizerFn: (text: string) => string[]; private splitLongSentences: boolean; @@ -78,8 +76,7 @@ export class SentenceSplitter { constructor(options?: { chunkSize?: number; chunkOverlap?: number; - tokenizer?: any; - tokenizerDecoder?: any; + tokenizer?: Tokenizer; paragraphSeparator?: string; chunkingTokenizerFn?: (text: string) => string[]; splitLongSentences?: boolean; @@ -88,7 +85,6 @@ export class SentenceSplitter { chunkSize = DEFAULT_CHUNK_SIZE, chunkOverlap = DEFAULT_CHUNK_OVERLAP, tokenizer = null, - tokenizerDecoder = null, paragraphSeparator = defaultParagraphSeparator, chunkingTokenizerFn, splitLongSentences = false, @@ -102,9 +98,7 @@ export class SentenceSplitter { this.chunkSize = chunkSize; this.chunkOverlap = chunkOverlap; - this.tokenizer = tokenizer ?? globalsHelper.tokenizer(); - this.tokenizerDecoder = - tokenizerDecoder ?? globalsHelper.tokenizerDecoder(); + this.tokenizer = tokenizer ?? tokenizers.tokenizer(); this.paragraphSeparator = paragraphSeparator; this.chunkingTokenizerFn = chunkingTokenizerFn ?? defaultSentenceTokenizer; @@ -115,7 +109,8 @@ export class SentenceSplitter { // get "effective" chunk size by removing the metadata let effectiveChunkSize; if (extraInfoStr != undefined) { - const numExtraTokens = this.tokenizer(`${extraInfoStr}\n\n`).length + 1; + const numExtraTokens = + this.tokenizer.encode(`${extraInfoStr}\n\n`).length + 1; effectiveChunkSize = this.chunkSize - numExtraTokens; if (effectiveChunkSize <= 0) { throw new Error( @@ -190,19 +185,19 @@ export class SentenceSplitter { if (!this.splitLongSentences) { return sentenceSplits.map((split) => ({ text: split, - numTokens: this.tokenizer(split).length, + numTokens: this.tokenizer.encode(split).length, })); } const newSplits: SplitRep[] = []; for (const split of sentenceSplits) { - const splitTokens = this.tokenizer(split); + const splitTokens = this.tokenizer.encode(split); const splitLen = splitTokens.length; if (splitLen <= effectiveChunkSize) { newSplits.push({ text: split, numTokens: splitLen }); } else { for (let i = 0; i < splitLen; i += effectiveChunkSize) { - const cur_split = this.tokenizerDecoder( + const cur_split = this.tokenizer.decode( splitTokens.slice(i, i + effectiveChunkSize), ); newSplits.push({ text: cur_split, numTokens: effectiveChunkSize }); diff --git a/packages/core/src/embeddings/OpenAIEmbedding.ts b/packages/core/src/embeddings/OpenAIEmbedding.ts index a2d173a3f1c6b54f7ccd81fd6cb1b8da46839d99..03d4eadde68986466cfe1caf0c1c4aab97da7733 100644 --- a/packages/core/src/embeddings/OpenAIEmbedding.ts +++ b/packages/core/src/embeddings/OpenAIEmbedding.ts @@ -1,3 +1,4 @@ +import { Tokenizers } from "@llamaindex/env"; import type { ClientOptions as OpenAIClientOptions } from "openai"; import type { AzureOpenAIConfig } from "../llm/azure.js"; import { @@ -12,20 +13,25 @@ import { BaseEmbedding } from "./types.js"; export const ALL_OPENAI_EMBEDDING_MODELS = { "text-embedding-ada-002": { dimensions: 1536, - maxTokens: 8191, + maxTokens: 8192, + tokenizer: Tokenizers.CL100K_BASE, }, "text-embedding-3-small": { dimensions: 1536, dimensionOptions: [512, 1536], - maxTokens: 8191, + maxTokens: 8192, + tokenizer: Tokenizers.CL100K_BASE, }, "text-embedding-3-large": { dimensions: 3072, dimensionOptions: [256, 1024, 3072], - maxTokens: 8191, + maxTokens: 8192, + tokenizer: Tokenizers.CL100K_BASE, }, }; +type ModelKeys = keyof typeof ALL_OPENAI_EMBEDDING_MODELS; + export class OpenAIEmbedding extends BaseEmbedding { /** embeddding model. defaults to "text-embedding-ada-002" */ model: string; @@ -65,6 +71,14 @@ export class OpenAIEmbedding extends BaseEmbedding { this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds this.additionalSessionOptions = init?.additionalSessionOptions; + // find metadata for model + const key = Object.keys(ALL_OPENAI_EMBEDDING_MODELS).find( + (key) => key === this.model, + ) as ModelKeys | undefined; + if (key) { + this.embedInfo = ALL_OPENAI_EMBEDDING_MODELS[key]; + } + if (init?.azure || shouldUseAzure()) { const azureConfig = { ...getAzureConfigFromEnv({ @@ -102,6 +116,9 @@ export class OpenAIEmbedding extends BaseEmbedding { * @param options */ private async getOpenAIEmbedding(input: string[]): Promise<number[][]> { + // TODO: ensure this for every sub class by calling it in the base class + input = this.truncateMaxTokens(input); + const { data } = await this.session.openai.embeddings.create({ model: this.model, dimensions: this.dimensions, // only sent to OpenAI if set by user diff --git a/packages/core/src/embeddings/tokenizer.ts b/packages/core/src/embeddings/tokenizer.ts new file mode 100644 index 0000000000000000000000000000000000000000..42fba032a3f4b2173c42d28e89c357542b04c61b --- /dev/null +++ b/packages/core/src/embeddings/tokenizer.ts @@ -0,0 +1,20 @@ +import { Tokenizers, tokenizers } from "@llamaindex/env"; + +export function truncateMaxTokens( + tokenizer: Tokenizers, + value: string, + maxTokens: number, +): string { + // the maximum number of tokens per one character is 2 (e.g. 爨) + if (value.length * 2 < maxTokens) return value; + const t = tokenizers.tokenizer(tokenizer); + let tokens = t.encode(value); + if (tokens.length > maxTokens) { + // truncate tokens + tokens = tokens.slice(0, maxTokens); + value = t.decode(tokens); + // if we truncate at an UTF-8 boundary (some characters have more than one token), tiktoken returns a � character - remove it + return value.replace("�", ""); + } + return value; +} diff --git a/packages/core/src/embeddings/types.ts b/packages/core/src/embeddings/types.ts index c0a237854006d068fe9a78964aaa69240270f395..5f622c5b8120c063924bfd31bc5ff98966fbd7e5 100644 --- a/packages/core/src/embeddings/types.ts +++ b/packages/core/src/embeddings/types.ts @@ -1,16 +1,25 @@ +import { type Tokenizers } from "@llamaindex/env"; import type { BaseNode } from "../Node.js"; import { MetadataMode } from "../Node.js"; import type { TransformComponent } from "../ingestion/types.js"; import type { MessageContentDetail } from "../llm/types.js"; import { extractSingleText } from "../llm/utils.js"; +import { truncateMaxTokens } from "./tokenizer.js"; import { SimilarityType, similarity } from "./utils.js"; const DEFAULT_EMBED_BATCH_SIZE = 10; type EmbedFunc<T> = (values: T[]) => Promise<Array<number[]>>; +export type EmbeddingInfo = { + dimensions?: number; + maxTokens?: number; + tokenizer?: Tokenizers; +}; + export abstract class BaseEmbedding implements TransformComponent { embedBatchSize = DEFAULT_EMBED_BATCH_SIZE; + embedInfo?: EmbeddingInfo; similarity( embedding1: number[], @@ -77,6 +86,18 @@ export abstract class BaseEmbedding implements TransformComponent { return nodes; } + + truncateMaxTokens(input: string[]): string[] { + return input.map((s) => { + // truncate to max tokens + if (!(this.embedInfo?.tokenizer && this.embedInfo?.maxTokens)) return s; + return truncateMaxTokens( + this.embedInfo.tokenizer, + s, + this.embedInfo.maxTokens, + ); + }); + } } export async function batchEmbeddings<T>( diff --git a/packages/core/src/index.edge.ts b/packages/core/src/index.edge.ts index 4c2d806a564b1cf464919ad2d07178c329a53314..37a9f0993e1877b45874ced6b323ebea2e9a7341 100644 --- a/packages/core/src/index.edge.ts +++ b/packages/core/src/index.edge.ts @@ -1,5 +1,4 @@ export * from "./ChatHistory.js"; -export * from "./GlobalsHelper.js"; export * from "./Node.js"; export * from "./OutputParser.js"; export * from "./Prompt.js"; diff --git a/packages/core/src/llm/openai.ts b/packages/core/src/llm/openai.ts index 2bc3604b4ae3dae26f5871545425604ac6bf107b..e46df173618caae9eb97e53da8e37643b2e08625 100644 --- a/packages/core/src/llm/openai.ts +++ b/packages/core/src/llm/openai.ts @@ -7,6 +7,7 @@ import type { } from "openai"; import { AzureOpenAI, OpenAI as OrigOpenAI } from "openai"; +import { Tokenizers } from "@llamaindex/env"; import type { ChatCompletionAssistantMessageParam, ChatCompletionMessageToolCall, @@ -17,7 +18,6 @@ import type { ChatCompletionUserMessageParam, } from "openai/resources/chat/completions"; import type { ChatCompletionMessageParam } from "openai/resources/index.js"; -import { Tokenizers } from "../GlobalsHelper.js"; import { wrapEventCaller } from "../internal/context/EventCaller.js"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; import type { BaseTool } from "../types.js"; diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index fec84cf3168c2745713b9d144b5f18b3ed2333bc..e5f54a07d821f4f6ad71c7adb6e886fd879ed206 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -1,4 +1,4 @@ -import type { Tokenizers } from "../GlobalsHelper.js"; +import type { Tokenizers } from "@llamaindex/env"; import type { NodeWithScore } from "../Node.js"; import type { BaseEvent } from "../internal/type.js"; import type { BaseTool, JSONObject, ToolOutput, UUID } from "../types.js"; diff --git a/packages/core/tests/embeddings/tokenizer.test.ts b/packages/core/tests/embeddings/tokenizer.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..1edf50faad3c491269dfaa23c88c6f0a21d970b7 --- /dev/null +++ b/packages/core/tests/embeddings/tokenizer.test.ts @@ -0,0 +1,29 @@ +import { Tokenizers, tokenizers } from "@llamaindex/env"; +import { describe, expect, test } from "vitest"; +import { truncateMaxTokens } from "../../src/embeddings/tokenizer.js"; + +describe("truncateMaxTokens", () => { + const tokenizer = tokenizers.tokenizer(Tokenizers.CL100K_BASE); + + test("should not truncate if less or equal to max tokens", () => { + const text = "Hello".repeat(40); + const t = truncateMaxTokens(Tokenizers.CL100K_BASE, text, 40); + expect(t.length).toEqual(text.length); + }); + + test("should truncate if more than max tokens", () => { + const text = "Hello".repeat(40); + const t = truncateMaxTokens(Tokenizers.CL100K_BASE, text, 20); + expect(tokenizer.encode(t).length).toBe(20); + }); + + test("should work with UTF8-boundaries", () => { + // "爨" has two tokens in CL100K_BASE + const text = "爨".repeat(40); + // truncate at utf-8 boundary + const t = truncateMaxTokens(Tokenizers.CL100K_BASE, text, 39); + // has to remove one token to keep the boundary + expect(tokenizer.encode(t).length).toBe(38); + expect(t.includes("�")).toBe(false); + }); +}); diff --git a/packages/env/package.json b/packages/env/package.json index 9938ca6ee80578e922b199b108fe13e387775976..cefe4001092a07d3a52be9d3b6f8c7746d2b34a9 100644 --- a/packages/env/package.json +++ b/packages/env/package.json @@ -80,7 +80,9 @@ }, "peerDependencies": { "@aws-crypto/sha256-js": "^5.2.0", - "pathe": "^1.1.2" + "pathe": "^1.1.2", + "js-tiktoken": "^1.0.12", + "tiktoken": "^1.0.15" }, "peerDependenciesMeta": { "@aws-crypto/sha256-js": { diff --git a/packages/env/src/index.edge-light.ts b/packages/env/src/index.edge-light.ts index 3d08a3cdfea5a6fb8aa680addf2724545268d183..7f22a82daac62cabda4802f8ecf385d0218282ca 100644 --- a/packages/env/src/index.edge-light.ts +++ b/packages/env/src/index.edge-light.ts @@ -4,3 +4,5 @@ * @module */ export * from "./polyfill.js"; + +export { Tokenizers, tokenizers, type Tokenizer } from "./tokenizers/js.js"; diff --git a/packages/env/src/index.ts b/packages/env/src/index.ts index 672182904e38ef5ebcd2c8161e5146c8b36e7217..a572741cd9c3e0c1f369c84a622b5a679c48c46f 100644 --- a/packages/env/src/index.ts +++ b/packages/env/src/index.ts @@ -35,6 +35,7 @@ export function createSHA256(): SHA256 { }; } +export { Tokenizers, tokenizers, type Tokenizer } from "./tokenizers/node.js"; export { AsyncLocalStorage, CustomEvent, getEnv, setEnvs } from "./utils.js"; export { EOL, diff --git a/packages/env/src/index.workerd.ts b/packages/env/src/index.workerd.ts index 02aceb29bf7f6b11509e728f14ddbcb1b3c853ac..1b0d683862ae150bc646dcd79552826f85f2d9db 100644 --- a/packages/env/src/index.workerd.ts +++ b/packages/env/src/index.workerd.ts @@ -12,3 +12,5 @@ export * from "./polyfill.js"; export function getEnv(name: string): string | undefined { return INTERNAL_ENV[name]; } + +export { Tokenizers, tokenizers, type Tokenizer } from "./tokenizers/node.js"; diff --git a/packages/env/src/tokenizers/js.ts b/packages/env/src/tokenizers/js.ts new file mode 100644 index 0000000000000000000000000000000000000000..c17860fab6e1e902b81e3596a33430c657b44603 --- /dev/null +++ b/packages/env/src/tokenizers/js.ts @@ -0,0 +1,35 @@ +// Note: js-tiktoken it's 60x slower than the WASM implementation - use it only for unsupported environments +import { getEncoding } from "js-tiktoken"; +import type { Tokenizer } from "./types.js"; +import { Tokenizers } from "./types.js"; + +class TokenizerSingleton { + private defaultTokenizer: Tokenizer; + + constructor() { + const encoding = getEncoding("cl100k_base"); + + this.defaultTokenizer = { + encode: (text: string) => { + return new Uint32Array(encoding.encode(text)); + }, + decode: (tokens: Uint32Array) => { + const numberArray = Array.from(tokens); + const text = encoding.decode(numberArray); + const uint8Array = new TextEncoder().encode(text); + return new TextDecoder().decode(uint8Array); + }, + }; + } + + tokenizer(encoding?: Tokenizers) { + if (encoding && encoding !== Tokenizers.CL100K_BASE) { + throw new Error(`Tokenizer encoding ${encoding} not yet supported`); + } + + return this.defaultTokenizer; + } +} + +export const tokenizers = new TokenizerSingleton(); +export { Tokenizers, type Tokenizer }; diff --git a/packages/env/src/tokenizers/node.ts b/packages/env/src/tokenizers/node.ts new file mode 100644 index 0000000000000000000000000000000000000000..3f7b3207c3d14a46bf61139ce0e8386d45a52597 --- /dev/null +++ b/packages/env/src/tokenizers/node.ts @@ -0,0 +1,38 @@ +// Note: This is using th WASM implementation of tiktoken which is 60x faster +import cl100k_base from "tiktoken/encoders/cl100k_base.json"; +import { Tiktoken } from "tiktoken/lite"; +import type { Tokenizer } from "./types.js"; +import { Tokenizers } from "./types.js"; + +class TokenizerSingleton { + private defaultTokenizer: Tokenizer; + + constructor() { + const encoding = new Tiktoken( + cl100k_base.bpe_ranks, + cl100k_base.special_tokens, + cl100k_base.pat_str, + ); + + this.defaultTokenizer = { + encode: (text: string) => { + return encoding.encode(text); + }, + decode: (tokens: Uint32Array) => { + const text = encoding.decode(tokens); + return new TextDecoder().decode(text); + }, + }; + } + + tokenizer(encoding?: Tokenizers) { + if (encoding && encoding !== Tokenizers.CL100K_BASE) { + throw new Error(`Tokenizer encoding ${encoding} not yet supported`); + } + + return this.defaultTokenizer; + } +} + +export const tokenizers = new TokenizerSingleton(); +export { Tokenizers, type Tokenizer }; diff --git a/packages/env/src/tokenizers/types.ts b/packages/env/src/tokenizers/types.ts new file mode 100644 index 0000000000000000000000000000000000000000..b884df74acb53ccd6eab91c1d08eea45ecf8726b --- /dev/null +++ b/packages/env/src/tokenizers/types.ts @@ -0,0 +1,8 @@ +export enum Tokenizers { + CL100K_BASE = "cl100k_base", +} + +export interface Tokenizer { + encode: (text: string) => Uint32Array; + decode: (tokens: Uint32Array) => string; +} diff --git a/packages/env/tsconfig.json b/packages/env/tsconfig.json index 616056b31e46e894257e47ba19d60c1a379dca93..a7628c0ee5e06525fe9ccf2772d308483418008e 100644 --- a/packages/env/tsconfig.json +++ b/packages/env/tsconfig.json @@ -7,7 +7,8 @@ "emitDeclarationOnly": true, "module": "node16", "moduleResolution": "node16", - "types": ["node"] + "types": ["node"], + "resolveJsonModule": true }, "include": ["./src"], "exclude": ["node_modules"] diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c8dcb037f164f5c6590b40b0dcfad11a0ebf97b9..5828373ce026b5b61d631d9b119613fedcafcdb4 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -467,6 +467,9 @@ importers: string-strip-html: specifier: ^13.4.8 version: 13.4.8 + tiktoken: + specifier: ^1.0.15 + version: 1.0.15 unpdf: specifier: ^0.10.1 version: 0.10.1(encoding@0.1.13) @@ -664,6 +667,12 @@ importers: '@types/node': specifier: ^20.12.11 version: 20.12.11 + js-tiktoken: + specifier: ^1.0.12 + version: 1.0.12 + tiktoken: + specifier: ^1.0.15 + version: 1.0.15 devDependencies: '@aws-crypto/sha256-js': specifier: ^5.2.0 @@ -9703,6 +9712,9 @@ packages: thunky@1.1.0: resolution: {integrity: sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA==} + tiktoken@1.0.15: + resolution: {integrity: sha512-sCsrq/vMWUSEW29CJLNmPvWxlVp7yh2tlkAjpJltIKqp5CKf98ZNpdeHRmAlPVFlGEbswDc6SmI8vz64W/qErw==} + tiny-invariant@1.3.3: resolution: {integrity: sha512-+FbBPE1o9QAYvviau/qC5SE3caw21q3xkvWKBtja5vgqOWIHHJ3ioaq1VPfn/Szqctz2bU/oYeKd9/z5BL+PVg==} @@ -15675,7 +15687,7 @@ snapshots: chokidar@3.6.0: dependencies: anymatch: 3.1.3 - braces: 3.0.2 + braces: 3.0.3 glob-parent: 5.1.2 is-binary-path: 2.1.0 is-glob: 4.0.3 @@ -21795,6 +21807,8 @@ snapshots: thunky@1.1.0: {} + tiktoken@1.0.15: {} + tiny-invariant@1.3.3: {} tiny-warning@1.0.3: {}