Skip to content
Snippets Groups Projects
Commit a4888572 authored by Andres Marafioti's avatar Andres Marafioti
Browse files

Let users choose their llm

parent 0e1e0798
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,7 @@ from queue import Queue
from threading import Event, Thread
from time import perf_counter
from typing import Optional
from sys import platform
from LLM.mlx_lm import MLXLanguageModelHandler
from TTS.melotts import MeloTTSHandler
from baseHandler import BaseHandler
......@@ -62,6 +62,12 @@ class ModuleArguments:
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
},
)
llm: Optional[str] = field(
default="transformers",
metadata={
"help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'"
},
)
tts: Optional[str] = field(
default="parler",
metadata={
......@@ -950,6 +956,20 @@ def main():
if module_kwargs.log_level == "debug":
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
if platform == "darwin":
if module_kwargs.device == "cuda":
raise ValueError(
"Cannot use CUDA on macOS. Please set the device to 'cpu' or 'mps'."
)
if module_kwargs.llm != "mlx-lm":
logger.warning(
"For macOS users, it is recommended to use mlx-lm."
)
if module_kwargs.tts != "melo":
logger.warning(
"If you experiences issues generating the voice, considering setting the tts to melo."
)
# 2. Prepare each part's arguments
def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
if common_device:
......@@ -1020,12 +1040,22 @@ def main():
queue_out=text_prompt_queue,
setup_kwargs=vars(whisper_stt_handler_kwargs),
)
lm = MLXLanguageModelHandler(
if module_kwargs.llm == 'transformers':
lm = LanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs),
)
elif module_kwargs.llm == 'mlx-lm':
lm = MLXLanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs),
)
else:
raise ValueError("The LLM should be either transformers or mlx-lm")
if module_kwargs.tts == 'parler':
torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment