From 3196799ec483096fbf08f30a05279f7f79ffe982 Mon Sep 17 00:00:00 2001
From: Eustache Le Bihan <eulebihan@gmail.com>
Date: Tue, 13 Aug 2024 19:25:52 +0000
Subject: [PATCH] chat_size

---
 s2s_pipeline.py | 39 ++++++++++++++++++++++++++++++++++-----
 1 file changed, 34 insertions(+), 5 deletions(-)

diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index eebb2c6..1b744ff 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -9,6 +9,7 @@ import os
 from pathlib import Path
 from dataclasses import dataclass, field
 from copy import copy
+from collections import deque
 
 import numpy as np
 import torch
@@ -490,7 +491,7 @@ class LanguageModelHandlerArguments:
         }
     )
     lm_gen_max_new_tokens: int = field(
-        default=128,
+        default=64,
         metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."}
     )
     lm_gen_temperature: float = field(
@@ -501,6 +502,28 @@ class LanguageModelHandlerArguments:
         default=False,
         metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."}
     )
+    chat_size: int = field(
+        default=3,
+        metadata={"help": "Number of messages of the messages to keep for the chat. None for no limitations."}
+    )
+
+
+class Chat:
+    def __init__(self, size):
+        self.init_chat_message = None
+        self.buffer = deque(maxlen=size)
+
+    def append(self, item):
+        self.buffer.append(item)
+
+    def init_chat(self, init_chat_message):
+        self.init_chat_message = init_chat_message
+
+    def to_list(self):
+        if self.init_chat_message:
+            return [self.init_chat_message] + list(self.buffer)
+        else:
+            return list(self.buffer)
 
 
 class LanguageModelHandler(BaseHandler):
@@ -509,6 +532,7 @@ class LanguageModelHandler(BaseHandler):
             model_name="microsoft/Phi-3-mini-4k-instruct",
             device="cuda", 
             torch_dtype="float16",
+            chat_size=3,
             gen_kwargs={},
             user_role="user",
             init_chat_role=None, 
@@ -532,19 +556,24 @@ class LanguageModelHandler(BaseHandler):
             skip_prompt=True,
             skip_special_tokens=True,
         )
-        self.chat = []
+        self.chat = Chat(chat_size)
         if init_chat_role:
             if not init_chat_prompt:
                 raise ValueError(f"An initial promt needs to be specified when setting init_chat_role.")
-            self.chat.append(
+            self.chat.init_chat(
                 {"role": init_chat_role, "content": init_chat_prompt}
             )
+
         self.gen_kwargs = {
             "streamer": self.streamer,
             "return_full_text": False,
             **gen_kwargs
         }
         self.user_role = user_role
+
+        
+
+
         self.warmup()
 
     def warmup(self):
@@ -578,7 +607,7 @@ class LanguageModelHandler(BaseHandler):
         self.chat.append(
             {"role": self.user_role, "content": prompt}
         )
-        thread = Thread(target=self.pipe, args=(self.chat,), kwargs=self.gen_kwargs)
+        thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs)
         thread.start()
         generated_text, printable_text = "", ""
         logger.debug("infering language model...")
@@ -623,7 +652,7 @@ class ParlerTTSHandlerArguments:
         }
     )
     tts_gen_min_new_tokens: int = field(
-        default=10,
+        default=None,
         metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"}
     )
     tts_gen_max_new_tokens: int = field(
-- 
GitLab