Skip to content
Snippets Groups Projects
Commit 809a904b authored by Marcus Schiesser's avatar Marcus Schiesser
Browse files

fix: summarizer issues

parent 602d27c7
Branches
Tags
No related merge requests found
...@@ -314,11 +314,11 @@ export class HistoryChatEngine implements ChatEngine { ...@@ -314,11 +314,11 @@ export class HistoryChatEngine implements ChatEngine {
): Promise<R> { ): Promise<R> {
//Streaming option //Streaming option
if (streaming) { if (streaming) {
return this.streamChat(message, chatHistory) as R; return this.streamChat(message) as R;
} }
this.chatHistory.addMessage({ content: message, role: "user" }); await this.chatHistory.addMessage({ content: message, role: "user" });
const response = await this.llm.chat(this.chatHistory.requestMessages); const response = await this.llm.chat(this.chatHistory.requestMessages);
this.chatHistory.addMessage(response.message); await this.chatHistory.addMessage(response.message);
return new Response(response.message.content) as R; return new Response(response.message.content) as R;
} }
...@@ -326,7 +326,7 @@ export class HistoryChatEngine implements ChatEngine { ...@@ -326,7 +326,7 @@ export class HistoryChatEngine implements ChatEngine {
message: string, message: string,
chatHistory?: ChatMessage[] | undefined, chatHistory?: ChatMessage[] | undefined,
): AsyncGenerator<string, void, unknown> { ): AsyncGenerator<string, void, unknown> {
this.chatHistory.addMessage({ content: message, role: "user" }); await this.chatHistory.addMessage({ content: message, role: "user" });
const response_stream = await this.llm.chat( const response_stream = await this.llm.chat(
this.chatHistory.requestMessages, this.chatHistory.requestMessages,
undefined, undefined,
...@@ -338,7 +338,7 @@ export class HistoryChatEngine implements ChatEngine { ...@@ -338,7 +338,7 @@ export class HistoryChatEngine implements ChatEngine {
accumulator += part; accumulator += part;
yield part; yield part;
} }
this.chatHistory.addMessage({ content: accumulator, role: "user" }); await this.chatHistory.addMessage({ content: accumulator, role: "user" });
return; return;
} }
......
...@@ -61,8 +61,12 @@ export class SummaryChatHistory implements ChatHistory { ...@@ -61,8 +61,12 @@ export class SummaryChatHistory implements ChatHistory {
private async summarize() { private async summarize() {
// get all messages after the last summary message (including) // get all messages after the last summary message (including)
// if there's no summary message, get all messages
const lastSummaryIndex = this.getLastSummaryIndex();
const chatHistoryStr = messagesToHistoryStr( const chatHistoryStr = messagesToHistoryStr(
this.messages.slice(this.getLastSummaryIndex()), lastSummaryIndex === -1
? this.messages
: this.messages.slice(lastSummaryIndex),
); );
const response = await this.llm.complete( const response = await this.llm.complete(
...@@ -73,12 +77,10 @@ export class SummaryChatHistory implements ChatHistory { ...@@ -73,12 +77,10 @@ export class SummaryChatHistory implements ChatHistory {
} }
async addMessage(message: ChatMessage) { async addMessage(message: ChatMessage) {
const lastSummaryIndex = this.getLastSummaryIndex(); const messagesSinceLastSummary =
// if there are more than or equal `messagesToSummarize` messages since the last summary, call summarize this.messages.length - this.getLastSummaryIndex() - 1;
if ( // if there are too many messages since the last summary, call summarize
lastSummaryIndex !== -1 && if (messagesSinceLastSummary >= this.messagesToSummarize) {
this.messages.length - lastSummaryIndex - 1 >= this.messagesToSummarize
) {
// TODO: define what are better conditions, e.g. depending on the context length of the LLM? // TODO: define what are better conditions, e.g. depending on the context length of the LLM?
// for now we just summarize each `messagesToSummarize` messages // for now we just summarize each `messagesToSummarize` messages
await this.summarize(); await this.summarize();
...@@ -88,14 +90,19 @@ export class SummaryChatHistory implements ChatHistory { ...@@ -88,14 +90,19 @@ export class SummaryChatHistory implements ChatHistory {
// Find last summary message // Find last summary message
private getLastSummaryIndex() { private getLastSummaryIndex() {
return this.messages return (
.slice() this.messages.length -
.reverse() 1 -
.findIndex((message) => message.role === "memory"); this.messages
.slice()
.reverse()
.findIndex((message) => message.role === "memory")
);
} }
get requestMessages() { get requestMessages() {
const lastSummaryIndex = this.getLastSummaryIndex(); const lastSummaryIndex = this.getLastSummaryIndex();
if (lastSummaryIndex === -1) return this.messages;
// get array of all system messages // get array of all system messages
const systemMessages = this.messages.filter( const systemMessages = this.messages.filter(
(message) => message.role === "system", (message) => message.role === "system",
......
export * from "./callbacks/CallbackManager"; export * from "./callbacks/CallbackManager";
export * from "./ChatEngine"; export * from "./ChatEngine";
export * from "./ChatHistory";
export * from "./constants"; export * from "./constants";
export * from "./Embedding"; export * from "./Embedding";
export * from "./GlobalsHelper"; export * from "./GlobalsHelper";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment