From 8b04d48ffdf226e345957ddb509da9c1b5cf67dd Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen <paulus@paulusschoutsen.nl> Date: Thu, 19 Jul 2018 08:37:13 +0200 Subject: [PATCH] Update config entry id in entity registry (#15531) --- homeassistant/helpers/entity_registry.py | 16 ++++++++++++++++ tests/helpers/test_entity_registry.py | 13 ++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 04d9cc450ba..b222d78b577 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -109,6 +109,12 @@ class EntityRegistry: """Get entity. Create if it doesn't exist.""" entity_id = self.async_get_entity_id(domain, platform, unique_id) if entity_id: + entry = self.entities[entity_id] + if entry.config_entry_id == config_entry_id: + return entry + + self._async_update_entity( + entity_id, config_entry_id=config_entry_id) return self.entities[entity_id] entity_id = self.async_generate_entity_id( @@ -129,6 +135,12 @@ class EntityRegistry: @callback def async_update_entity(self, entity_id, *, name=_UNDEF): """Update properties of an entity.""" + return self._async_update_entity(entity_id, name=name) + + @callback + def _async_update_entity(self, entity_id, *, name=_UNDEF, + config_entry_id=_UNDEF): + """Private facing update properties method.""" old = self.entities[entity_id] changes = {} @@ -136,6 +148,10 @@ class EntityRegistry: if name is not _UNDEF and name != old.name: changes['name'] = name + if (config_entry_id is not _UNDEF and + config_entry_id != old.config_entry_id): + changes['config_entry_id'] = config_entry_id + if not changes: return old diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 6808206243f..5a9efd5c041 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -107,7 +107,8 @@ def test_loading_saving_data(hass, registry): # Ensure same order assert list(registry.entities) == list(registry2.entities) new_entry1 = registry.async_get_or_create('light', 'hue', '1234') - new_entry2 = registry.async_get_or_create('light', 'hue', '5678') + new_entry2 = registry.async_get_or_create('light', 'hue', '5678', + config_entry_id='mock-id') assert orig_entry1 == new_entry1 assert orig_entry2 == new_entry2 @@ -191,3 +192,13 @@ def test_async_get_entity_id(registry): assert registry.async_get_entity_id( 'light', 'hue', '1234') == 'light.hue_1234' assert registry.async_get_entity_id('light', 'hue', '123') is None + + +async def test_updating_config_entry_id(registry): + """Test that we update config entry id in registry.""" + entry = registry.async_get_or_create( + 'light', 'hue', '5678', config_entry_id='mock-id-1') + entry2 = registry.async_get_or_create( + 'light', 'hue', '5678', config_entry_id='mock-id-2') + assert entry.entity_id == entry2.entity_id + assert entry2.config_entry_id == 'mock-id-2' -- GitLab