diff --git a/examples/package.json b/examples/package.json index 0a0570e74f6cf7306751031e7a54f3221b3326f0..dc78593516f0460b0b73c999e4da9bea54fad204 100644 --- a/examples/package.json +++ b/examples/package.json @@ -11,6 +11,7 @@ "chromadb": "^1.8.1", "commander": "^11.1.0", "dotenv": "^16.4.1", + "js-tiktoken": "^1.0.10", "llamaindex": "latest", "mongodb": "^6.2.0", "pathe": "^1.1.2" diff --git a/examples/recipes/cost-analysis.ts b/examples/recipes/cost-analysis.ts index 9e5e5f999eb8d33c320a5b3caa516bfce9150de1..c27a8c6c8949a99d6fc04249a6b4c89f57fde084 100644 --- a/examples/recipes/cost-analysis.ts +++ b/examples/recipes/cost-analysis.ts @@ -1,6 +1,9 @@ +import { encodingForModel } from "js-tiktoken"; import { OpenAI } from "llamaindex"; import { Settings } from "llamaindex/Settings"; +const encoding = encodingForModel("gpt-4-0125-preview"); + const llm = new OpenAI({ model: "gpt-4-0125-preview", }); @@ -9,11 +12,21 @@ let tokenCount = 0; Settings.callbackManager.on("llm-start", (event) => { const { messages } = event.detail.payload; - tokenCount += llm.tokens(messages); + tokenCount += messages.reduce((count, message) => { + return count + encoding.encode(message.content).length; + }, 0); console.log("Token count:", tokenCount); // https://openai.com/pricing // $10.00 / 1M tokens - console.log(`Price: $${(tokenCount / 1000000) * 10}`); + console.log(`Price: $${(tokenCount / 1_000_000) * 10}`); +}); +Settings.callbackManager.on("llm-end", (event) => { + const { response } = event.detail.payload; + tokenCount += encoding.encode(response.message.content).length; + console.log("Token count:", tokenCount); + // https://openai.com/pricing + // $30.00 / 1M tokens + console.log(`Price: $${(tokenCount / 1_000_000) * 30}`); }); const question = "Hello, how are you?"; diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index ff8155d9c873242ffbef4f87397c40a92f78c266..d3f2615159273d0c20a98121ad38efea8557ea2d 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -1,7 +1,8 @@ -import { OpenAI } from "./llm/LLM.js"; -import type { ChatMessage, LLM, MessageType } from "./llm/types.js"; +import { globalsHelper } from "./GlobalsHelper.js"; import type { SummaryPrompt } from "./Prompt.js"; import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js"; +import { OpenAI } from "./llm/LLM.js"; +import type { ChatMessage, LLM, MessageType } from "./llm/types.js"; /** * A ChatHistory is used to keep the state of back and forth chat messages @@ -62,6 +63,12 @@ export class SimpleChatHistory extends ChatHistory { } export class SummaryChatHistory extends ChatHistory { + /** + * Tokenizer function that converts text to tokens, + * this is used to calculate the number of tokens in a message. + */ + tokenizer: (text: string) => Uint32Array = + globalsHelper.defaultTokenizer.encode; tokensToSummarize: number; messages: ChatMessage[]; summaryPrompt: SummaryPrompt; @@ -104,7 +111,9 @@ export class SummaryChatHistory extends ChatHistory { ]; // remove oldest message until the chat history is short enough for the context window messagesToSummarize.shift(); - } while (this.llm.tokens(promptMessages) > this.tokensToSummarize); + } while ( + this.tokenizer(promptMessages[0].content).length > this.tokensToSummarize + ); const response = await this.llm.chat({ messages: promptMessages }); return { content: response.message.content, role: "memory" }; @@ -178,7 +187,10 @@ export class SummaryChatHistory extends ChatHistory { const requestMessages = this.calcCurrentRequestMessages(transientMessages); // get tokens of current request messages and the transient messages - const tokens = this.llm.tokens(requestMessages); + const tokens = requestMessages.reduce( + (count, message) => count + this.tokenizer(message.content).length, + 0, + ); if (tokens > this.tokensToSummarize) { // if there are too many tokens for the next request, call summarize const memoryMessage = await this.summarize(); diff --git a/packages/core/src/GlobalsHelper.ts b/packages/core/src/GlobalsHelper.ts index c41c68753f61d2aa40db880da8cc0657b5d8d898..1051f9b07f79b56566818bc64ce03695a7c02a1c 100644 --- a/packages/core/src/GlobalsHelper.ts +++ b/packages/core/src/GlobalsHelper.ts @@ -18,9 +18,9 @@ class GlobalsHelper { defaultTokenizer: { encode: (text: string) => Uint32Array; decode: (tokens: Uint32Array) => string; - } | null = null; + }; - private initDefaultTokenizer() { + constructor() { const encoding = encodingForModel("text-embedding-ada-002"); // cl100k_base this.defaultTokenizer = { @@ -40,9 +40,6 @@ class GlobalsHelper { if (encoding && encoding !== Tokenizers.CL100K_BASE) { throw new Error(`Tokenizer encoding ${encoding} not yet supported`); } - if (!this.defaultTokenizer) { - this.initDefaultTokenizer(); - } return this.defaultTokenizer!.encode.bind(this.defaultTokenizer); } @@ -51,9 +48,6 @@ class GlobalsHelper { if (encoding && encoding !== Tokenizers.CL100K_BASE) { throw new Error(`Tokenizer encoding ${encoding} not yet supported`); } - if (!this.defaultTokenizer) { - this.initDefaultTokenizer(); - } return this.defaultTokenizer!.decode.bind(this.defaultTokenizer); } diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 8d20169c303aa6563fca21b6543df35fc5b6cbd9..2f18ecc6245aeabcd2bdb01d9b46e728b95dda5f 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -9,7 +9,7 @@ import { import type { ChatCompletionMessageParam } from "openai/resources/index.js"; import type { LLMOptions } from "portkey-ai"; -import { Tokenizers, globalsHelper } from "../GlobalsHelper.js"; +import { Tokenizers } from "../GlobalsHelper.js"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; import type { AnthropicSession } from "./anthropic.js"; import { getAnthropicSession } from "./anthropic.js"; @@ -171,21 +171,6 @@ export class OpenAI extends BaseLLM { }; } - tokens(messages: ChatMessage[]): number { - // for latest OpenAI models, see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - const tokenizer = globalsHelper.tokenizer(this.metadata.tokenizer); - const tokensPerMessage = 3; - let numTokens = 0; - for (const message of messages) { - numTokens += tokensPerMessage; - for (const value of Object.values(message)) { - numTokens += tokenizer(value).length; - } - } - numTokens += 3; // every reply is primed with <|im_start|>assistant<|im_sep|> - return numTokens; - } - mapMessageType( messageType: MessageType, ): "user" | "assistant" | "system" | "function" | "tool" { @@ -414,10 +399,6 @@ export class LlamaDeuce extends BaseLLM { this.replicateSession = init?.replicateSession ?? new ReplicateSession(); } - tokens(messages: ChatMessage[]): number { - throw new Error("Method not implemented."); - } - get metadata() { return { model: this.model, @@ -667,10 +648,6 @@ export class Anthropic extends BaseLLM { }); } - tokens(messages: ChatMessage[]): number { - throw new Error("Method not implemented."); - } - get metadata() { return { model: this.model, @@ -797,10 +774,6 @@ export class Portkey extends BaseLLM { }); } - tokens(messages: ChatMessage[]): number { - throw new Error("Method not implemented."); - } - get metadata(): LLMMetadata { throw new Error("metadata not implemented for Portkey"); } diff --git a/packages/core/src/llm/base.ts b/packages/core/src/llm/base.ts index 36b62b6c77deb2a9e92acf420fe66ea158e4b76a..6dce23cb8c0817e876a483000bb7277127ab3580 100644 --- a/packages/core/src/llm/base.ts +++ b/packages/core/src/llm/base.ts @@ -1,5 +1,4 @@ import type { - ChatMessage, ChatResponse, ChatResponseChunk, CompletionResponse, @@ -48,6 +47,4 @@ export abstract class BaseLLM implements LLM { params: LLMChatParamsStreaming, ): Promise<AsyncIterable<ChatResponseChunk>>; abstract chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; - - abstract tokens(messages: ChatMessage[]): number; } diff --git a/packages/core/src/llm/mistral.ts b/packages/core/src/llm/mistral.ts index 7ee798030d1e85b8e98e6d3a8ffb626edb7b2a80..ce3fd5b0eef0a41406e2ec5e8821d2fec9062e64 100644 --- a/packages/core/src/llm/mistral.ts +++ b/packages/core/src/llm/mistral.ts @@ -81,10 +81,6 @@ export class MistralAI extends BaseLLM { }; } - tokens(messages: ChatMessage[]): number { - throw new Error("Method not implemented."); - } - private buildParams(messages: ChatMessage[]): any { return { model: this.model, diff --git a/packages/core/src/llm/ollama.ts b/packages/core/src/llm/ollama.ts index 12fdbc9285bb17b0062693fd04c0ada8fd1b6570..5b7f39e7c4780bf13b387f5c3739add330fee205 100644 --- a/packages/core/src/llm/ollama.ts +++ b/packages/core/src/llm/ollama.ts @@ -2,7 +2,6 @@ import { ok } from "@llamaindex/env"; import type { Event } from "../callbacks/CallbackManager.js"; import { BaseEmbedding } from "../embeddings/types.js"; import type { - ChatMessage, ChatResponse, ChatResponseChunk, CompletionResponse, @@ -182,10 +181,6 @@ export class Ollama extends BaseEmbedding implements LLM { } } - tokens(messages: ChatMessage[]): number { - throw new Error("Method not implemented."); - } - private async getEmbedding(prompt: string): Promise<number[]> { const payload = { model: this.model, diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index 7c52bcc8b3d9fbcccc05c3490015a60d0b7cc40c..7bd2ebad5541bb76a0816b6c66e1c56da7615707 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -59,11 +59,6 @@ export interface LLM extends LLMChat { complete( params: LLMCompletionParamsNonStreaming, ): Promise<CompletionResponse>; - - /** - * Calculates the number of tokens needed for the given chat messages - */ - tokens(messages: ChatMessage[]): number; } export type MessageType = diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4b69ec290f33e90ad56049f36ee9b96cb330bed9..1e5b389f883585ddbbcd031842ab4fbe17eda362 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -130,6 +130,9 @@ importers: dotenv: specifier: ^16.4.1 version: 16.4.1 + js-tiktoken: + specifier: ^1.0.10 + version: 1.0.10 llamaindex: specifier: latest version: link:../packages/core