From aa41432bbb3d50b4ecbde8c6f06a81c2d1cc1c2f Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Mon, 1 Apr 2024 14:12:17 -0500
Subject: [PATCH] refactor: remove `llm.tokens` api (#679)

---
 examples/package.json              |  1 +
 examples/recipes/cost-analysis.ts  | 17 +++++++++++++++--
 packages/core/src/ChatHistory.ts   | 20 ++++++++++++++++----
 packages/core/src/GlobalsHelper.ts | 10 ++--------
 packages/core/src/llm/LLM.ts       | 29 +----------------------------
 packages/core/src/llm/base.ts      |  3 ---
 packages/core/src/llm/mistral.ts   |  4 ----
 packages/core/src/llm/ollama.ts    |  5 -----
 packages/core/src/llm/types.ts     |  5 -----
 pnpm-lock.yaml                     |  3 +++
 10 files changed, 38 insertions(+), 59 deletions(-)

diff --git a/examples/package.json b/examples/package.json
index 0a0570e74..dc7859351 100644
--- a/examples/package.json
+++ b/examples/package.json
@@ -11,6 +11,7 @@
     "chromadb": "^1.8.1",
     "commander": "^11.1.0",
     "dotenv": "^16.4.1",
+    "js-tiktoken": "^1.0.10",
     "llamaindex": "latest",
     "mongodb": "^6.2.0",
     "pathe": "^1.1.2"
diff --git a/examples/recipes/cost-analysis.ts b/examples/recipes/cost-analysis.ts
index 9e5e5f999..c27a8c6c8 100644
--- a/examples/recipes/cost-analysis.ts
+++ b/examples/recipes/cost-analysis.ts
@@ -1,6 +1,9 @@
+import { encodingForModel } from "js-tiktoken";
 import { OpenAI } from "llamaindex";
 import { Settings } from "llamaindex/Settings";
 
+const encoding = encodingForModel("gpt-4-0125-preview");
+
 const llm = new OpenAI({
   model: "gpt-4-0125-preview",
 });
@@ -9,11 +12,21 @@ let tokenCount = 0;
 
 Settings.callbackManager.on("llm-start", (event) => {
   const { messages } = event.detail.payload;
-  tokenCount += llm.tokens(messages);
+  tokenCount += messages.reduce((count, message) => {
+    return count + encoding.encode(message.content).length;
+  }, 0);
   console.log("Token count:", tokenCount);
   // https://openai.com/pricing
   // $10.00 / 1M tokens
-  console.log(`Price: $${(tokenCount / 1000000) * 10}`);
+  console.log(`Price: $${(tokenCount / 1_000_000) * 10}`);
+});
+Settings.callbackManager.on("llm-end", (event) => {
+  const { response } = event.detail.payload;
+  tokenCount += encoding.encode(response.message.content).length;
+  console.log("Token count:", tokenCount);
+  // https://openai.com/pricing
+  // $30.00 / 1M tokens
+  console.log(`Price: $${(tokenCount / 1_000_000) * 30}`);
 });
 
 const question = "Hello, how are you?";
diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts
index ff8155d9c..d3f261515 100644
--- a/packages/core/src/ChatHistory.ts
+++ b/packages/core/src/ChatHistory.ts
@@ -1,7 +1,8 @@
-import { OpenAI } from "./llm/LLM.js";
-import type { ChatMessage, LLM, MessageType } from "./llm/types.js";
+import { globalsHelper } from "./GlobalsHelper.js";
 import type { SummaryPrompt } from "./Prompt.js";
 import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js";
+import { OpenAI } from "./llm/LLM.js";
+import type { ChatMessage, LLM, MessageType } from "./llm/types.js";
 
 /**
  * A ChatHistory is used to keep the state of back and forth chat messages
@@ -62,6 +63,12 @@ export class SimpleChatHistory extends ChatHistory {
 }
 
 export class SummaryChatHistory extends ChatHistory {
+  /**
+   * Tokenizer function that converts text to tokens,
+   *  this is used to calculate the number of tokens in a message.
+   */
+  tokenizer: (text: string) => Uint32Array =
+    globalsHelper.defaultTokenizer.encode;
   tokensToSummarize: number;
   messages: ChatMessage[];
   summaryPrompt: SummaryPrompt;
@@ -104,7 +111,9 @@ export class SummaryChatHistory extends ChatHistory {
       ];
       // remove oldest message until the chat history is short enough for the context window
       messagesToSummarize.shift();
-    } while (this.llm.tokens(promptMessages) > this.tokensToSummarize);
+    } while (
+      this.tokenizer(promptMessages[0].content).length > this.tokensToSummarize
+    );
 
     const response = await this.llm.chat({ messages: promptMessages });
     return { content: response.message.content, role: "memory" };
