Skip to content
Snippets Groups Projects
Commit 618b5c23 authored by Eustache Le Bihan's avatar Eustache Le Bihan
Browse files

change llm to lm

parent e29df843
No related branches found
No related tags found
No related merge requests found
......@@ -445,19 +445,19 @@ class WhisperSTTHandler(BaseHandler):
@dataclass
class LanguageModelHandlerArguments:
llm_model_name: str = field(
lm_model_name: str = field(
default="microsoft/Phi-3-mini-4k-instruct",
metadata={
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
}
)
llm_device: str = field(
lm_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
}
)
llm_torch_dtype: str = field(
lm_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
......@@ -481,15 +481,15 @@ class LanguageModelHandlerArguments:
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
}
)
llm_gen_max_new_tokens: int = field(
lm_gen_max_new_tokens: int = field(
default=128,
metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."}
)
llm_gen_temperature: float = field(
lm_gen_temperature: float = field(
default=0.0,
metadata={"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."}
)
llm_gen_do_sample: bool = field(
lm_gen_do_sample: bool = field(
default=False,
metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."}
)
......@@ -635,9 +635,9 @@ class ParlerTTSHandler(BaseHandler):
framerate = self.model.audio_encoder.config.frame_rate
self.play_steps = int(framerate * play_steps_s)
def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}")
tokenized_prompt = self.prompt_tokenizer(llm_sentence, return_tensors="pt")
def process(self, lm_sentence):
console.print(f"[green]ASSISTANT: {lm_sentence}")
tokenized_prompt = self.prompt_tokenizer(lm_sentence, return_tensors="pt")
prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
......@@ -723,7 +723,7 @@ def main():
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(language_model_handler_kwargs, "llm")
prepare_args(language_model_handler_kwargs, "lm")
prepare_args(parler_tts_handler_kwargs, "tts")
stop_event = Event()
......@@ -732,7 +732,7 @@ def main():
send_audio_chunks_queue = Queue()
spoken_prompt_queue = Queue()
text_prompt_queue = Queue()
llm_response_queue = Queue()
lm_response_queue = Queue()
vad = VADHandler(
stop_event,
......@@ -747,15 +747,15 @@ def main():
queue_out=text_prompt_queue,
setup_kwargs=vars(whisper_stt_handler_kwargs),
)
llm = LanguageModelHandler(
lm = LanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=llm_response_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs),
)
tts = ParlerTTSHandler(
stop_event,
queue_in=llm_response_queue,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
setup_args=(should_listen,),
setup_kwargs=vars(parler_tts_handler_kwargs),
......@@ -778,7 +778,7 @@ def main():
)
try:
pipeline_manager = ThreadManager([vad, tts, llm, stt, recv_handler, send_handler])
pipeline_manager = ThreadManager([vad, tts, lm, stt, recv_handler, send_handler])
pipeline_manager.start()
except KeyboardInterrupt:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment