From d84cd01cbfe4543490eca9c37db02aa61f017143 Mon Sep 17 00:00:00 2001
From: emontnemery <erik@montnemery.com>
Date: Fri, 25 Jan 2019 14:40:52 +0800
Subject: [PATCH] Cleanup if discovered mqtt light can't be added (#19740)

* Cleanup if discovered mqtt light can't be added

* No bare except

* Clear ALREADY_DISCOVERED list with helper

* Use constant instead of string literal
---
 .../components/mqtt/light/__init__.py         | 23 ++++++----
 .../components/mqtt/light/schema_basic.py     |  2 +-
 .../components/mqtt/light/schema_json.py      |  4 +-
 .../components/mqtt/light/schema_template.py  |  2 +-
 tests/components/mqtt/test_light.py           | 35 +++++++++++++-
 tests/components/mqtt/test_light_json.py      | 36 ++++++++++++++-
 tests/components/mqtt/test_light_template.py  | 46 ++++++++++++++++++-
 7 files changed, 132 insertions(+), 16 deletions(-)

diff --git a/homeassistant/components/mqtt/light/__init__.py b/homeassistant/components/mqtt/light/__init__.py
index 93f32cd2791..77a1b1d3c10 100644
--- a/homeassistant/components/mqtt/light/__init__.py
+++ b/homeassistant/components/mqtt/light/__init__.py
@@ -10,7 +10,8 @@ import voluptuous as vol
 
 from homeassistant.components import light
 from homeassistant.components.mqtt import ATTR_DISCOVERY_HASH
-from homeassistant.components.mqtt.discovery import MQTT_DISCOVERY_NEW
+from homeassistant.components.mqtt.discovery import (
+    MQTT_DISCOVERY_NEW, clear_discovery_hash)
 from homeassistant.helpers.dispatcher import async_dispatcher_connect
 from homeassistant.helpers.typing import HomeAssistantType, ConfigType
 
@@ -44,23 +45,29 @@ PLATFORM_SCHEMA = vol.All(vol.Schema({
 async def async_setup_platform(hass: HomeAssistantType, config: ConfigType,
                                async_add_entities, discovery_info=None):
     """Set up MQTT light through configuration.yaml."""
-    await _async_setup_entity(hass, config, async_add_entities)
+    await _async_setup_entity(config, async_add_entities)
 
 
 async def async_setup_entry(hass, config_entry, async_add_entities):
     """Set up MQTT light dynamically through MQTT discovery."""
     async def async_discover(discovery_payload):
         """Discover and add a MQTT light."""
-        config = PLATFORM_SCHEMA(discovery_payload)
-        await _async_setup_entity(hass, config, async_add_entities,
-                                  discovery_payload[ATTR_DISCOVERY_HASH])
+        try:
+            discovery_hash = discovery_payload[ATTR_DISCOVERY_HASH]
+            config = PLATFORM_SCHEMA(discovery_payload)
+            await _async_setup_entity(config, async_add_entities,
+                                      discovery_hash)
+        except Exception:
+            if discovery_hash:
+                clear_discovery_hash(hass, discovery_hash)
+            raise
 
     async_dispatcher_connect(
         hass, MQTT_DISCOVERY_NEW.format(light.DOMAIN, 'mqtt'),
         async_discover)
 
 
-async def _async_setup_entity(hass, config, async_add_entities,
+async def _async_setup_entity(config, async_add_entities,
                               discovery_hash=None):
     """Set up a MQTT Light."""
     setup_entity = {
@@ -68,5 +75,5 @@ async def _async_setup_entity(hass, config, async_add_entities,
         'json': schema_json.async_setup_entity_json,
         'template': schema_template.async_setup_entity_template,
     }
-    await setup_entity[config['schema']](
-        hass, config, async_add_entities, discovery_hash)
+    await setup_entity[config[CONF_SCHEMA]](
+        config, async_add_entities, discovery_hash)
diff --git a/homeassistant/components/mqtt/light/schema_basic.py b/homeassistant/components/mqtt/light/schema_basic.py
index 3be8de5c722..d9f676c8b38 100644
--- a/homeassistant/components/mqtt/light/schema_basic.py
+++ b/homeassistant/components/mqtt/light/schema_basic.py
@@ -112,7 +112,7 @@ PLATFORM_SCHEMA_BASIC = mqtt.MQTT_RW_PLATFORM_SCHEMA.extend({
     mqtt.MQTT_JSON_ATTRS_SCHEMA.schema)
 
 
-async def async_setup_entity_basic(hass, config, async_add_entities,
+async def async_setup_entity_basic(config, async_add_entities,
                                    discovery_hash=None):
     """Set up a MQTT Light."""
     config.setdefault(
diff --git a/homeassistant/components/mqtt/light/schema_json.py b/homeassistant/components/mqtt/light/schema_json.py
index 1c32b0c5783..fcf31f097cc 100644
--- a/homeassistant/components/mqtt/light/schema_json.py
+++ b/homeassistant/components/mqtt/light/schema_json.py
@@ -25,7 +25,7 @@ from homeassistant.const import (
 from homeassistant.core import callback
 import homeassistant.helpers.config_validation as cv
 from homeassistant.helpers.restore_state import RestoreEntity
-from homeassistant.helpers.typing import ConfigType, HomeAssistantType
+from homeassistant.helpers.typing import ConfigType
 import homeassistant.util.color as color_util
 
 from .schema_basic import CONF_BRIGHTNESS_SCALE
@@ -85,7 +85,7 @@ PLATFORM_SCHEMA_JSON = mqtt.MQTT_RW_PLATFORM_SCHEMA.extend({
     mqtt.MQTT_JSON_ATTRS_SCHEMA.schema)
 
 
-async def async_setup_entity_json(hass: HomeAssistantType, config: ConfigType,
+async def async_setup_entity_json(config: ConfigType,
                                   async_add_entities, discovery_hash):
     """Set up a MQTT JSON Light."""
     async_add_entities([MqttLightJson(config, discovery_hash)])
diff --git a/homeassistant/components/mqtt/light/schema_template.py b/homeassistant/components/mqtt/light/schema_template.py
index 7020550710b..09aaa359058 100644
--- a/homeassistant/components/mqtt/light/schema_template.py
+++ b/homeassistant/components/mqtt/light/schema_template.py
@@ -71,7 +71,7 @@ PLATFORM_SCHEMA_TEMPLATE = mqtt.MQTT_RW_PLATFORM_SCHEMA.extend({
     mqtt.MQTT_JSON_ATTRS_SCHEMA.schema)
 
 
-async def async_setup_entity_template(hass, config, async_add_entities,
+async def async_setup_entity_template(config, async_add_entities,
                                       discovery_hash):
     """Set up a MQTT Template light."""
     async_add_entities([MqttTemplate(config, discovery_hash)])
diff --git a/tests/components/mqtt/test_light.py b/tests/components/mqtt/test_light.py
index a424263af8c..1b1ba3862e9 100644
--- a/tests/components/mqtt/test_light.py
+++ b/tests/components/mqtt/test_light.py
@@ -1239,7 +1239,7 @@ async def test_discovery_deprecated(hass, mqtt_mock, caplog):
 
 
 async def test_discovery_update_light(hass, mqtt_mock, caplog):
-    """Test removal of discovered light."""
+    """Test update of discovered light."""
     entry = MockConfigEntry(domain=mqtt.DOMAIN)
     await async_start(hass, 'homeassistant', {}, entry)
 
@@ -1274,6 +1274,39 @@ async def test_discovery_update_light(hass, mqtt_mock, caplog):
     assert state is None
 
 
+async def test_discovery_broken(hass, mqtt_mock, caplog):
+    """Test handling of bad discovery message."""
+    entry = MockConfigEntry(domain=mqtt.DOMAIN)
+    await async_start(hass, 'homeassistant', {}, entry)
+
+    data1 = (
+        '{ "name": "Beer" }'
+    )
+    data2 = (
+        '{ "name": "Milk",'
+        '  "status_topic": "test_topic",'
+        '  "command_topic": "test_topic" }'
+    )
+
+    async_fire_mqtt_message(hass, 'homeassistant/light/bla/config',
+                            data1)
+    await hass.async_block_till_done()
+
+    state = hass.states.get('light.beer')
+    assert state is None
+
+    async_fire_mqtt_message(hass, 'homeassistant/light/bla/config',
+                            data2)
+    await hass.async_block_till_done()
+    await hass.async_block_till_done()
+
+    state = hass.states.get('light.milk')
+    assert state is not None
+    assert state.name == 'Milk'
+    state = hass.states.get('light.beer')
+    assert state is None
+
+
 async def test_entity_device_info_with_identifier(hass, mqtt_mock):
     """Test MQTT light device registry integration."""
     entry = MockConfigEntry(domain=mqtt.DOMAIN)
diff --git a/tests/components/mqtt/test_light_json.py b/tests/components/mqtt/test_light_json.py
index 7621da724c9..c8d7c1d3e54 100644
--- a/tests/components/mqtt/test_light_json.py
+++ b/tests/components/mqtt/test_light_json.py
@@ -707,7 +707,7 @@ async def test_discovery_deprecated(hass, mqtt_mock, caplog):
 
 
 async def test_discovery_update_light(hass, mqtt_mock, caplog):
-    """Test removal of discovered light."""
+    """Test update of discovered light."""
     entry = MockConfigEntry(domain=mqtt.DOMAIN)
     await async_start(hass, 'homeassistant', {}, entry)
 
@@ -744,6 +744,40 @@ async def test_discovery_update_light(hass, mqtt_mock, caplog):
     assert state is None
 
 
+async def test_discovery_broken(hass, mqtt_mock, caplog):
+    """Test handling of bad discovery message."""
+    entry = MockConfigEntry(domain=mqtt.DOMAIN)
+    await async_start(hass, 'homeassistant', {}, entry)
+
+    data1 = (
+        '{ "name": "Beer" }'
+    )
+    data2 = (
+        '{ "name": "Milk",'
+        '  "schema": "json",'
+        '  "status_topic": "test_topic",'
+        '  "command_topic": "test_topic" }'
+    )
+
+    async_fire_mqtt_message(hass, 'homeassistant/light/bla/config',
+                            data1)
+    await hass.async_block_till_done()
+
+    state = hass.states.get('light.beer')
+    assert state is None
+
+    async_fire_mqtt_message(hass, 'homeassistant/light/bla/config',
+                            data2)
+    await hass.async_block_till_done()
+    await hass.async_block_till_done()
+
+    state = hass.states.get('light.milk')
+    assert state is not None
+    assert state.name == 'Milk'
+    state = hass.states.get('light.beer')
+    assert state is None
+
+
 async def test_entity_device_info_with_identifier(hass, mqtt_mock):
     """Test MQTT light device registry integration."""
     entry = MockConfigEntry(domain=mqtt.DOMAIN)
diff --git a/tests/components/mqtt/test_light_template.py b/tests/components/mqtt/test_light_template.py
index 509f2ee5d36..13fe086684c 100644
--- a/tests/components/mqtt/test_light_template.py
+++ b/tests/components/mqtt/test_light_template.py
@@ -627,7 +627,7 @@ async def test_unique_id(hass):
     assert len(hass.states.async_entity_ids(light.DOMAIN)) == 1
 
 
-async def test_discovery(hass, mqtt_mock, caplog):
+async def test_discovery_removal(hass, mqtt_mock, caplog):
     """Test removal of discovered mqtt_json lights."""
     entry = MockConfigEntry(domain=mqtt.DOMAIN)
     await async_start(hass, 'homeassistant', {'mqtt': {}}, entry)
@@ -644,6 +644,12 @@ async def test_discovery(hass, mqtt_mock, caplog):
     state = hass.states.get('light.beer')
     assert state is not None
     assert state.name == 'Beer'
+    async_fire_mqtt_message(hass, 'homeassistant/light/bla/config',
+                            '')
+    await hass.async_block_till_done()
+    await hass.async_block_till_done()
+    state = hass.states.get('light.beer')
+    assert state is None
 
 
 async def test_discovery_deprecated(hass, mqtt_mock, caplog):
@@ -666,7 +672,7 @@ async def test_discovery_deprecated(hass, mqtt_mock, caplog):
 
 
 async def test_discovery_update_light(hass, mqtt_mock, caplog):
-    """Test removal of discovered light."""
+    """Test update of discovered light."""
     entry = MockConfigEntry(domain=mqtt.DOMAIN)
     await async_start(hass, 'homeassistant', {}, entry)
 
@@ -707,6 +713,42 @@ async def test_discovery_update_light(hass, mqtt_mock, caplog):
     assert state is None
 
 
+async def test_discovery_broken(hass, mqtt_mock, caplog):
+    """Test handling of bad discovery message."""
+    entry = MockConfigEntry(domain=mqtt.DOMAIN)
+    await async_start(hass, 'homeassistant', {}, entry)
+
+    data1 = (
+        '{ "name": "Beer" }'
+    )
+    data2 = (
+        '{ "name": "Milk",'
+        '  "schema": "template",'
+        '  "status_topic": "test_topic",'
+        '  "command_topic": "test_topic",'
+        '  "command_on_template": "on",'
+        '  "command_off_template": "off"}'
+    )
+
+    async_fire_mqtt_message(hass, 'homeassistant/light/bla/config',
+                            data1)
+    await hass.async_block_till_done()
+
+    state = hass.states.get('light.beer')
+    assert state is None
+
+    async_fire_mqtt_message(hass, 'homeassistant/light/bla/config',
+                            data2)
+    await hass.async_block_till_done()
+    await hass.async_block_till_done()
+
+    state = hass.states.get('light.milk')
+    assert state is not None
+    assert state.name == 'Milk'
+    state = hass.states.get('light.beer')
+    assert state is None
+
+
 async def test_entity_device_info_with_identifier(hass, mqtt_mock):
     """Test MQTT light device registry integration."""
     entry = MockConfigEntry(domain=mqtt.DOMAIN)
-- 
GitLab