From 45507cd9d1a8f97d20a41b5c20ad0edc0b713f01 Mon Sep 17 00:00:00 2001
From: Robbie Trencheny <me@robbiet.us>
Date: Tue, 7 Feb 2017 03:07:11 -0800
Subject: [PATCH] TTS ID3 support (#5773)

* Add support for writing ID3 tags to the file for improved display in media players

* Lint and async fixes

* Use mutagen instead of taglib

* Fix tests

* Add fallback for album

* Requested changes

* move import

* Fix album name

* Change default options handling

* Move to member function / minor fix

* fix style

* fix lint

* change mutagen handling

* fix lint / add name to bytesio

* Update __init__.py

* Fix test, some cleanups

* Add mutagen exeption handling, fix tests

* fix mutagen taging
---
 homeassistant/components/tts/__init__.py     | 43 +++++++++++++++++++-
 homeassistant/components/tts/amazon_polly.py |  1 +
 homeassistant/components/tts/demo.py         |  1 +
 homeassistant/components/tts/google.py       |  1 +
 homeassistant/components/tts/picotts.py      |  1 +
 homeassistant/components/tts/voicerss.py     |  1 +
 homeassistant/components/tts/yandextts.py    |  1 +
 requirements_all.txt                         |  3 ++
 tests/components/tts/test_init.py            |  9 ++++
 9 files changed, 60 insertions(+), 1 deletion(-)

diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py
index 9b4df2749c0..fb99219416f 100644
--- a/homeassistant/components/tts/__init__.py
+++ b/homeassistant/components/tts/__init__.py
@@ -12,6 +12,7 @@ import logging
 import mimetypes
 import os
 import re
+import io
 
 from aiohttp import web
 import voluptuous as vol
@@ -30,6 +31,7 @@ import homeassistant.helpers.config_validation as cv
 
 DOMAIN = 'tts'
 DEPENDENCIES = ['http']
+REQUIREMENTS = ["mutagen==1.36.2"]
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -255,6 +257,8 @@ class SpeechManager(object):
     def async_register_engine(self, engine, provider, config):
         """Register a TTS provider."""
         provider.hass = self.hass
+        if provider.name is None:
+            provider.name = engine
         self.providers[engine] = provider
 
     @asyncio.coroutine
@@ -276,6 +280,8 @@ class SpeechManager(object):
                 language))
 
         # options
+        if provider.default_options and options:
+            options = provider.default_options.copy().update(options)
         options = options or provider.default_options
         if options is not None:
             invalid_opts = [opt_name for opt_name in options.keys()
@@ -296,7 +302,7 @@ class SpeechManager(object):
         # is file store in file cache
         elif use_cache and key in self.file_cache:
             filename = self.file_cache[key]
-            self.hass.async_add_job(self.async_file_to_mem(key))
+            yield from self.async_file_to_mem(key)
         # load speech from provider into memory
         else:
             filename = yield from self.async_get_tts_audio(
@@ -323,6 +329,9 @@ class SpeechManager(object):
         # create file infos
         filename = ("{}.{}".format(key, extension)).lower()
 
+        data = self.write_tags(
+            filename, data, provider, message, language, options)
+
         # save to memory
         self._async_store_to_memcache(key, filename, data)
 
@@ -412,11 +421,43 @@ class SpeechManager(object):
         content, _ = mimetypes.guess_type(filename)
         return (content, self.mem_cache[key][MEM_CACHE_VOICE])
 
+    @staticmethod
+    def write_tags(filename, data, provider, message, language, options):
+        """Write ID3 tags to file.
+
+        Async friendly.
+        """
+        import mutagen
+
+        data_bytes = io.BytesIO(data)
+        data_bytes.name = filename
+        data_bytes.seek(0)
+
+        album = provider.name
+        artist = language
+
+        if options is not None:
+            if options.get('voice') is not None:
+                artist = options.get('voice')
+
+        try:
+            tts_file = mutagen.File(data_bytes, easy=True)
+            if tts_file is not None:
+                tts_file['artist'] = artist
+                tts_file['album'] = album
+                tts_file['title'] = message
+                tts_file.save(data_bytes)
+        except mutagen.MutagenError as err:
+            _LOGGER.error("ID3 tag error: %s", err)
+
+        return data_bytes.getvalue()
+
 
 class Provider(object):
     """Represent a single provider."""
 
     hass = None
+    name = None
 
     @property
     def default_language(self):
diff --git a/homeassistant/components/tts/amazon_polly.py b/homeassistant/components/tts/amazon_polly.py
index e40c10f5e14..7dab49482ed 100644
--- a/homeassistant/components/tts/amazon_polly.py
+++ b/homeassistant/components/tts/amazon_polly.py
@@ -138,6 +138,7 @@ class AmazonPollyProvider(Provider):
         self.supported_langs = supported_languages
         self.all_voices = all_voices
         self.default_voice = self.config.get(CONF_VOICE)
+        self.name = 'Amazon Polly'
 
     @property
     def supported_languages(self):
diff --git a/homeassistant/components/tts/demo.py b/homeassistant/components/tts/demo.py
index 95362b49db9..d9d1eccec8d 100644
--- a/homeassistant/components/tts/demo.py
+++ b/homeassistant/components/tts/demo.py
@@ -32,6 +32,7 @@ class DemoProvider(Provider):
     def __init__(self, lang):
         """Initialize demo provider."""
         self._lang = lang
+        self.name = 'Demo'
 
     @property
     def default_language(self):
diff --git a/homeassistant/components/tts/google.py b/homeassistant/components/tts/google.py
index 32c9663eedc..be84e0e029b 100644
--- a/homeassistant/components/tts/google.py
+++ b/homeassistant/components/tts/google.py
@@ -58,6 +58,7 @@ class GoogleProvider(Provider):
                            "AppleWebKit/537.36 (KHTML, like Gecko) "
                            "Chrome/47.0.2526.106 Safari/537.36")
         }
