diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 4ee0bf288f50376e2900e7d705088f4c1aad09d6..ec72d6e485bb4e754cd2041c016db3a293527740 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -445,19 +445,19 @@ class WhisperSTTHandler(BaseHandler): @dataclass class LanguageModelHandlerArguments: - llm_model_name: str = field( + lm_model_name: str = field( default="microsoft/Phi-3-mini-4k-instruct", metadata={ "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." } ) - llm_device: str = field( + lm_device: str = field( default="cuda", metadata={ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." } ) - llm_torch_dtype: str = field( + lm_torch_dtype: str = field( default="float16", metadata={ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." @@ -481,15 +481,15 @@ class LanguageModelHandlerArguments: "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" } ) - llm_gen_max_new_tokens: int = field( + lm_gen_max_new_tokens: int = field( default=128, metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."} ) - llm_gen_temperature: float = field( + lm_gen_temperature: float = field( default=0.0, metadata={"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."} ) - llm_gen_do_sample: bool = field( + lm_gen_do_sample: bool = field( default=False, metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."} ) @@ -635,9 +635,9 @@ class ParlerTTSHandler(BaseHandler): framerate = self.model.audio_encoder.config.frame_rate self.play_steps = int(framerate * play_steps_s) - def process(self, llm_sentence): - console.print(f"[green]ASSISTANT: {llm_sentence}") - tokenized_prompt = self.prompt_tokenizer(llm_sentence, return_tensors="pt") + def process(self, lm_sentence): + console.print(f"[green]ASSISTANT: {lm_sentence}") + tokenized_prompt = self.prompt_tokenizer(lm_sentence, return_tensors="pt") prompt_input_ids = tokenized_prompt.input_ids.to(self.device) prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) @@ -723,7 +723,7 @@ def main(): torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) prepare_args(whisper_stt_handler_kwargs, "stt") - prepare_args(language_model_handler_kwargs, "llm") + prepare_args(language_model_handler_kwargs, "lm") prepare_args(parler_tts_handler_kwargs, "tts") stop_event = Event() @@ -732,7 +732,7 @@ def main(): send_audio_chunks_queue = Queue() spoken_prompt_queue = Queue() text_prompt_queue = Queue() - llm_response_queue = Queue() + lm_response_queue = Queue() vad = VADHandler( stop_event, @@ -747,15 +747,15 @@ def main(): queue_out=text_prompt_queue, setup_kwargs=vars(whisper_stt_handler_kwargs), ) - llm = LanguageModelHandler( + lm = LanguageModelHandler( stop_event, queue_in=text_prompt_queue, - queue_out=llm_response_queue, + queue_out=lm_response_queue, setup_kwargs=vars(language_model_handler_kwargs), ) tts = ParlerTTSHandler( stop_event, - queue_in=llm_response_queue, + queue_in=lm_response_queue, queue_out=send_audio_chunks_queue, setup_args=(should_listen,), setup_kwargs=vars(parler_tts_handler_kwargs), @@ -778,7 +778,7 @@ def main(): ) try: - pipeline_manager = ThreadManager([vad, tts, llm, stt, recv_handler, send_handler]) + pipeline_manager = ThreadManager([vad, tts, lm, stt, recv_handler, send_handler]) pipeline_manager.start() except KeyboardInterrupt: