From b8257866f5d7aaa56f5ed4491967c3bb8f395d59 Mon Sep 17 00:00:00 2001
From: Paulus Schoutsen <paulus@home-assistant.io>
Date: Mon, 17 Sep 2018 13:39:30 +0200
Subject: [PATCH] Clean up device update, add via-hub (#16659)

* Clean up device update, add via-hub

* Test loading/saving data

* Lint

* Add to Hue"

* Lint + tests
---
 .../components/config/device_registry.py      |   1 +
 homeassistant/components/deconz/__init__.py   |   2 +-
 homeassistant/components/hue/__init__.py      |   2 +-
 homeassistant/components/light/hue.py         |   1 +
 homeassistant/helpers/device_registry.py      | 109 +++++++++++----
 homeassistant/helpers/entity_platform.py      |  13 +-
 homeassistant/helpers/entity_registry.py      |  14 +-
 tests/common.py                               |   5 +
 .../components/config/test_device_registry.py |  16 ++-
 tests/components/hue/test_init.py             |   4 +-
 tests/helpers/test_device_registry.py         | 125 +++++++++++++++---
 tests/helpers/test_entity_platform.py         |  52 ++++++++
 tests/helpers/test_entity_registry.py         |   7 +-
 13 files changed, 271 insertions(+), 80 deletions(-)

diff --git a/homeassistant/components/config/device_registry.py b/homeassistant/components/config/device_registry.py
index 8383e0cdc7d..88aa5727a97 100644
--- a/homeassistant/components/config/device_registry.py
+++ b/homeassistant/components/config/device_registry.py
@@ -40,6 +40,7 @@ def websocket_list_devices(hass, connection, msg):
                 'name': entry.name,
                 'sw_version': entry.sw_version,
                 'id': entry.id,
+                'hub_device_id': entry.hub_device_id,
             } for entry in registry.devices.values()]
         ))
 
