Skip to content
Snippets Groups Projects
Commit a26395c5 authored by Andres Marafioti's avatar Andres Marafioti
Browse files

linting and few fixes for chattts

parent c40bf05d
No related branches found
No related tags found
No related merge requests found
...@@ -68,7 +68,9 @@ class WhisperSTTHandler(BaseHandler): ...@@ -68,7 +68,9 @@ class WhisperSTTHandler(BaseHandler):
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
# hence, having min_new_tokens < max_new_tokens in the future doesn't make sense # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense
warmup_gen_kwargs = { warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"], # Yes, assign max_new_tokens to min_new_tokens "min_new_tokens": self.gen_kwargs[
"max_new_tokens"
], # Yes, assign max_new_tokens to min_new_tokens
"max_new_tokens": self.gen_kwargs["max_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs, **self.gen_kwargs,
} }
......
...@@ -18,15 +18,15 @@ class ChatTTSHandler(BaseHandler): ...@@ -18,15 +18,15 @@ class ChatTTSHandler(BaseHandler):
def setup( def setup(
self, self,
should_listen, should_listen,
device="mps", device="cuda",
gen_kwargs={}, # Unused gen_kwargs={}, # Unused
stream=True, stream=True,
chunk_size=512, chunk_size=512,
): ):
self.should_listen = should_listen self.should_listen = should_listen
self.device = device self.device = device
self.model = ChatTTS.Chat() self.model = ChatTTS.Chat()
self.model.load(compile=True) # Set to True for better performance self.model.load(compile=False) # Doesn't work for me with True
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.stream = stream self.stream = stream
rnd_spk_emb = self.model.sample_random_speaker() rnd_spk_emb = self.model.sample_random_speaker()
...@@ -37,8 +37,7 @@ class ChatTTSHandler(BaseHandler): ...@@ -37,8 +37,7 @@ class ChatTTSHandler(BaseHandler):
def warmup(self): def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}") logger.info(f"Warming up {self.__class__.__name__}")
_= self.model.infer("text") _ = self.model.infer("text")
def process(self, llm_sentence): def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}") console.print(f"[green]ASSISTANT: {llm_sentence}")
...@@ -52,36 +51,32 @@ class ChatTTSHandler(BaseHandler): ...@@ -52,36 +51,32 @@ class ChatTTSHandler(BaseHandler):
time.time() - start time.time() - start
) # Removing this line makes it fail more often. I'm looking into it. ) # Removing this line makes it fail more often. I'm looking into it.
wavs_gen = self.model.infer(llm_sentence,params_infer_code=self.params_infer_code, stream=self.stream) wavs_gen = self.model.infer(
llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream
)
if self.stream: if self.stream:
wavs = [np.array([])] wavs = [np.array([])]
for gen in wavs_gen: for gen in wavs_gen:
print('new chunk gen', len(gen[0]))
if len(gen[0]) == 0: if len(gen[0]) == 0:
self.should_listen.set() self.should_listen.set()
return return
audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000) audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16) audio_chunk = (audio_chunk * 32768).astype(np.int16)[0]
print('audio_chunk:', audio_chunk.shape)
while len(audio_chunk) > self.chunk_size: while len(audio_chunk) > self.chunk_size:
yield audio_chunk[:self.chunk_size] # 返回前 chunk_size 字节的数据 yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据
audio_chunk = audio_chunk[self.chunk_size:] # 移除已返回的数据 audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据
yield np.pad(audio_chunk, (0,self.chunk_size-len(audio_chunk))) yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk)))
else: else:
print('check result', wavs_gen)
wavs = wavs_gen wavs = wavs_gen
if len(wavs[0]) == 0: if len(wavs[0]) == 0:
self.should_listen.set() self.should_listen.set()
return return
audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000) audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16) audio_chunk = (audio_chunk * 32768).astype(np.int16)
print('audio_chunk:', audio_chunk.shape)
for i in range(0, len(audio_chunk), self.chunk_size): for i in range(0, len(audio_chunk), self.chunk_size):
yield np.pad( yield np.pad(
audio_chunk[i : i + self.chunk_size], audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
) )
self.should_listen.set() self.should_listen.set()
...@@ -5,14 +5,12 @@ from dataclasses import dataclass, field ...@@ -5,14 +5,12 @@ from dataclasses import dataclass, field
class ChatTTSHandlerArguments: class ChatTTSHandlerArguments:
chat_tts_stream: bool = field( chat_tts_stream: bool = field(
default=True, default=True,
metadata={ metadata={"help": "The tts mode is stream Default is 'stream'."},
"help": "The tts mode is stream Default is 'stream'."
},
) )
chat_tts_device: str = field( chat_tts_device: str = field(
default="mps", default="cuda",
metadata={ metadata={
"help": "The device to be used for speech synthesis. Default is 'mps'." "help": "The device to be used for speech synthesis. Default is 'cuda'."
}, },
) )
chat_tts_chunk_size: int = field( chat_tts_chunk_size: int = field(
......
...@@ -35,7 +35,7 @@ class ModuleArguments: ...@@ -35,7 +35,7 @@ class ModuleArguments:
tts: Optional[str] = field( tts: Optional[str] = field(
default="parler", default="parler",
metadata={ metadata={
"help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'" "help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'"
}, },
) )
log_level: str = field( log_level: str = field(
......
...@@ -3,6 +3,6 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git ...@@ -3,6 +3,6 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git
melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers
torch==2.4.0 torch==2.4.0
sounddevice==0.5.0 sounddevice==0.5.0
ChatTTS ChatTTS>=0.1.1
funasr funasr>=1.1.6
modelscope modelscope>=1.17.1
\ No newline at end of file \ No newline at end of file
...@@ -5,6 +5,6 @@ torch==2.4.0 ...@@ -5,6 +5,6 @@ torch==2.4.0
sounddevice==0.5.0 sounddevice==0.5.0
lightning-whisper-mlx>=0.0.10 lightning-whisper-mlx>=0.0.10
mlx-lm>=0.14.0 mlx-lm>=0.14.0
ChatTTS ChatTTS>=0.1.1
funasr>=1.1.6 funasr>=1.1.6
modelscope>=1.17.1 modelscope>=1.17.1
...@@ -80,7 +80,7 @@ def main(): ...@@ -80,7 +80,7 @@ def main():
MLXLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments, ParlerTTSHandlerArguments,
MeloTTSHandlerArguments, MeloTTSHandlerArguments,
ChatTTSHandlerArguments ChatTTSHandlerArguments,
) )
) )
...@@ -190,7 +190,7 @@ def main(): ...@@ -190,7 +190,7 @@ def main():
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(parler_tts_handler_kwargs, "tts")
prepare_args(melo_tts_handler_kwargs, "melo") prepare_args(melo_tts_handler_kwargs, "melo")
prepare_args(chat_tts_handler_kwargs,"chat_tts") prepare_args(chat_tts_handler_kwargs, "chat_tts")
# 3. Build the pipeline # 3. Build the pipeline
stop_event = Event() stop_event = Event()
...@@ -319,9 +319,7 @@ def main(): ...@@ -319,9 +319,7 @@ def main():
try: try:
from TTS.chatTTS_handler import ChatTTSHandler from TTS.chatTTS_handler import ChatTTSHandler
except RuntimeError as e: except RuntimeError as e:
logger.error( logger.error("Error importing ChatTTSHandler")
"Error importing ChatTTSHandler"
)
raise e raise e
tts = ChatTTSHandler( tts = ChatTTSHandler(
stop_event, stop_event,
...@@ -331,7 +329,7 @@ def main(): ...@@ -331,7 +329,7 @@ def main():
setup_kwargs=vars(chat_tts_handler_kwargs), setup_kwargs=vars(chat_tts_handler_kwargs),
) )
else: else:
raise ValueError("The TTS should be either parler or melo") raise ValueError("The TTS should be either parler, melo or chatTTS")
# 4. Run the pipeline # 4. Run the pipeline
try: try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment