Skip to content
Snippets Groups Projects
Commit 5099c988 authored by Andres Marafioti's avatar Andres Marafioti
Browse files

add warning to install flash attn

parent 235e628a
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,9 @@ import librosa ...@@ -11,6 +11,9 @@ import librosa
import logging import logging
from rich.console import Console from rich.console import Console
from utils.utils import next_power_of_2 from utils.utils import next_power_of_2
from transformers.utils.import_utils import (
is_flash_attn_2_available,
)
torch._inductor.config.fx_graph_cache = True torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
...@@ -24,6 +27,13 @@ logger = logging.getLogger(__name__) ...@@ -24,6 +27,13 @@ logger = logging.getLogger(__name__)
console = Console() console = Console()
if not is_flash_attn_2_available() and torch.cuda.is_available():
logger.warn(
"""Parler TTS works best with flash attention 2, but is not installed
Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`"""
)
class ParlerTTSHandler(BaseHandler): class ParlerTTSHandler(BaseHandler):
def setup( def setup(
self, self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment