diff --git a/TTS/parler_handler.py b/TTS/parler_handler.py index efeb5a84145b65a65bdb7826344a0370a680b111..06dc19b897c55bf5731ad83e444e1cdc26abd4e8 100644 --- a/TTS/parler_handler.py +++ b/TTS/parler_handler.py @@ -11,6 +11,9 @@ import librosa import logging from rich.console import Console from utils.utils import next_power_of_2 +from transformers.utils.import_utils import ( + is_flash_attn_2_available, +) torch._inductor.config.fx_graph_cache = True # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS @@ -24,6 +27,13 @@ logger = logging.getLogger(__name__) console = Console() +if not is_flash_attn_2_available() and torch.cuda.is_available(): + logger.warn( + """Parler TTS works best with flash attention 2, but is not installed + Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`""" + ) + + class ParlerTTSHandler(BaseHandler): def setup( self,