@@ -178,7 +187,10 @@ export class SummaryChatHistory extends ChatHistory {
     const requestMessages = this.calcCurrentRequestMessages(transientMessages);
 
     // get tokens of current request messages and the transient messages
-    const tokens = this.llm.tokens(requestMessages);
+    const tokens = requestMessages.reduce(
+      (count, message) => count + this.tokenizer(message.content).length,
+      0,
+    );
     if (tokens > this.tokensToSummarize) {
       // if there are too many tokens for the next request, call summarize
       const memoryMessage = await this.summarize();
diff --git a/packages/core/src/GlobalsHelper.ts b/packages/core/src/GlobalsHelper.ts
index c41c68753..1051f9b07 100644
--- a/packages/core/src/GlobalsHelper.ts
+++ b/packages/core/src/GlobalsHelper.ts
@@ -18,9 +18,9 @@ class GlobalsHelper {
   defaultTokenizer: {
     encode: (text: string) => Uint32Array;
     decode: (tokens: Uint32Array) => string;
-  } | null = null;
+  };
 
-  private initDefaultTokenizer() {
+  constructor() {
     const encoding = encodingForModel("text-embedding-ada-002"); // cl100k_base
 
     this.defaultTokenizer = {
@@ -40,9 +40,6 @@ class GlobalsHelper {
     if (encoding && encoding !== Tokenizers.CL100K_BASE) {
       throw new Error(`Tokenizer encoding ${encoding} not yet supported`);
     }
-    if (!this.defaultTokenizer) {
-      this.initDefaultTokenizer();
-    }
 
     return this.defaultTokenizer!.encode.bind(this.defaultTokenizer);
   }
@@ -51,9 +48,6 @@ class GlobalsHelper {
     if (encoding && encoding !== Tokenizers.CL100K_BASE) {
       throw new Error(`Tokenizer encoding ${encoding} not yet supported`);
     }
-    if (!this.defaultTokenizer) {
-      this.initDefaultTokenizer();
-    }
 
     return this.defaultTokenizer!.decode.bind(this.defaultTokenizer);
   }
diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts
index 8d20169c3..2f18ecc62 100644
--- a/packages/core/src/llm/LLM.ts
+++ b/packages/core/src/llm/LLM.ts
@@ -9,7 +9,7 @@ import {
 
 import type { ChatCompletionMessageParam } from "openai/resources/index.js";
 import type { LLMOptions } from "portkey-ai";
-import { Tokenizers, globalsHelper } from "../GlobalsHelper.js";
+import { Tokenizers } from "../GlobalsHelper.js";
 import { getCallbackManager } from "../internal/settings/CallbackManager.js";
 import type { AnthropicSession } from "./anthropic.js";
 import { getAnthropicSession } from "./anthropic.js";
@@ -171,21 +171,6 @@ export class OpenAI extends BaseLLM {
     };
   }
 
-  tokens(messages: ChatMessage[]): number {
-    // for latest OpenAI models, see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
-    const tokenizer = globalsHelper.tokenizer(this.metadata.tokenizer);
-    const tokensPerMessage = 3;
-    let numTokens = 0;
-    for (const message of messages) {
-      numTokens += tokensPerMessage;
-      for (const value of Object.values(message)) {
-        numTokens += tokenizer(value).length;
-      }
-    }
-    numTokens += 3; // every reply is primed with <|im_start|>assistant<|im_sep|>
-    return numTokens;
-  }
-
   mapMessageType(
     messageType: MessageType,
   ): "user" | "assistant" | "system" | "function" | "tool" {
@@ -414,10 +399,6 @@ export class LlamaDeuce extends BaseLLM {
     this.replicateSession = init?.replicateSession ?? new ReplicateSession();
   }
 
-  tokens(messages: ChatMessage[]): number {
-    throw new Error("Method not implemented.");
-  }
-
   get metadata() {
     return {
       model: this.model,
@@ -667,10 +648,6 @@ export class Anthropic extends BaseLLM {
       });
   }
 
-  tokens(messages: ChatMessage[]): number {
-    throw new Error("Method not implemented.");
-  }
-
   get metadata() {
     return {
       model: this.model,
@@ -797,10 +774,6 @@ export class Portkey extends BaseLLM {
     });
   }
 
-  tokens(messages: ChatMessage[]): number {
-    throw new Error("Method not implemented.");
-  }
-
   get metadata(): LLMMetadata {
     throw new Error("metadata not implemented for Portkey");
   }
diff --git a/packages/core/src/llm/base.ts b/packages/core/src/llm/base.ts
index 36b62b6c7..6dce23cb8 100644
--- a/packages/core/src/llm/base.ts
+++ b/packages/core/src/llm/base.ts
@@ -1,5 +1,4 @@
 import type {
-  ChatMessage,
   ChatResponse,
   ChatResponseChunk,
   CompletionResponse,
@@ -48,6 +47,4 @@ export abstract class BaseLLM implements LLM {
     params: LLMChatParamsStreaming,
   ): Promise<AsyncIterable<ChatResponseChunk>>;
   abstract chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
-
-  abstract tokens(messages: ChatMessage[]): number;
 }
diff --git a/packages/core/src/llm/mistral.ts b/packages/core/src/llm/mistral.ts
index 7ee798030..ce3fd5b0e 100644
--- a/packages/core/src/llm/mistral.ts
+++ b/packages/core/src/llm/mistral.ts
@@ -81,10 +81,6 @@ export class MistralAI extends BaseLLM {
     };
   }
 
-  tokens(messages: ChatMessage[]): number {
-    throw new Error("Method not implemented.");
-  }
-
   private buildParams(messages: ChatMessage[]): any {
     return {
       model: this.model,
diff --git a/packages/core/src/llm/ollama.ts b/packages/core/src/llm/ollama.ts
index 12fdbc928..5b7f39e7c 100644
--- a/packages/core/src/llm/ollama.ts
+++ b/packages/core/src/llm/ollama.ts
@@ -2,7 +2,6 @@ import { ok } from "@llamaindex/env";
 import type { Event } from "../callbacks/CallbackManager.js";
 import { BaseEmbedding } from "../embeddings/types.js";
 import type {
-  ChatMessage,
   ChatResponse,
   ChatResponseChunk,
   CompletionResponse,
@@ -182,10 +181,6 @@ export class Ollama extends BaseEmbedding implements LLM {
     }
   }
 
-  tokens(messages: ChatMessage[]): number {
-    throw new Error("Method not implemented.");
-  }
-
   private async getEmbedding(prompt: string): Promise<number[]> {
     const payload = {
       model: this.model,
diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts
index 7c52bcc8b..7bd2ebad5 100644
--- a/packages/core/src/llm/types.ts
+++ b/packages/core/src/llm/types.ts
@@ -59,11 +59,6 @@ export interface LLM extends LLMChat {
   complete(
     params: LLMCompletionParamsNonStreaming,
   ): Promise<CompletionResponse>;
-
-  /**
-   * Calculates the number of tokens needed for the given chat messages
-   */
-  tokens(messages: ChatMessage[]): number;
 }
 
 export type MessageType =
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 4b69ec290..1e5b389f8 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -130,6 +130,9 @@ importers:
       dotenv:
         specifier: ^16.4.1
         version: 16.4.1
+      js-tiktoken:
+        specifier: ^1.0.10
+        version: 1.0.10
       llamaindex:
         specifier: latest
         version: link:../packages/core
-- 
GitLab