Skip to content
Snippets Groups Projects
Unverified Commit a6f27a28 authored by Andrés Marafioti's avatar Andrés Marafioti Committed by GitHub
Browse files

Merge pull request #51 from huggingface/flash-attn-warning

add warning to install flash attn
...@@ -11,6 +11,9 @@ import librosa ...@@ -11,6 +11,9 @@ import librosa
import logging import logging
from rich.console import Console from rich.console import Console
from utils.utils import next_power_of_2 from utils.utils import next_power_of_2
from transformers.utils.import_utils import (
is_flash_attn_2_available,
)
torch._inductor.config.fx_graph_cache = True torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
...@@ -24,6 +27,13 @@ logger = logging.getLogger(__name__) ...@@ -24,6 +27,13 @@ logger = logging.getLogger(__name__)
console = Console() console = Console()
if not is_flash_attn_2_available() and torch.cuda.is_available():
logger.warn(
"""Parler TTS works best with flash attention 2, but is not installed
Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`"""
)
class ParlerTTSHandler(BaseHandler): class ParlerTTSHandler(BaseHandler):
def setup( def setup(
self, self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment