From 8800b83283304eb657225a529d8b5b5ef8f58d3c Mon Sep 17 00:00:00 2001
From: Erik Montnemery <erik@montnemery.com>
Date: Tue, 27 Oct 2020 23:24:54 +0100
Subject: [PATCH] Fix race in Tasmota discovery (#42492)

---
 homeassistant/components/tasmota/__init__.py  | 29 ++++++++-----------
 homeassistant/components/tasmota/discovery.py |  7 ++---
 .../components/tasmota/test_device_trigger.py |  2 +-
 3 files changed, 15 insertions(+), 23 deletions(-)

diff --git a/homeassistant/components/tasmota/__init__.py b/homeassistant/components/tasmota/__init__.py
index 5754b98b71e..a82d95474cc 100644
--- a/homeassistant/components/tasmota/__init__.py
+++ b/homeassistant/components/tasmota/__init__.py
@@ -19,12 +19,10 @@ from homeassistant.components.mqtt.subscription import (
 )
 from homeassistant.core import callback
 from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
-from homeassistant.helpers.dispatcher import async_dispatcher_connect
 from homeassistant.helpers.typing import HomeAssistantType
 
 from . import device_automation, discovery
 from .const import CONF_DISCOVERY_PREFIX, DATA_REMOVE_DISCOVER_COMPONENT, PLATFORMS
-from .discovery import TASMOTA_DISCOVERY_DEVICE
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -55,13 +53,11 @@ async def async_setup_entry(hass, entry):
 
     tasmota_mqtt = TasmotaMQTTClient(_publish, _subscribe_topics, _unsubscribe_topics)
 
-    async def async_discover_device(config, mac):
-        """Discover and add a Tasmota device."""
-        await async_setup_device(hass, mac, config, entry, tasmota_mqtt)
+    device_registry = await hass.helpers.device_registry.async_get_registry()
 
-    hass.data[
-        DATA_REMOVE_DISCOVER_COMPONENT.format("device")
-    ] = async_dispatcher_connect(hass, TASMOTA_DISCOVERY_DEVICE, async_discover_device)
+    def async_discover_device(config, mac):
+        """Discover and add a Tasmota device."""
+        async_setup_device(hass, mac, config, entry, tasmota_mqtt, device_registry)
 
     async def start_platforms():
         await device_automation.async_setup_entry(hass, entry)
@@ -73,7 +69,9 @@ async def async_setup_entry(hass, entry):
         )
 
         discovery_prefix = entry.data[CONF_DISCOVERY_PREFIX]
-        await discovery.async_start(hass, discovery_prefix, entry, tasmota_mqtt)
+        await discovery.async_start(
+            hass, discovery_prefix, entry, tasmota_mqtt, async_discover_device
+        )
 
     hass.async_create_task(start_platforms())
     return True
@@ -97,7 +95,6 @@ async def async_unload_entry(hass, entry):
     # disable discovery
     await discovery.async_stop(hass)
     hass.data.pop(DEVICE_MACS)
-    hass.data[DATA_REMOVE_DISCOVER_COMPONENT.format("device")]()
     hass.data.pop(DATA_REMOVE_DISCOVER_COMPONENT.format("device_automation"))()
     for component in PLATFORMS:
         hass.data.pop(DATA_REMOVE_DISCOVER_COMPONENT.format(component))()
@@ -105,9 +102,8 @@ async def async_unload_entry(hass, entry):
     return True
 
 
-async def _remove_device(hass, config_entry, mac, tasmota_mqtt):
+def _remove_device(hass, config_entry, mac, tasmota_mqtt, device_registry):
     """Remove device from device registry."""
-    device_registry = await hass.helpers.device_registry.async_get_registry()
     device = device_registry.async_get_device(set(), {(CONNECTION_NETWORK_MAC, mac)})
 
     if device is None:
@@ -118,9 +114,8 @@ async def _remove_device(hass, config_entry, mac, tasmota_mqtt):
     clear_discovery_topic(mac, config_entry.data[CONF_DISCOVERY_PREFIX], tasmota_mqtt)
 
 
-async def _update_device(hass, config_entry, config):
+def _update_device(hass, config_entry, config, device_registry):
     """Add or update device registry."""
-    device_registry = await hass.helpers.device_registry.async_get_registry()
     config_entry_id = config_entry.entry_id
     device_info = {
         "connections": {(CONNECTION_NETWORK_MAC, config[CONF_MAC])},
@@ -135,9 +130,9 @@ async def _update_device(hass, config_entry, config):
     hass.data[DEVICE_MACS][device.id] = config[CONF_MAC]
 
 
-async def async_setup_device(hass, mac, config, config_entry, tasmota_mqtt):
+def async_setup_device(hass, mac, config, config_entry, tasmota_mqtt, device_registry):
     """Set up the Tasmota device."""
     if not config:
-        await _remove_device(hass, config_entry, mac, tasmota_mqtt)
+        _remove_device(hass, config_entry, mac, tasmota_mqtt, device_registry)
     else:
-        await _update_device(hass, config_entry, config)
+        _update_device(hass, config_entry, config, device_registry)
diff --git a/homeassistant/components/tasmota/discovery.py b/homeassistant/components/tasmota/discovery.py
index 06a88333230..2313a8327c5 100644
--- a/homeassistant/components/tasmota/discovery.py
+++ b/homeassistant/components/tasmota/discovery.py
@@ -21,7 +21,6 @@ from .const import DOMAIN, PLATFORMS
 _LOGGER = logging.getLogger(__name__)
 
 ALREADY_DISCOVERED = "tasmota_discovered_components"
-TASMOTA_DISCOVERY_DEVICE = "tasmota_discovery_device"
 TASMOTA_DISCOVERY_ENTITY_NEW = "tasmota_discovery_entity_new_{}"
 TASMOTA_DISCOVERY_ENTITY_UPDATED = "tasmota_discovery_entity_updated_{}_{}_{}_{}"
 TASMOTA_DISCOVERY_INSTANCE = "tasmota_discovery_instance"
@@ -41,7 +40,7 @@ def set_discovery_hash(hass, discovery_hash):
 
 
 async def async_start(
-    hass: HomeAssistantType, discovery_topic, config_entry, tasmota_mqtt
+    hass: HomeAssistantType, discovery_topic, config_entry, tasmota_mqtt, setup_device
 ) -> bool:
     """Start Tasmota device discovery."""
 
@@ -95,9 +94,7 @@ async def async_start(
 
         _LOGGER.debug("Received discovery data for tasmota device: %s", mac)
         tasmota_device_config = tasmota_get_device_config(payload)
-        async_dispatcher_send(
-            hass, TASMOTA_DISCOVERY_DEVICE, tasmota_device_config, mac
-        )
+        setup_device(tasmota_device_config, mac)
 
         if not payload:
             return
diff --git a/tests/components/tasmota/test_device_trigger.py b/tests/components/tasmota/test_device_trigger.py
index 35f3b4be9d8..b027b2c095d 100644
--- a/tests/components/tasmota/test_device_trigger.py
+++ b/tests/components/tasmota/test_device_trigger.py
@@ -832,7 +832,7 @@ async def test_attach_unknown_remove_device_from_registry(
     await hass.async_block_till_done()
 
 
-async def test_attach_remove_config_entry(hass, mqtt_mock, setup_tasmota, device_reg):
+async def test_attach_remove_config_entry(hass, device_reg, mqtt_mock, setup_tasmota):
     """Test trigger cleanup when removing a Tasmota config entry."""
     # Discover a device with device trigger
     config = copy.deepcopy(DEFAULT_CONFIG)
-- 
GitLab