diff --git a/homeassistant/components/deconz/__init__.py b/homeassistant/components/deconz/__init__.py
index 6ed0a6e2c11..82f4233a7da 100644
--- a/homeassistant/components/deconz/__init__.py
+++ b/homeassistant/components/deconz/__init__.py
@@ -127,7 +127,7 @@ async def async_setup_entry(hass, config_entry):
     device_registry = await \
         hass.helpers.device_registry.async_get_registry()
     device_registry.async_get_or_create(
-        config_entry=config_entry.entry_id,
+        config_entry_id=config_entry.entry_id,
         connections={(CONNECTION_NETWORK_MAC, deconz.config.mac)},
         identifiers={(DOMAIN, deconz.config.bridgeid)},
         manufacturer='Dresden Elektronik', model=deconz.config.modelid,
diff --git a/homeassistant/components/hue/__init__.py b/homeassistant/components/hue/__init__.py
index 38b521078f4..7a781c99f53 100644
--- a/homeassistant/components/hue/__init__.py
+++ b/homeassistant/components/hue/__init__.py
@@ -140,7 +140,7 @@ async def async_setup_entry(hass, entry):
     config = bridge.api.config
     device_registry = await dr.async_get_registry(hass)
     device_registry.async_get_or_create(
-        config_entry=entry.entry_id,
+        config_entry_id=entry.entry_id,
         connections={
             (dr.CONNECTION_NETWORK_MAC, config.mac)
         },
diff --git a/homeassistant/components/light/hue.py b/homeassistant/components/light/hue.py
index 6f6e0ed617e..958abaca033 100644
--- a/homeassistant/components/light/hue.py
+++ b/homeassistant/components/light/hue.py
@@ -302,6 +302,7 @@ class HueLight(Light):
             'model': self.light.productname or self.light.modelid,
             # Not yet exposed as properties in aiohue
             'sw_version': self.light.raw['swversion'],
+            'via_hub': (hue.DOMAIN, self.bridge.api.config.bridgeid),
         }
 
     async def async_turn_on(self, **kwargs):
diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py
index e6ff45af2fe..8d4cd0a5bbf 100644
--- a/homeassistant/helpers/device_registry.py
+++ b/homeassistant/helpers/device_registry.py
@@ -10,6 +10,7 @@ from homeassistant.core import callback
 from homeassistant.loader import bind_hass
 
 _LOGGER = logging.getLogger(__name__)
+_UNDEF = object()
 
 DATA_REGISTRY = 'device_registry'
 
@@ -32,6 +33,7 @@ class DeviceEntry:
     model = attr.ib(type=str)
     name = attr.ib(type=str, default=None)
     sw_version = attr.ib(type=str, default=None)
+    hub_device_id = attr.ib(type=str, default=None)
     id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
 
 
@@ -54,28 +56,36 @@ class DeviceRegistry:
         return None
 
     @callback
-    def async_get_or_create(self, *, config_entry, connections, identifiers,
-                            manufacturer, model, name=None, sw_version=None):
+    def async_get_or_create(self, *, config_entry_id, connections, identifiers,
+                            manufacturer, model, name=None, sw_version=None,
+                            via_hub=None):
         """Get device. Create if it doesn't exist."""
         if not identifiers and not connections:
             return None
 
         device = self.async_get_device(identifiers, connections)
 
+        if via_hub is not None:
+            hub_device = self.async_get_device({via_hub}, set())
+            hub_device_id = hub_device.id if hub_device else None
+        else:
+            hub_device_id = None
+
         if device is not None:
-            if config_entry not in device.config_entries:
-                device.config_entries.add(config_entry)
-                self.async_schedule_save()
-            return device
+            return self._async_update_device(
+                device.id, config_entry_id=config_entry_id,
+                hub_device_id=hub_device_id
+            )
 
         device = DeviceEntry(
-            config_entries=[config_entry],
+            config_entries={config_entry_id},
             connections=connections,
             identifiers=identifiers,
             manufacturer=manufacturer,
             model=model,
             name=name,
-            sw_version=sw_version
+            sw_version=sw_version,
+            hub_device_id=hub_device_id
         )
         self.devices[device.id] = device
 
@@ -83,24 +93,64 @@ class DeviceRegistry:
 
         return device
 
+    @callback
+    def _async_update_device(self, device_id, *, config_entry_id=_UNDEF,
+                             remove_config_entry_id=_UNDEF,
+                             hub_device_id=_UNDEF):
+        """Update device attributes."""
+        old = self.devices[device_id]
+
+        changes = {}
+
+        config_entries = old.config_entries
+
+        if (config_entry_id is not _UNDEF and
+                config_entry_id not in old.config_entries):
+            config_entries = old.config_entries | {config_entry_id}
+
+        if (remove_config_entry_id is not _UNDEF and
+                remove_config_entry_id in config_entries):
+            config_entries = set(config_entries)
+            config_entries.remove(remove_config_entry_id)
+
+        if config_entries is not old.config_entries:
+            changes['config_entries'] = config_entries
+
+        if (hub_device_id is not _UNDEF and
+                hub_device_id != old.hub_device_id):
+            changes['hub_device_id'] = hub_device_id
+
+        if not changes:
+            return old
+
+        new = self.devices[device_id] = attr.evolve(old, **changes)
+        self.async_schedule_save()
+        return new
+
     async def async_load(self):
         """Load the device registry."""
-        devices = await self._store.async_load()
-
-        if devices is None:
-            self.devices = OrderedDict()
-            return
-
-        self.devices = {device['id']: DeviceEntry(
-            config_entries=device['config_entries'],
-            connections={tuple(conn) for conn in device['connections']},
-            identifiers={tuple(iden) for iden in device['identifiers']},
-            manufacturer=device['manufacturer'],
-            model=device['model'],
-            name=device['name'],
-            sw_version=device['sw_version'],
-            id=device['id'],
-        ) for device in devices['devices']}
+        data = await self._store.async_load()
+
+        devices = OrderedDict()
+
+        if data is not None:
+            for device in data['devices']:
+                devices[device['id']] = DeviceEntry(
+                    config_entries=set(device['config_entries']),
+                    connections={tuple(conn) for conn
+                                 in device['connections']},
+                    identifiers={tuple(iden) for iden
+                                 in device['identifiers']},
+                    manufacturer=device['manufacturer'],
+                    model=device['model'],
+                    name=device['name'],
+                    sw_version=device['sw_version'],
+                    id=device['id'],
+                    # Introduced in 0.79
+                    hub_device_id=device.get('hub_device_id'),
+                )
+
+        self.devices = devices
 
     @callback
     def async_schedule_save(self):
@@ -122,18 +172,19 @@ class DeviceRegistry:
                 'name': entry.name,
                 'sw_version': entry.sw_version,
                 'id': entry.id,
+                'hub_device_id': entry.hub_device_id,
             } for entry in self.devices.values()
         ]
 
         return data
 
     @callback
