Skip to content
Snippets Groups Projects
Unverified Commit 89737d6e authored by yisding's avatar yisding Committed by GitHub
Browse files

Merge pull request #140 from run-llama/feat/use-tokenizer-for-summarizer

Feat: Use tokenizer for chat history summarizer
parents 602d27c7 6a81d54e
No related branches found
No related tags found
No related merge requests found
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
"scripts": { "scripts": {
"lint": "eslint .", "lint": "eslint .",
"test": "jest", "test": "jest",
"build": "tsup src/index.ts --format esm,cjs --dts" "build": "tsup src/index.ts --format esm,cjs --dts",
"dev": "tsup src/index.ts --format esm,cjs --watch"
} }
} }
\ No newline at end of file
...@@ -314,11 +314,11 @@ export class HistoryChatEngine implements ChatEngine { ...@@ -314,11 +314,11 @@ export class HistoryChatEngine implements ChatEngine {
): Promise<R> { ): Promise<R> {
//Streaming option //Streaming option
if (streaming) { if (streaming) {
return this.streamChat(message, chatHistory) as R; return this.streamChat(message) as R;
} }
this.chatHistory.addMessage({ content: message, role: "user" }); await this.chatHistory.addMessage({ content: message, role: "user" });
const response = await this.llm.chat(this.chatHistory.requestMessages); const response = await this.llm.chat(this.chatHistory.requestMessages);
this.chatHistory.addMessage(response.message); await this.chatHistory.addMessage(response.message);
return new Response(response.message.content) as R; return new Response(response.message.content) as R;
} }
...@@ -326,7 +326,7 @@ export class HistoryChatEngine implements ChatEngine { ...@@ -326,7 +326,7 @@ export class HistoryChatEngine implements ChatEngine {
message: string, message: string,
chatHistory?: ChatMessage[] | undefined, chatHistory?: ChatMessage[] | undefined,
): AsyncGenerator<string, void, unknown> { ): AsyncGenerator<string, void, unknown> {
this.chatHistory.addMessage({ content: message, role: "user" }); await this.chatHistory.addMessage({ content: message, role: "user" });
const response_stream = await this.llm.chat( const response_stream = await this.llm.chat(
this.chatHistory.requestMessages, this.chatHistory.requestMessages,
undefined, undefined,
...@@ -338,7 +338,10 @@ export class HistoryChatEngine implements ChatEngine { ...@@ -338,7 +338,10 @@ export class HistoryChatEngine implements ChatEngine {
accumulator += part; accumulator += part;
yield part; yield part;
} }
this.chatHistory.addMessage({ content: accumulator, role: "user" }); await this.chatHistory.addMessage({
content: accumulator,
role: "assistant",
});
return; return;
} }
......
import { ChatMessage, LLM, OpenAI } from "./llm/LLM"; import tiktoken from "tiktoken";
import {
ALL_AVAILABLE_OPENAI_MODELS,
ChatMessage,
MessageType,
OpenAI,
} from "./llm/LLM";
import { import {
defaultSummaryPrompt, defaultSummaryPrompt,
messagesToHistoryStr, messagesToHistoryStr,
...@@ -47,59 +53,104 @@ export class SimpleChatHistory implements ChatHistory { ...@@ -47,59 +53,104 @@ export class SimpleChatHistory implements ChatHistory {
} }
export class SummaryChatHistory implements ChatHistory { export class SummaryChatHistory implements ChatHistory {
messagesToSummarize: number; tokensToSummarize: number;
messages: ChatMessage[]; messages: ChatMessage[];
summaryPrompt: SummaryPrompt; summaryPrompt: SummaryPrompt;
llm: LLM; llm: OpenAI;
constructor(init?: Partial<SummaryChatHistory>) { constructor(init?: Partial<SummaryChatHistory>) {
this.messagesToSummarize = init?.messagesToSummarize ?? 5;
this.messages = init?.messages ?? []; this.messages = init?.messages ?? [];
this.summaryPrompt = init?.summaryPrompt ?? defaultSummaryPrompt; this.summaryPrompt = init?.summaryPrompt ?? defaultSummaryPrompt;
this.llm = init?.llm ?? new OpenAI(); this.llm = init?.llm ?? new OpenAI();
if (!this.llm.maxTokens) {
throw new Error(
"LLM maxTokens is not set. Needed so the summarizer ensures the context window size of the LLM.",
);
}
// TODO: currently, this only works with OpenAI
// to support more LLMs, we have to move the tokenizer and the context window size to the LLM interface
this.tokensToSummarize =
ALL_AVAILABLE_OPENAI_MODELS[this.llm.model].contextWindow -
this.llm.maxTokens;
} }
private async summarize() { private tokens(messages: ChatMessage[]): number {
// get all messages after the last summary message (including) // for latest OpenAI models, see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
const chatHistoryStr = messagesToHistoryStr( const encoding = tiktoken.encoding_for_model(this.llm.model);
this.messages.slice(this.getLastSummaryIndex()), const tokensPerMessage = 3;
); let numTokens = 0;
for (const message of messages) {
const response = await this.llm.complete( numTokens += tokensPerMessage;
this.summaryPrompt({ context: chatHistoryStr }), for (const value of Object.values(message)) {
); numTokens += encoding.encode(value).length;
}
}
numTokens += 3; // every reply is primed with <|im_start|>assistant<|im_sep|>
return numTokens;
}
this.messages.push({ content: response.message.content, role: "memory" }); private async summarize(): Promise<ChatMessage> {
// get all messages after the last summary message (including)
// if there's no summary message, get all messages (without system messages)
const lastSummaryIndex = this.getLastSummaryIndex();
const messagesToSummarize = !lastSummaryIndex
? this.nonSystemMessages
: this.messages.slice(lastSummaryIndex);
let promptMessages;
do {
promptMessages = [
{
content: this.summaryPrompt({
context: messagesToHistoryStr(messagesToSummarize),
}),
role: "user" as MessageType,
},
];
// remove oldest message until the chat history is short enough for the context window
messagesToSummarize.shift();
} while (this.tokens(promptMessages) > this.tokensToSummarize);
const response = await this.llm.chat(promptMessages);
return { content: response.message.content, role: "memory" };
} }
async addMessage(message: ChatMessage) { async addMessage(message: ChatMessage) {
const lastSummaryIndex = this.getLastSummaryIndex(); // get tokens of current request messages and the new message
// if there are more than or equal `messagesToSummarize` messages since the last summary, call summarize const tokens = this.tokens([...this.requestMessages, message]);
if ( // if there are too many tokens for the next request, call summarize
lastSummaryIndex !== -1 && if (tokens > this.tokensToSummarize) {
this.messages.length - lastSummaryIndex - 1 >= this.messagesToSummarize const memoryMessage = await this.summarize();
) { this.messages.push(memoryMessage);
// TODO: define what are better conditions, e.g. depending on the context length of the LLM?
// for now we just summarize each `messagesToSummarize` messages
await this.summarize();
} }
this.messages.push(message); this.messages.push(message);
} }
// Find last summary message // Find last summary message
private getLastSummaryIndex() { private getLastSummaryIndex(): number | null {
return this.messages const reversedMessages = this.messages.slice().reverse();
.slice() const index = reversedMessages.findIndex(
.reverse() (message) => message.role === "memory",
.findIndex((message) => message.role === "memory"); );
if (index === -1) {
return null;
}
return this.messages.length - 1 - index;
}
private get systemMessages() {
// get array of all system messages
return this.messages.filter((message) => message.role === "system");
}
private get nonSystemMessages() {
// get array of all non-system messages
return this.messages.filter((message) => message.role !== "system");
} }
get requestMessages() { get requestMessages() {
const lastSummaryIndex = this.getLastSummaryIndex(); const lastSummaryIndex = this.getLastSummaryIndex();
// get array of all system messages if (!lastSummaryIndex) return this.messages;
const systemMessages = this.messages.filter(
(message) => message.role === "system",
);
// convert summary message so it can be send to the LLM // convert summary message so it can be send to the LLM
const summaryMessage: ChatMessage = { const summaryMessage: ChatMessage = {
content: `This is a summary of conversation so far: ${this.messages[lastSummaryIndex].content}`, content: `This is a summary of conversation so far: ${this.messages[lastSummaryIndex].content}`,
...@@ -107,7 +158,7 @@ export class SummaryChatHistory implements ChatHistory { ...@@ -107,7 +158,7 @@ export class SummaryChatHistory implements ChatHistory {
}; };
// return system messages, last summary and all messages after the last summary message // return system messages, last summary and all messages after the last summary message
return [ return [
...systemMessages, ...this.systemMessages,
summaryMessage, summaryMessage,
...this.messages.slice(lastSummaryIndex + 1), ...this.messages.slice(lastSummaryIndex + 1),
]; ];
......
export * from "./callbacks/CallbackManager"; export * from "./callbacks/CallbackManager";
export * from "./ChatEngine"; export * from "./ChatEngine";
export * from "./ChatHistory";
export * from "./constants"; export * from "./constants";
export * from "./Embedding"; export * from "./Embedding";
export * from "./GlobalsHelper"; export * from "./GlobalsHelper";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment