From aa41432bbb3d50b4ecbde8c6f06a81c2d1cc1c2f Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Mon, 1 Apr 2024 14:12:17 -0500 Subject: [PATCH] refactor: remove `llm.tokens` api (#679) --- examples/package.json | 1 + examples/recipes/cost-analysis.ts | 17 +++++++++++++++-- packages/core/src/ChatHistory.ts | 20 ++++++++++++++++---- packages/core/src/GlobalsHelper.ts | 10 ++-------- packages/core/src/llm/LLM.ts | 29 +---------------------------- packages/core/src/llm/base.ts | 3 --- packages/core/src/llm/mistral.ts | 4 ---- packages/core/src/llm/ollama.ts | 5 ----- packages/core/src/llm/types.ts | 5 ----- pnpm-lock.yaml | 3 +++ 10 files changed, 38 insertions(+), 59 deletions(-) diff --git a/examples/package.json b/examples/package.json index 0a0570e74..dc7859351 100644 --- a/examples/package.json +++ b/examples/package.json @@ -11,6 +11,7 @@ "chromadb": "^1.8.1", "commander": "^11.1.0", "dotenv": "^16.4.1", + "js-tiktoken": "^1.0.10", "llamaindex": "latest", "mongodb": "^6.2.0", "pathe": "^1.1.2" diff --git a/examples/recipes/cost-analysis.ts b/examples/recipes/cost-analysis.ts index 9e5e5f999..c27a8c6c8 100644 --- a/examples/recipes/cost-analysis.ts +++ b/examples/recipes/cost-analysis.ts @@ -1,6 +1,9 @@ +import { encodingForModel } from "js-tiktoken"; import { OpenAI } from "llamaindex"; import { Settings } from "llamaindex/Settings"; +const encoding = encodingForModel("gpt-4-0125-preview"); + const llm = new OpenAI({ model: "gpt-4-0125-preview", }); @@ -9,11 +12,21 @@ let tokenCount = 0; Settings.callbackManager.on("llm-start", (event) => { const { messages } = event.detail.payload; - tokenCount += llm.tokens(messages); + tokenCount += messages.reduce((count, message) => { + return count + encoding.encode(message.content).length; + }, 0); console.log("Token count:", tokenCount); // https://openai.com/pricing // $10.00 / 1M tokens - console.log(`Price: $${(tokenCount / 1000000) * 10}`); + console.log(`Price: $${(tokenCount / 1_000_000) * 10}`); +}); +Settings.callbackManager.on("llm-end", (event) => { + const { response } = event.detail.payload; + tokenCount += encoding.encode(response.message.content).length; + console.log("Token count:", tokenCount); + // https://openai.com/pricing + // $30.00 / 1M tokens + console.log(`Price: $${(tokenCount / 1_000_000) * 30}`); }); const question = "Hello, how are you?"; diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index ff8155d9c..d3f261515 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -1,7 +1,8 @@ -import { OpenAI } from "./llm/LLM.js"; -import type { ChatMessage, LLM, MessageType } from "./llm/types.js"; +import { globalsHelper } from "./GlobalsHelper.js"; import type { SummaryPrompt } from "./Prompt.js"; import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js"; +import { OpenAI } from "./llm/LLM.js"; +import type { ChatMessage, LLM, MessageType } from "./llm/types.js"; /** * A ChatHistory is used to keep the state of back and forth chat messages @@ -62,6 +63,12 @@ export class SimpleChatHistory extends ChatHistory { } export class SummaryChatHistory extends ChatHistory { + /** + * Tokenizer function that converts text to tokens, + * this is used to calculate the number of tokens in a message. + */ + tokenizer: (text: string) => Uint32Array = + globalsHelper.defaultTokenizer.encode; tokensToSummarize: number; messages: ChatMessage[]; summaryPrompt: SummaryPrompt; @@ -104,7 +111,9 @@ export class SummaryChatHistory extends ChatHistory { ]; // remove oldest message until the chat history is short enough for the context window messagesToSummarize.shift(); - } while (this.llm.tokens(promptMessages) > this.tokensToSummarize); + } while ( + this.tokenizer(promptMessages[0].content).length > this.tokensToSummarize + ); const response = await this.llm.chat({ messages: promptMessages }); return { content: response.message.content, role: "memory" }; @@ -178,7 +187,10 @@ export class SummaryChatHistory extends ChatHistory { const requestMessages = this.calcCurrentRequestMessages(transientMessages); // get tokens of current request messages and the transient messages - const tokens = this.llm.tokens(requestMessages); + const tokens = requestMessages.reduce( + (count, message) => count + this.tokenizer(message.content).length, + 0, + ); if (tokens > this.tokensToSummarize) { // if there are too many tokens for the next request, call summarize const memoryMessage = await this.summarize(); diff --git a/packages/core/src/GlobalsHelper.ts b/packages/core/src/GlobalsHelper.ts index c41c68753..1051f9b07 100644 --- a/packages/core/src/GlobalsHelper.ts +++ b/packages/core/src/GlobalsHelper.ts @@ -18,9 +18,9 @@ class GlobalsHelper { defaultTokenizer: { encode: (text: string) => Uint32Array; decode: (tokens: Uint32Array) => string; - } | null = null; + }; - private initDefaultTokenizer() { + constructor() { const encoding = encodingForModel("text-embedding-ada-002"); // cl100k_base this.defaultTokenizer = { @@ -40,9 +40,6 @@ class GlobalsHelper { if (encoding && encoding !== Tokenizers.CL100K_BASE) { throw new Error(`Tokenizer encoding ${encoding} not yet supported`); } - if (!this.defaultTokenizer) { - this.initDefaultTokenizer(); - } return this.defaultTokenizer!.encode.bind(this.defaultTokenizer); } @@ -51,9 +48,6 @@ class GlobalsHelper { if (encoding && encoding !== Tokenizers.CL100K_BASE) { throw new Error(`Tokenizer encoding ${encoding} not yet supported`); } - if (!this.defaultTokenizer) { - this.initDefaultTokenizer(); - } return this.defaultTokenizer!.decode.bind(this.defaultTokenizer); } diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 8d20169c3..2f18ecc62 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -9,7 +9,7 @@ import { import type { ChatCompletionMessageParam } from "openai/resources/index.js"; import type { LLMOptions } from "portkey-ai"; -import { Tokenizers, globalsHelper } from "../GlobalsHelper.js"; +import { Tokenizers } from "../GlobalsHelper.js"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; import type { AnthropicSession } from "./anthropic.js"; import { getAnthropicSession } from "./anthropic.js"; @@ -171,21 +171,6 @@ export class OpenAI extends BaseLLM { }; } - tokens(messages: ChatMessage[]): number { - // for latest OpenAI models, see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - const tokenizer = globalsHelper.tokenizer(this.metadata.tokenizer); - const tokensPerMessage = 3; - let numTokens = 0; - for (const message of messages) { - numTokens += tokensPerMessage; - for (const value of Object.values(message)) { - numTokens += tokenizer(value).length; - } - } - numTokens += 3; // every reply is primed with <|im_start|>assistant<|im_sep|> - return numTokens; - } - mapMessageType( messageType: MessageType, ): "user" | "assistant" | "system" | "function" | "tool" { @@ -414,10 +399,6 @@ export class LlamaDeuce extends BaseLLM { this.replicateSession = init?.replicateSession ?? new ReplicateSession(); } - tokens(messages: ChatMessage[]): number { - throw new Error("Method not implemented."); - } - get metadata() { return { model: this.model, @@ -667,10 +648,6 @@ export class Anthropic extends BaseLLM { }); } - tokens(messages: ChatMessage[]): number { - throw new Error("Method not implemented."); - } - get metadata() { return { model: this.model, @@ -797,10 +774,6 @@ export class Portkey extends BaseLLM { }); } - tokens(messages: ChatMessage[]): number { - throw new Error("Method not implemented."); - } - get metadata(): LLMMetadata { throw new Error("metadata not implemented for Portkey"); } diff --git a/packages/core/src/llm/base.ts b/packages/core/src/llm/base.ts index 36b62b6c7..6dce23cb8 100644 --- a/packages/core/src/llm/base.ts +++ b/packages/core/src/llm/base.ts @@ -1,5 +1,4 @@ import type { - ChatMessage, ChatResponse, ChatResponseChunk, CompletionResponse, @@ -48,6 +47,4 @@ export abstract class BaseLLM implements LLM { params: LLMChatParamsStreaming, ): Promise<AsyncIterable<ChatResponseChunk>>; abstract chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; - - abstract tokens(messages: ChatMessage[]): number; } diff --git a/packages/core/src/llm/mistral.ts b/packages/core/src/llm/mistral.ts index 7ee798030..ce3fd5b0e 100644 --- a/packages/core/src/llm/mistral.ts +++ b/packages/core/src/llm/mistral.ts @@ -81,10 +81,6 @@ export class MistralAI extends BaseLLM { }; } - tokens(messages: ChatMessage[]): number { - throw new Error("Method not implemented."); - } - private buildParams(messages: ChatMessage[]): any { return { model: this.model, diff --git a/packages/core/src/llm/ollama.ts b/packages/core/src/llm/ollama.ts index 12fdbc928..5b7f39e7c 100644 --- a/packages/core/src/llm/ollama.ts +++ b/packages/core/src/llm/ollama.ts @@ -2,7 +2,6 @@ import { ok } from "@llamaindex/env"; import type { Event } from "../callbacks/CallbackManager.js"; import { BaseEmbedding } from "../embeddings/types.js"; import type { - ChatMessage, ChatResponse, ChatResponseChunk, CompletionResponse, @@ -182,10 +181,6 @@ export class Ollama extends BaseEmbedding implements LLM { } } - tokens(messages: ChatMessage[]): number { - throw new Error("Method not implemented."); - } - private async getEmbedding(prompt: string): Promise<number[]> { const payload = { model: this.model, diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index 7c52bcc8b..7bd2ebad5 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -59,11 +59,6 @@ export interface LLM extends LLMChat { complete( params: LLMCompletionParamsNonStreaming, ): Promise<CompletionResponse>; - - /** - * Calculates the number of tokens needed for the given chat messages - */ - tokens(messages: ChatMessage[]): number; } export type MessageType = diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4b69ec290..1e5b389f8 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -130,6 +130,9 @@ importers: dotenv: specifier: ^16.4.1 version: 16.4.1 + js-tiktoken: + specifier: ^1.0.10 + version: 1.0.10 llamaindex: specifier: latest version: link:../packages/core -- GitLab