-    def async_clear_config_entry(self, config_entry):
+    def async_clear_config_entry(self, config_entry_id):
         """Clear config entry from registry entries."""
-        for device in self.devices.values():
-            if config_entry in device.config_entries:
-                device.config_entries.remove(config_entry)
-                self.async_schedule_save()
+        for dev_id, device in self.devices.items():
+            if config_entry_id in device.config_entries:
+                self._async_update_device(
+                    dev_id, remove_config_entry_id=config_entry_id)
 
 
 @bind_hass
diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py
index 083a2946122..f2913e37339 100644
--- a/homeassistant/helpers/entity_platform.py
+++ b/homeassistant/helpers/entity_platform.py
@@ -273,16 +273,19 @@ class EntityPlatform:
                 config_entry_id = None
 
             device_info = entity.device_info
+
             if config_entry_id is not None and device_info is not None:
                 device = device_registry.async_get_or_create(
-                    config_entry=config_entry_id,
-                    connections=device_info.get('connections', []),
-                    identifiers=device_info.get('identifiers', []),
+                    config_entry_id=config_entry_id,
+                    connections=device_info.get('connections') or set(),
+                    identifiers=device_info.get('identifiers') or set(),
                     manufacturer=device_info.get('manufacturer'),
                     model=device_info.get('model'),
                     name=device_info.get('name'),
-                    sw_version=device_info.get('sw_version'))
-                device_id = device.id
+                    sw_version=device_info.get('sw_version'),
+                    via_hub=device_info.get('via_hub'))
+                if device:
+                    device_id = device.id
             else:
                 device_id = None
 
diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py
index da3645a96fe..01c8419dc04 100644
--- a/homeassistant/helpers/entity_registry.py
+++ b/homeassistant/helpers/entity_registry.py
@@ -31,7 +31,7 @@ STORAGE_VERSION = 1
 STORAGE_KEY = 'core.entity_registry'
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True)
 class RegistryEntry:
     """Entity Registry Entry."""
 
@@ -113,14 +113,9 @@ class EntityRegistry:
         """Get entity. Create if it doesn't exist."""
         entity_id = self.async_get_entity_id(domain, platform, unique_id)
         if entity_id:
-            entry = self.entities[entity_id]
-            if entry.config_entry_id == config_entry_id:
-                return entry
-
-            self._async_update_entity(
+            return self._async_update_entity(
                 entity_id, config_entry_id=config_entry_id,
                 device_id=device_id)
-            return self.entities[entity_id]
 
         entity_id = self.async_generate_entity_id(
             domain, suggested_object_id or '{}_{}'.format(platform, unique_id))
@@ -253,10 +248,9 @@ class EntityRegistry:
     @callback
     def async_clear_config_entry(self, config_entry):
         """Clear config entry from registry entries."""
-        for entry in self.entities.values():
+        for entity_id, entry in self.entities.items():
             if config_entry == entry.config_entry_id:
-                entry.config_entry_id = None
-                self.async_schedule_save()
+                self._async_update_entity(entity_id, config_entry_id=None)
 
 
 @bind_hass
diff --git a/tests/common.py b/tests/common.py
index 6629207b288..56e86a4cd5c 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -763,6 +763,11 @@ class MockEntity(entity.Entity):
         """Return True if entity is available."""
         return self._handle('available')
 
+    @property
+    def device_info(self):
+        """Info how it links to a device."""
+        return self._handle('device_info')
+
     def _handle(self, attr):
         """Return attribute value."""
         if attr in self._values:
diff --git a/tests/components/config/test_device_registry.py b/tests/components/config/test_device_registry.py
index 491319bf927..f8ea51cfdc8 100644
--- a/tests/components/config/test_device_registry.py
+++ b/tests/components/config/test_device_registry.py
@@ -21,15 +21,16 @@ def registry(hass):
 async def test_list_devices(hass, client, registry):
     """Test list entries."""
     registry.async_get_or_create(
-        config_entry='1234',
+        config_entry_id='1234',
         connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
         identifiers={('bridgeid', '0123')},
         manufacturer='manufacturer', model='model')
     registry.async_get_or_create(
-        config_entry='1234',
+        config_entry_id='1234',
         connections={},
         identifiers={('bridgeid', '1234')},
-        manufacturer='manufacturer', model='model')
+        manufacturer='manufacturer', model='model',
+        via_hub=('bridgeid', '0123'))
 
     await client.send_json({
         'id': 5,
@@ -37,8 +38,7 @@ async def test_list_devices(hass, client, registry):
     })
     msg = await client.receive_json()
 
-    for entry in msg['result']:
-        entry.pop('id')
+    dev1, dev2 = [entry.pop('id') for entry in msg['result']]
 
     assert msg['result'] == [
         {
@@ -47,7 +47,8 @@ async def test_list_devices(hass, client, registry):
             'manufacturer': 'manufacturer',
             'model': 'model',
             'name': None,
-            'sw_version': None
+            'sw_version': None,
+            'hub_device_id': None,
         },
         {
             'config_entries': ['1234'],
@@ -55,6 +56,7 @@ async def test_list_devices(hass, client, registry):
             'manufacturer': 'manufacturer',
             'model': 'model',
             'name': None,
-            'sw_version': None
+            'sw_version': None,
+            'hub_device_id': dev1,
         }
     ]
diff --git a/tests/components/hue/test_init.py b/tests/components/hue/test_init.py
index 1c4768746d5..5da6d5b709a 100644
--- a/tests/components/hue/test_init.py
+++ b/tests/components/hue/test_init.py
@@ -182,7 +182,7 @@ async def test_config_passed_to_config_entry(hass):
 
     assert len(mock_registry.mock_calls) == 1
     assert mock_registry.mock_calls[0][2] == {
-        'config_entry': entry.entry_id,
+        'config_entry_id': entry.entry_id,
         'connections': {
             ('mac', 'mock-mac')
         },
@@ -192,7 +192,7 @@ async def test_config_passed_to_config_entry(hass):
         'manufacturer': 'Signify',
         'name': 'mock-name',
         'model': 'mock-modelid',
-        'sw_version': 'mock-swversion'
+        'sw_version': 'mock-swversion',
     }
 
 
diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py
index 5ae6b4df651..b251846c491 100644
--- a/tests/helpers/test_device_registry.py
+++ b/tests/helpers/test_device_registry.py
@@ -2,7 +2,7 @@
 import pytest
 
 from homeassistant.helpers import device_registry
-from tests.common import mock_device_registry
+from tests.common import mock_device_registry, flush_store
 
 
 @pytest.fixture
@@ -14,41 +14,41 @@ def registry(hass):
 async def test_get_or_create_returns_same_entry(registry):
     """Make sure we do not duplicate entries."""
     entry = registry.async_get_or_create(
-        config_entry='1234',
+        config_entry_id='1234',
         connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
         identifiers={('bridgeid', '0123')},
         manufacturer='manufacturer', model='model')
     entry2 = registry.async_get_or_create(
-        config_entry='1234',
+        config_entry_id='1234',
         connections={('ethernet', '11:22:33:44:55:66:77:88')},
         identifiers={('bridgeid', '0123')},
         manufacturer='manufacturer', model='model')
     entry3 = registry.async_get_or_create(
-        config_entry='1234',
+        config_entry_id='1234',
         connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
         identifiers={('bridgeid', '1234')},
         manufacturer='manufacturer', model='model')
 
     assert len(registry.devices) == 1
-    assert entry is entry2
-    assert entry is entry3
+    assert entry.id == entry2.id
+    assert entry.id == entry3.id
     assert entry.identifiers == {('bridgeid', '0123')}
 
 
 async def test_requirement_for_identifier_or_connection(registry):
     """Make sure we do require some descriptor of device."""
     entry = registry.async_get_or_create(
-        config_entry='1234',
+        config_entry_id='1234',
         connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
         identifiers=set(),
         manufacturer='manufacturer', model='model')
     entry2 = registry.async_get_or_create(
-        config_entry='1234',
+        config_entry_id='1234',
         connections=set(),
         identifiers={('bridgeid', '0123')},
         manufacturer='manufacturer', model='model')
     entry3 = registry.async_get_or_create(
-        config_entry='1234',
+        config_entry_id='1234',
         connections=set(),
         identifiers=set(),
         manufacturer='manufacturer', model='model')