+        self.name = 'Google'
 
     @property
     def default_language(self):
diff --git a/homeassistant/components/tts/picotts.py b/homeassistant/components/tts/picotts.py
index 49addd9b177..a22196cfbe0 100644
--- a/homeassistant/components/tts/picotts.py
+++ b/homeassistant/components/tts/picotts.py
@@ -38,6 +38,7 @@ class PicoProvider(Provider):
     def __init__(self, lang):
         """Initialize Pico TTS provider."""
         self._lang = lang
+        self.name = 'PicoTTS'
 
     @property
     def default_language(self):
diff --git a/homeassistant/components/tts/voicerss.py b/homeassistant/components/tts/voicerss.py
index b0c74d1de30..ee50cc30cca 100644
--- a/homeassistant/components/tts/voicerss.py
+++ b/homeassistant/components/tts/voicerss.py
@@ -95,6 +95,7 @@ class VoiceRSSProvider(Provider):
         self.hass = hass
         self._extension = conf[CONF_CODEC]
         self._lang = conf[CONF_LANG]
+        self.name = 'VoiceRSS'
 
         self._form_data = {
             'key': conf[CONF_API_KEY],
diff --git a/homeassistant/components/tts/yandextts.py b/homeassistant/components/tts/yandextts.py
index 824ca6ca38f..b60f9cab61e 100644
--- a/homeassistant/components/tts/yandextts.py
+++ b/homeassistant/components/tts/yandextts.py
@@ -82,6 +82,7 @@ class YandexSpeechKitProvider(Provider):
         self._language = conf.get(CONF_LANG)
         self._emotion = conf.get(CONF_EMOTION)
         self._speed = str(conf.get(CONF_SPEED))
+        self.name = 'YandexTTS'
 
     @property
     def default_language(self):
diff --git a/requirements_all.txt b/requirements_all.txt
index e3e06dcebe8..d8d598dc17f 100755
--- a/requirements_all.txt
+++ b/requirements_all.txt
@@ -328,6 +328,9 @@ mficlient==0.3.0
 # homeassistant.components.sensor.miflora
 miflora==0.1.15
 
+# homeassistant.components.tts
+mutagen==1.36.2
+
 # homeassistant.components.sensor.usps
 myusps==1.0.2
 
diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py
index f7985b8af74..0db7c1a5bef 100644
--- a/tests/components/tts/test_init.py
+++ b/tests/components/tts/test_init.py
@@ -341,6 +341,10 @@ class TestTTS(object):
         assert len(calls) == 1
         req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID])
         _, demo_data = self.demo_provider.get_tts_audio("bla", 'en')
+        demo_data = tts.SpeechManager.write_tags(
+            "265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3",
+            demo_data, self.demo_provider,
+            "I person is on front of your door.", 'en', None)
         assert req.status_code == 200
         assert req.content == demo_data
 
@@ -351,6 +355,7 @@ class TestTTS(object):
         config = {
             tts.DOMAIN: {
                 'platform': 'demo',
+                'language': 'de',
             }
         }
 
@@ -367,6 +372,10 @@ class TestTTS(object):
         assert len(calls) == 1
         req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID])
         _, demo_data = self.demo_provider.get_tts_audio("bla", "de")
+        demo_data = tts.SpeechManager.write_tags(
+            "265944c108cbb00b2a621be5930513e03a0bb2cd_de_-_demo.mp3",
+            demo_data, self.demo_provider,
+            "I person is on front of your door.", 'de', None)
         assert req.status_code == 200
         assert req.content == demo_data
 
-- 
GitLab