Skip to content
Snippets Groups Projects
Commit 45507cd9 authored by Robbie Trencheny's avatar Robbie Trencheny Committed by Pascal Vizeli
Browse files

TTS ID3 support (#5773)

* Add support for writing ID3 tags to the file for improved display in media players

* Lint and async fixes

* Use mutagen instead of taglib

* Fix tests

* Add fallback for album

* Requested changes

* move import

* Fix album name

* Change default options handling

* Move to member function / minor fix

* fix style

* fix lint

* change mutagen handling

* fix lint / add name to bytesio

* Update __init__.py

* Fix test, some cleanups

* Add mutagen exeption handling, fix tests

* fix mutagen taging
parent 063c0e8f
Branches
Tags
No related merge requests found
...@@ -12,6 +12,7 @@ import logging ...@@ -12,6 +12,7 @@ import logging
import mimetypes import mimetypes
import os import os
import re import re
import io
from aiohttp import web from aiohttp import web
import voluptuous as vol import voluptuous as vol
...@@ -30,6 +31,7 @@ import homeassistant.helpers.config_validation as cv ...@@ -30,6 +31,7 @@ import homeassistant.helpers.config_validation as cv
DOMAIN = 'tts' DOMAIN = 'tts'
DEPENDENCIES = ['http'] DEPENDENCIES = ['http']
REQUIREMENTS = ["mutagen==1.36.2"]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
...@@ -255,6 +257,8 @@ class SpeechManager(object): ...@@ -255,6 +257,8 @@ class SpeechManager(object):
def async_register_engine(self, engine, provider, config): def async_register_engine(self, engine, provider, config):
"""Register a TTS provider.""" """Register a TTS provider."""
provider.hass = self.hass provider.hass = self.hass
if provider.name is None:
provider.name = engine
self.providers[engine] = provider self.providers[engine] = provider
@asyncio.coroutine @asyncio.coroutine
...@@ -276,6 +280,8 @@ class SpeechManager(object): ...@@ -276,6 +280,8 @@ class SpeechManager(object):
language)) language))
# options # options
if provider.default_options and options:
options = provider.default_options.copy().update(options)
options = options or provider.default_options options = options or provider.default_options
if options is not None: if options is not None:
invalid_opts = [opt_name for opt_name in options.keys() invalid_opts = [opt_name for opt_name in options.keys()
...@@ -296,7 +302,7 @@ class SpeechManager(object): ...@@ -296,7 +302,7 @@ class SpeechManager(object):
# is file store in file cache # is file store in file cache
elif use_cache and key in self.file_cache: elif use_cache and key in self.file_cache:
filename = self.file_cache[key] filename = self.file_cache[key]
self.hass.async_add_job(self.async_file_to_mem(key)) yield from self.async_file_to_mem(key)
# load speech from provider into memory # load speech from provider into memory
else: else:
filename = yield from self.async_get_tts_audio( filename = yield from self.async_get_tts_audio(
...@@ -323,6 +329,9 @@ class SpeechManager(object): ...@@ -323,6 +329,9 @@ class SpeechManager(object):
# create file infos # create file infos
filename = ("{}.{}".format(key, extension)).lower() filename = ("{}.{}".format(key, extension)).lower()
data = self.write_tags(
filename, data, provider, message, language, options)
# save to memory # save to memory
self._async_store_to_memcache(key, filename, data) self._async_store_to_memcache(key, filename, data)
...@@ -412,11 +421,43 @@ class SpeechManager(object): ...@@ -412,11 +421,43 @@ class SpeechManager(object):
content, _ = mimetypes.guess_type(filename) content, _ = mimetypes.guess_type(filename)
return (content, self.mem_cache[key][MEM_CACHE_VOICE]) return (content, self.mem_cache[key][MEM_CACHE_VOICE])
@staticmethod
def write_tags(filename, data, provider, message, language, options):
"""Write ID3 tags to file.
Async friendly.
"""
import mutagen
data_bytes = io.BytesIO(data)
data_bytes.name = filename
data_bytes.seek(0)
album = provider.name
artist = language
if options is not None:
if options.get('voice') is not None:
artist = options.get('voice')
try:
tts_file = mutagen.File(data_bytes, easy=True)
if tts_file is not None:
tts_file['artist'] = artist
tts_file['album'] = album
tts_file['title'] = message
tts_file.save(data_bytes)
except mutagen.MutagenError as err:
_LOGGER.error("ID3 tag error: %s", err)
return data_bytes.getvalue()
class Provider(object): class Provider(object):
"""Represent a single provider.""" """Represent a single provider."""
hass = None hass = None
name = None
@property @property
def default_language(self): def default_language(self):
......
...@@ -138,6 +138,7 @@ class AmazonPollyProvider(Provider): ...@@ -138,6 +138,7 @@ class AmazonPollyProvider(Provider):
self.supported_langs = supported_languages self.supported_langs = supported_languages
self.all_voices = all_voices self.all_voices = all_voices
self.default_voice = self.config.get(CONF_VOICE) self.default_voice = self.config.get(CONF_VOICE)
self.name = 'Amazon Polly'
@property @property
def supported_languages(self): def supported_languages(self):
......
...@@ -32,6 +32,7 @@ class DemoProvider(Provider): ...@@ -32,6 +32,7 @@ class DemoProvider(Provider):
def __init__(self, lang): def __init__(self, lang):
"""Initialize demo provider.""" """Initialize demo provider."""
self._lang = lang self._lang = lang
self.name = 'Demo'
@property @property
def default_language(self): def default_language(self):
......
...@@ -58,6 +58,7 @@ class GoogleProvider(Provider): ...@@ -58,6 +58,7 @@ class GoogleProvider(Provider):
"AppleWebKit/537.36 (KHTML, like Gecko) " "AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/47.0.2526.106 Safari/537.36") "Chrome/47.0.2526.106 Safari/537.36")
} }
self.name = 'Google'
@property @property
def default_language(self): def default_language(self):
......
...@@ -38,6 +38,7 @@ class PicoProvider(Provider): ...@@ -38,6 +38,7 @@ class PicoProvider(Provider):
def __init__(self, lang): def __init__(self, lang):
"""Initialize Pico TTS provider.""" """Initialize Pico TTS provider."""
self._lang = lang self._lang = lang
self.name = 'PicoTTS'
@property @property
def default_language(self): def default_language(self):
......
...@@ -95,6 +95,7 @@ class VoiceRSSProvider(Provider): ...@@ -95,6 +95,7 @@ class VoiceRSSProvider(Provider):
self.hass = hass self.hass = hass
self._extension = conf[CONF_CODEC] self._extension = conf[CONF_CODEC]
self._lang = conf[CONF_LANG] self._lang = conf[CONF_LANG]
self.name = 'VoiceRSS'
self._form_data = { self._form_data = {
'key': conf[CONF_API_KEY], 'key': conf[CONF_API_KEY],
......
...@@ -82,6 +82,7 @@ class YandexSpeechKitProvider(Provider): ...@@ -82,6 +82,7 @@ class YandexSpeechKitProvider(Provider):
self._language = conf.get(CONF_LANG) self._language = conf.get(CONF_LANG)
self._emotion = conf.get(CONF_EMOTION) self._emotion = conf.get(CONF_EMOTION)
self._speed = str(conf.get(CONF_SPEED)) self._speed = str(conf.get(CONF_SPEED))
self.name = 'YandexTTS'
@property @property
def default_language(self): def default_language(self):
......
...@@ -328,6 +328,9 @@ mficlient==0.3.0 ...@@ -328,6 +328,9 @@ mficlient==0.3.0
# homeassistant.components.sensor.miflora # homeassistant.components.sensor.miflora
miflora==0.1.15 miflora==0.1.15
# homeassistant.components.tts
mutagen==1.36.2
# homeassistant.components.sensor.usps # homeassistant.components.sensor.usps
myusps==1.0.2 myusps==1.0.2
......
...@@ -341,6 +341,10 @@ class TestTTS(object): ...@@ -341,6 +341,10 @@ class TestTTS(object):
assert len(calls) == 1 assert len(calls) == 1
req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID]) req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID])
_, demo_data = self.demo_provider.get_tts_audio("bla", 'en') _, demo_data = self.demo_provider.get_tts_audio("bla", 'en')
demo_data = tts.SpeechManager.write_tags(
"265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3",
demo_data, self.demo_provider,
"I person is on front of your door.", 'en', None)
assert req.status_code == 200 assert req.status_code == 200
assert req.content == demo_data assert req.content == demo_data
...@@ -351,6 +355,7 @@ class TestTTS(object): ...@@ -351,6 +355,7 @@ class TestTTS(object):
config = { config = {
tts.DOMAIN: { tts.DOMAIN: {
'platform': 'demo', 'platform': 'demo',
'language': 'de',
} }
} }
...@@ -367,6 +372,10 @@ class TestTTS(object): ...@@ -367,6 +372,10 @@ class TestTTS(object):
assert len(calls) == 1 assert len(calls) == 1
req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID]) req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID])
_, demo_data = self.demo_provider.get_tts_audio("bla", "de") _, demo_data = self.demo_provider.get_tts_audio("bla", "de")
demo_data = tts.SpeechManager.write_tags(
"265944c108cbb00b2a621be5930513e03a0bb2cd_de_-_demo.mp3",
demo_data, self.demo_provider,
"I person is on front of your door.", 'de', None)
assert req.status_code == 200 assert req.status_code == 200
assert req.content == demo_data assert req.content == demo_data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment