diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 04d9cc450ba0dab5aa83656cefe960bbf3478823..b222d78b57752d3712fc964162032654fa1d0d96 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -109,6 +109,12 @@ class EntityRegistry: """Get entity. Create if it doesn't exist.""" entity_id = self.async_get_entity_id(domain, platform, unique_id) if entity_id: + entry = self.entities[entity_id] + if entry.config_entry_id == config_entry_id: + return entry + + self._async_update_entity( + entity_id, config_entry_id=config_entry_id) return self.entities[entity_id] entity_id = self.async_generate_entity_id( @@ -129,6 +135,12 @@ class EntityRegistry: @callback def async_update_entity(self, entity_id, *, name=_UNDEF): """Update properties of an entity.""" + return self._async_update_entity(entity_id, name=name) + + @callback + def _async_update_entity(self, entity_id, *, name=_UNDEF, + config_entry_id=_UNDEF): + """Private facing update properties method.""" old = self.entities[entity_id] changes = {} @@ -136,6 +148,10 @@ class EntityRegistry: if name is not _UNDEF and name != old.name: changes['name'] = name + if (config_entry_id is not _UNDEF and + config_entry_id != old.config_entry_id): + changes['config_entry_id'] = config_entry_id + if not changes: return old diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 6808206243f8d1c920a164164971f14e4d54dc7c..5a9efd5c041f7c320f0c23c5e260d822e6fe2fa1 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -107,7 +107,8 @@ def test_loading_saving_data(hass, registry): # Ensure same order assert list(registry.entities) == list(registry2.entities) new_entry1 = registry.async_get_or_create('light', 'hue', '1234') - new_entry2 = registry.async_get_or_create('light', 'hue', '5678') + new_entry2 = registry.async_get_or_create('light', 'hue', '5678', + config_entry_id='mock-id') assert orig_entry1 == new_entry1 assert orig_entry2 == new_entry2 @@ -191,3 +192,13 @@ def test_async_get_entity_id(registry): assert registry.async_get_entity_id( 'light', 'hue', '1234') == 'light.hue_1234' assert registry.async_get_entity_id('light', 'hue', '123') is None + + +async def test_updating_config_entry_id(registry): + """Test that we update config entry id in registry.""" + entry = registry.async_get_or_create( + 'light', 'hue', '5678', config_entry_id='mock-id-1') + entry2 = registry.async_get_or_create( + 'light', 'hue', '5678', config_entry_id='mock-id-2') + assert entry.entity_id == entry2.entity_id + assert entry2.config_entry_id == 'mock-id-2'