From 5099c98865d0caec2c806c295831db00aa22cdd8 Mon Sep 17 00:00:00 2001
From: Andres Marafioti <andimarafioti@gmail.com>
Date: Tue, 27 Aug 2024 11:52:30 +0200
Subject: [PATCH] add warning to install flash attn

---
 TTS/parler_handler.py | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/TTS/parler_handler.py b/TTS/parler_handler.py
index efeb5a8..06dc19b 100644
--- a/TTS/parler_handler.py
+++ b/TTS/parler_handler.py
@@ -11,6 +11,9 @@ import librosa
 import logging
 from rich.console import Console
 from utils.utils import next_power_of_2
+from transformers.utils.import_utils import (
+    is_flash_attn_2_available,
+)
 
 torch._inductor.config.fx_graph_cache = True
 # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
@@ -24,6 +27,13 @@ logger = logging.getLogger(__name__)
 console = Console()
 
 
+if not is_flash_attn_2_available() and torch.cuda.is_available():
+    logger.warn(
+        """Parler TTS works best with flash attention 2, but is not installed
+        Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`"""
+    )
+
+
 class ParlerTTSHandler(BaseHandler):
     def setup(
         self,
-- 
GitLab