@@ -62,25 +62,25 @@ async def test_requirement_for_identifier_or_connection(registry):
 async def test_multiple_config_entries(registry):
     """Make sure we do not get duplicate entries."""
     entry = registry.async_get_or_create(
-        config_entry='123',
+        config_entry_id='123',
         connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
         identifiers={('bridgeid', '0123')},
         manufacturer='manufacturer', model='model')
     entry2 = registry.async_get_or_create(
-        config_entry='456',
+        config_entry_id='456',
         connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
         identifiers={('bridgeid', '0123')},
         manufacturer='manufacturer', model='model')
     entry3 = registry.async_get_or_create(
-        config_entry='123',
+        config_entry_id='123',
         connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
         identifiers={('bridgeid', '0123')},
         manufacturer='manufacturer', model='model')
 
     assert len(registry.devices) == 1
-    assert entry is entry2
-    assert entry is entry3
-    assert entry.config_entries == {'123', '456'}
+    assert entry.id == entry2.id
+    assert entry.id == entry3.id
+    assert entry2.config_entries == {'123', '456'}
 
 
 async def test_loading_from_storage(hass, hass_storage):
@@ -118,7 +118,7 @@ async def test_loading_from_storage(hass, hass_storage):
     registry = await device_registry.async_get_registry(hass)
 
     entry = registry.async_get_or_create(
-        config_entry='1234',
+        config_entry_id='1234',
         connections={('Zigbee', '01.23.45.67.89')},
         identifiers={('serial', '12:34:56:78:90:AB:CD:EF')},
         manufacturer='manufacturer', model='model')
@@ -129,25 +129,106 @@ async def test_loading_from_storage(hass, hass_storage):
 async def test_removing_config_entries(registry):
     """Make sure we do not get duplicate entries."""
     entry = registry.async_get_or_create(
-        config_entry='123',
+        config_entry_id='123',
         connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
         identifiers={('bridgeid', '0123')},
         manufacturer='manufacturer', model='model')
     entry2 = registry.async_get_or_create(
-        config_entry='456',
+        config_entry_id='456',
         connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
         identifiers={('bridgeid', '0123')},
         manufacturer='manufacturer', model='model')
     entry3 = registry.async_get_or_create(
-        config_entry='123',
+        config_entry_id='123',
         connections={('ethernet', '34:56:78:90:AB:CD:EF:12')},
         identifiers={('bridgeid', '4567')},
         manufacturer='manufacturer', model='model')
 
     assert len(registry.devices) == 2
-    assert entry is entry2
-    assert entry is not entry3
-    assert entry.config_entries == {'123', '456'}
+    assert entry.id == entry2.id
+    assert entry.id != entry3.id
+    assert entry2.config_entries == {'123', '456'}
+
     registry.async_clear_config_entry('123')
+    entry = registry.async_get_device({('bridgeid', '0123')}, set())
+    entry3 = registry.async_get_device({('bridgeid', '4567')}, set())
+
     assert entry.config_entries == {'456'}
     assert entry3.config_entries == set()
+
+
+async def test_specifying_hub_device_create(registry):
+    """Test specifying a hub and updating."""
+    hub = registry.async_get_or_create(
+        config_entry_id='123',
+        connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
+        identifiers={('hue', '0123')},
+        manufacturer='manufacturer', model='hub')
+
+    light = registry.async_get_or_create(
+        config_entry_id='456',
+        connections=set(),
+        identifiers={('hue', '456')},
+        manufacturer='manufacturer', model='light',
+        via_hub=('hue', '0123'))
+
+    assert light.hub_device_id == hub.id
+
+
+async def test_specifying_hub_device_update(registry):
+    """Test specifying a hub and updating."""
+    light = registry.async_get_or_create(
+        config_entry_id='456',
+        connections=set(),
+        identifiers={('hue', '456')},
+        manufacturer='manufacturer', model='light',
+        via_hub=('hue', '0123'))
+
+    assert light.hub_device_id is None
+
+    hub = registry.async_get_or_create(
+        config_entry_id='123',
+        connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
+        identifiers={('hue', '0123')},
+        manufacturer='manufacturer', model='hub')
+
+    light = registry.async_get_or_create(
+        config_entry_id='456',
+        connections=set(),
+        identifiers={('hue', '456')},
+        manufacturer='manufacturer', model='light',
+        via_hub=('hue', '0123'))
+
+    assert light.hub_device_id == hub.id
+
+
+async def test_loading_saving_data(hass, registry):
+    """Test that we load/save data correctly."""
+    orig_hub = registry.async_get_or_create(
+        config_entry_id='123',
+        connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
+        identifiers={('hue', '0123')},
+        manufacturer='manufacturer', model='hub')
+
+    orig_light = registry.async_get_or_create(
+        config_entry_id='456',
+        connections=set(),
+        identifiers={('hue', '456')},
+        manufacturer='manufacturer', model='light',
+        via_hub=('hue', '0123'))
+
+    assert len(registry.devices) == 2
+
+    # Now load written data in new registry
+    registry2 = device_registry.DeviceRegistry(hass)
+    await flush_store(registry._store)
+    await registry2.async_load()
+
+    # Ensure same order
+    assert list(registry.devices) == list(registry2.devices)
+
+    new_hub = registry2.async_get_device({('hue', '0123')}, set())
+    new_light = registry2.async_get_device({('hue', '456')}, set())
+
+    assert orig_hub == new_hub
+    assert orig_light == new_light
diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py
index b51219ddbed..631d446d186 100644
--- a/tests/helpers/test_entity_platform.py
+++ b/tests/helpers/test_entity_platform.py
@@ -676,3 +676,55 @@ async def test_entity_registry_updates_invalid_entity_id(hass):
     assert hass.states.get('test_domain.world') is not None
     assert hass.states.get('invalid_entity_id') is None
     assert hass.states.get('diff_domain.world') is None
+
+
+async def test_device_info_called(hass):
+    """Test device info is forwarded correctly."""
+    registry = await hass.helpers.device_registry.async_get_registry()
+    hub = registry.async_get_or_create(
+        config_entry_id='123',
+        connections=set(),
+        identifiers={('hue', 'hub-id')},
+        manufacturer='manufacturer', model='hub'
+    )
+
+    async def async_setup_entry(hass, config_entry, async_add_entities):
+        """Mock setup entry method."""
+        async_add_entities([
+            # Invalid device info
+            MockEntity(unique_id='abcd', device_info={}),
+            # Valid device info
+            MockEntity(unique_id='qwer', device_info={
+                'identifiers': {('hue', '1234')},
+                'connections': {('mac', 'abcd')},
+                'manufacturer': 'test-manuf',
+                'model': 'test-model',
+                'name': 'test-name',
+                'sw_version': 'test-sw',
+                'via_hub': ('hue', 'hub-id'),
+            }),
+        ])
+        return True
+
+    platform = MockPlatform(
+        async_setup_entry=async_setup_entry
+    )
+    config_entry = MockConfigEntry(entry_id='super-mock-id')
+    entity_platform = MockEntityPlatform(
+        hass,
+        platform_name=config_entry.domain,
+        platform=platform
+    )
+
+    assert await entity_platform.async_setup_entry(config_entry)
+    await hass.async_block_till_done()
+
+    device = registry.async_get_device({('hue', '1234')}, set())
+    assert device is not None
+    assert device.identifiers == {('hue', '1234')}
+    assert device.connections == {('mac', 'abcd')}
+    assert device.manufacturer == 'test-manuf'
+    assert device.model == 'test-model'
+    assert device.name == 'test-name'
+    assert device.sw_version == 'test-sw'
+    assert device.hub_device_id == hub.id
diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py
index bb28287ddd8..a8c9086b2d2 100644
--- a/tests/helpers/test_entity_registry.py
+++ b/tests/helpers/test_entity_registry.py
@@ -6,7 +6,7 @@ import pytest
 
 from homeassistant.helpers import entity_registry
 
-from tests.common import mock_registry
+from tests.common import mock_registry, flush_store
 
 
 YAML__OPEN_PATH = 'homeassistant.util.yaml.open'
@@ -77,8 +77,7 @@ async def test_loading_saving_data(hass, registry):
 
     # Now load written data in new registry
     registry2 = entity_registry.EntityRegistry(hass)
-    registry2._store = registry._store
-
+    await flush_store(registry._store)
     await registry2.async_load()
 
     # Ensure same order
@@ -192,6 +191,8 @@ async def test_removing_config_entry_id(registry):
         'light', 'hue', '5678', config_entry_id='mock-id-1')
     assert entry.config_entry_id == 'mock-id-1'
     registry.async_clear_config_entry('mock-id-1')
+
+    entry = registry.entities[entry.entity_id]
     assert entry.config_entry_id is None
 
 
-- 
GitLab