From 5099c98865d0caec2c806c295831db00aa22cdd8 Mon Sep 17 00:00:00 2001 From: Andres Marafioti <andimarafioti@gmail.com> Date: Tue, 27 Aug 2024 11:52:30 +0200 Subject: [PATCH] add warning to install flash attn --- TTS/parler_handler.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/TTS/parler_handler.py b/TTS/parler_handler.py index efeb5a8..06dc19b 100644 --- a/TTS/parler_handler.py +++ b/TTS/parler_handler.py @@ -11,6 +11,9 @@ import librosa import logging from rich.console import Console from utils.utils import next_power_of_2 +from transformers.utils.import_utils import ( + is_flash_attn_2_available, +) torch._inductor.config.fx_graph_cache = True # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS @@ -24,6 +27,13 @@ logger = logging.getLogger(__name__) console = Console() +if not is_flash_attn_2_available() and torch.cuda.is_available(): + logger.warn( + """Parler TTS works best with flash attention 2, but is not installed + Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`""" + ) + + class ParlerTTSHandler(BaseHandler): def setup( self, -- GitLab