Skip to content
Snippets Groups Projects
Commit c0062746 authored by Marcus Schiesser's avatar Marcus Schiesser
Browse files

feat: use tokenizer to ensure we're not running over the context window

parent 809a904b
Branches
Tags
No related merge requests found
......@@ -38,6 +38,7 @@
"scripts": {
"lint": "eslint .",
"test": "jest",
"build": "tsup src/index.ts --format esm,cjs --dts"
"build": "tsup src/index.ts --format esm,cjs --dts",
"dev": "tsup src/index.ts --format esm,cjs --watch"
}
}
}
\ No newline at end of file
......@@ -338,7 +338,10 @@ export class HistoryChatEngine implements ChatEngine {
accumulator += part;
yield part;
}
await this.chatHistory.addMessage({ content: accumulator, role: "user" });
await this.chatHistory.addMessage({
content: accumulator,
role: "assistant",
});
return;
}
......
import { ChatMessage, LLM, OpenAI } from "./llm/LLM";
import tiktoken from "tiktoken";
import {
ALL_AVAILABLE_OPENAI_MODELS,
ChatMessage,
MessageType,
OpenAI,
} from "./llm/LLM";
import {
defaultSummaryPrompt,
messagesToHistoryStr,
......@@ -47,66 +53,104 @@ export class SimpleChatHistory implements ChatHistory {
}
export class SummaryChatHistory implements ChatHistory {
messagesToSummarize: number;
tokensToSummarize: number;
messages: ChatMessage[];
summaryPrompt: SummaryPrompt;
llm: LLM;
llm: OpenAI;
constructor(init?: Partial<SummaryChatHistory>) {
this.messagesToSummarize = init?.messagesToSummarize ?? 5;
this.messages = init?.messages ?? [];
this.summaryPrompt = init?.summaryPrompt ?? defaultSummaryPrompt;
this.llm = init?.llm ?? new OpenAI();
if (!this.llm.maxTokens) {
throw new Error(
"LLM maxTokens is not set. Needed so the summarizer ensures the context window size of the LLM.",
);
}
// TODO: currently, this only works with OpenAI
// to support more LLMs, we have to move the tokenizer and the context window size to the LLM interface
this.tokensToSummarize =
ALL_AVAILABLE_OPENAI_MODELS[this.llm.model].contextWindow -
this.llm.maxTokens;
}
private tokens(messages: ChatMessage[]): number {
// for latest OpenAI models, see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
const encoding = tiktoken.encoding_for_model(this.llm.model);
const tokensPerMessage = 3;
let numTokens = 0;
for (const message of messages) {
numTokens += tokensPerMessage;
for (const value of Object.values(message)) {
numTokens += encoding.encode(value).length;
}
}
numTokens += 3; // every reply is primed with <|im_start|>assistant<|im_sep|>
return numTokens;
}
private async summarize() {
private async summarize(): Promise<ChatMessage> {
// get all messages after the last summary message (including)
// if there's no summary message, get all messages
// if there's no summary message, get all messages (without system messages)
const lastSummaryIndex = this.getLastSummaryIndex();
const chatHistoryStr = messagesToHistoryStr(
lastSummaryIndex === -1
? this.messages
: this.messages.slice(lastSummaryIndex),
);
const response = await this.llm.complete(
this.summaryPrompt({ context: chatHistoryStr }),
);
this.messages.push({ content: response.message.content, role: "memory" });
const messagesToSummarize = !lastSummaryIndex
? this.nonSystemMessages
: this.messages.slice(lastSummaryIndex);
let promptMessages;
do {
promptMessages = [
{
content: this.summaryPrompt({
context: messagesToHistoryStr(messagesToSummarize),
}),
role: "user" as MessageType,
},
];
// remove oldest message until the chat history is short enough for the context window
messagesToSummarize.shift();
} while (this.tokens(promptMessages) > this.tokensToSummarize);
const response = await this.llm.chat(promptMessages);
return { content: response.message.content, role: "memory" };
}
async addMessage(message: ChatMessage) {
const messagesSinceLastSummary =
this.messages.length - this.getLastSummaryIndex() - 1;
// if there are too many messages since the last summary, call summarize
if (messagesSinceLastSummary >= this.messagesToSummarize) {
// TODO: define what are better conditions, e.g. depending on the context length of the LLM?
// for now we just summarize each `messagesToSummarize` messages
await this.summarize();
// get tokens of current request messages and the new message
const tokens = this.tokens([...this.requestMessages, message]);
// if there are too many tokens for the next request, call summarize
if (tokens > this.tokensToSummarize) {
const memoryMessage = await this.summarize();
this.messages.push(memoryMessage);
}
this.messages.push(message);
}
// Find last summary message
private getLastSummaryIndex() {
return (
this.messages.length -
1 -
this.messages
.slice()
.reverse()
.findIndex((message) => message.role === "memory")
private getLastSummaryIndex(): number | null {
const reversedMessages = this.messages.slice().reverse();
const index = reversedMessages.findIndex(
(message) => message.role === "memory",
);
if (index === -1) {
return null;
}
return this.messages.length - 1 - index;
}
private get systemMessages() {
// get array of all system messages
return this.messages.filter((message) => message.role === "system");
}
private get nonSystemMessages() {
// get array of all system messages
return this.messages.filter((message) => message.role !== "system");
}
get requestMessages() {
const lastSummaryIndex = this.getLastSummaryIndex();
if (lastSummaryIndex === -1) return this.messages;
// get array of all system messages
const systemMessages = this.messages.filter(
(message) => message.role === "system",
);
if (!lastSummaryIndex) return this.messages;
// convert summary message so it can be send to the LLM
const summaryMessage: ChatMessage = {
content: `This is a summary of conversation so far: ${this.messages[lastSummaryIndex].content}`,
......@@ -114,7 +158,7 @@ export class SummaryChatHistory implements ChatHistory {
};
// return system messages, last summary and all messages after the last summary message
return [
...systemMessages,
...this.systemMessages,
summaryMessage,
...this.messages.slice(lastSummaryIndex + 1),
];
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment