From 5694e4190cddcab365cb0eb10201e5e1e193a9f0 Mon Sep 17 00:00:00 2001 From: Erik Montnemery <erik@montnemery.com> Date: Mon, 9 Nov 2020 19:47:45 +0100 Subject: [PATCH] Extend WS API result when enabling an entity (#42667) * Extend WS API result when enabling an entity * Fix tests * Fix tests * Move updated registry entry to sub dict * Address review comments * Increase test coverage --- .../components/config/entity_registry.py | 14 +- homeassistant/config_entries.py | 6 +- .../components/config/test_entity_registry.py | 246 ++++++++++++++---- tests/components/tasmota/test_sensor.py | 6 +- tests/test_config_entries.py | 6 +- 5 files changed, 212 insertions(+), 66 deletions(-) diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index 6e4fa30e205..73327ecf23c 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -1,6 +1,7 @@ """HTTP views to interact with the entity registry.""" import voluptuous as vol +from homeassistant import config_entries from homeassistant.components import websocket_api from homeassistant.components.websocket_api.const import ERR_NOT_FOUND from homeassistant.components.websocket_api.decorators import ( @@ -113,10 +114,15 @@ async def websocket_update_entity(hass, connection, msg): connection.send_message( websocket_api.error_message(msg["id"], "invalid_info", str(err)) ) - else: - connection.send_message( - websocket_api.result_message(msg["id"], _entry_ext_dict(entry)) - ) + return + result = {"entity_entry": _entry_ext_dict(entry)} + if "disabled_by" in changes and changes["disabled_by"] is None: + config_entry = hass.config_entries.async_get_entry(entry.config_entry_id) + if config_entry and not config_entry.supports_unload: + result["require_restart"] = True + else: + result["reload_delay"] = config_entries.RELOAD_AFTER_UPDATE_DELAY + connection.send_result(msg["id"], result) @require_admin diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 7fa343146d2..af82db0ffbb 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -89,6 +89,8 @@ CONN_CLASS_LOCAL_POLL = "local_poll" CONN_CLASS_ASSUMED = "assumed" CONN_CLASS_UNKNOWN = "unknown" +RELOAD_AFTER_UPDATE_DELAY = 30 + class ConfigError(HomeAssistantError): """Error while configuring an account.""" @@ -1112,8 +1114,6 @@ class SystemOptions: class EntityRegistryDisabledHandler: """Handler to handle when entities related to config entries updating disabled_by.""" - RELOAD_AFTER_UPDATE_DELAY = 30 - def __init__(self, hass: HomeAssistant) -> None: """Initialize the handler.""" self.hass = hass @@ -1170,7 +1170,7 @@ class EntityRegistryDisabledHandler: self._remove_call_later() self._remove_call_later = self.hass.helpers.event.async_call_later( - self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload + RELOAD_AFTER_UPDATE_DELAY, self._handle_reload ) async def _handle_reload(self, _now: Any) -> None: diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py index 84a646ed2ef..a506135c16d 100644 --- a/tests/components/config/test_entity_registry.py +++ b/tests/components/config/test_entity_registry.py @@ -7,7 +7,7 @@ from homeassistant.components.config import entity_registry from homeassistant.const import ATTR_ICON from homeassistant.helpers.entity_registry import RegistryEntry -from tests.common import MockEntity, MockEntityPlatform, mock_registry +from tests.common import MockConfigEntry, MockEntity, MockEntityPlatform, mock_registry @pytest.fixture @@ -162,18 +162,20 @@ async def test_update_entity(hass, client): msg = await client.receive_json() assert msg["result"] == { - "config_entry_id": None, - "device_id": None, - "area_id": "mock-area-id", - "disabled_by": None, - "platform": "test_platform", - "entity_id": "test_domain.world", - "name": "after update", - "icon": "icon:after update", - "original_name": None, - "original_icon": None, - "capabilities": None, - "unique_id": "1234", + "entity_entry": { + "config_entry_id": None, + "device_id": None, + "area_id": "mock-area-id", + "disabled_by": None, + "platform": "test_platform", + "entity_id": "test_domain.world", + "name": "after update", + "icon": "icon:after update", + "original_name": None, + "original_icon": None, + "capabilities": None, + "unique_id": "1234", + } } state = hass.states.get("test_domain.world") @@ -208,18 +210,75 @@ async def test_update_entity(hass, client): msg = await client.receive_json() assert msg["result"] == { - "config_entry_id": None, - "device_id": None, - "area_id": "mock-area-id", - "disabled_by": None, - "platform": "test_platform", - "entity_id": "test_domain.world", - "name": "after update", - "icon": "icon:after update", - "original_name": None, - "original_icon": None, - "capabilities": None, - "unique_id": "1234", + "entity_entry": { + "config_entry_id": None, + "device_id": None, + "area_id": "mock-area-id", + "disabled_by": None, + "platform": "test_platform", + "entity_id": "test_domain.world", + "name": "after update", + "icon": "icon:after update", + "original_name": None, + "original_icon": None, + "capabilities": None, + "unique_id": "1234", + }, + "reload_delay": 30, + } + + +async def test_update_entity_require_restart(hass, client): + """Test updating entity.""" + config_entry = MockConfigEntry(domain="test_platform") + config_entry.add_to_hass(hass) + mock_registry( + hass, + { + "test_domain.world": RegistryEntry( + config_entry_id=config_entry.entry_id, + entity_id="test_domain.world", + unique_id="1234", + # Using component.async_add_entities is equal to platform "domain" + platform="test_platform", + ) + }, + ) + platform = MockEntityPlatform(hass) + entity = MockEntity(unique_id="1234") + await platform.async_add_entities([entity]) + + state = hass.states.get("test_domain.world") + assert state is not None + + # UPDATE DISABLED_BY TO NONE + await client.send_json( + { + "id": 8, + "type": "config/entity_registry/update", + "entity_id": "test_domain.world", + "disabled_by": None, + } + ) + + msg = await client.receive_json() + + assert msg["result"] == { + "entity_entry": { + "config_entry_id": config_entry.entry_id, + "device_id": None, + "area_id": None, + "disabled_by": None, + "platform": "test_platform", + "entity_id": "test_domain.world", + "name": None, + "icon": None, + "original_name": None, + "original_icon": None, + "capabilities": None, + "unique_id": "1234", + }, + "require_restart": True, } @@ -257,18 +316,20 @@ async def test_update_entity_no_changes(hass, client): msg = await client.receive_json() assert msg["result"] == { - "config_entry_id": None, - "device_id": None, - "area_id": None, - "disabled_by": None, - "platform": "test_platform", - "entity_id": "test_domain.world", - "name": "name of entity", - "icon": None, - "original_name": None, - "original_icon": None, - "capabilities": None, - "unique_id": "1234", + "entity_entry": { + "config_entry_id": None, + "device_id": None, + "area_id": None, + "disabled_by": None, + "platform": "test_platform", + "entity_id": "test_domain.world", + "name": "name of entity", + "icon": None, + "original_name": None, + "original_icon": None, + "capabilities": None, + "unique_id": "1234", + } } state = hass.states.get("test_domain.world") @@ -335,24 +396,94 @@ async def test_update_entity_id(hass, client): msg = await client.receive_json() assert msg["result"] == { - "config_entry_id": None, - "device_id": None, - "area_id": None, - "disabled_by": None, - "platform": "test_platform", - "entity_id": "test_domain.planet", - "name": None, - "icon": None, - "original_name": None, - "original_icon": None, - "capabilities": None, - "unique_id": "1234", + "entity_entry": { + "config_entry_id": None, + "device_id": None, + "area_id": None, + "disabled_by": None, + "platform": "test_platform", + "entity_id": "test_domain.planet", + "name": None, + "icon": None, + "original_name": None, + "original_icon": None, + "capabilities": None, + "unique_id": "1234", + } } assert hass.states.get("test_domain.world") is None assert hass.states.get("test_domain.planet") is not None +async def test_update_existing_entity_id(hass, client): + """Test update entity id to an already registered entity id.""" + mock_registry( + hass, + { + "test_domain.world": RegistryEntry( + entity_id="test_domain.world", + unique_id="1234", + # Using component.async_add_entities is equal to platform "domain" + platform="test_platform", + ), + "test_domain.planet": RegistryEntry( + entity_id="test_domain.planet", + unique_id="2345", + # Using component.async_add_entities is equal to platform "domain" + platform="test_platform", + ), + }, + ) + platform = MockEntityPlatform(hass) + entities = [MockEntity(unique_id="1234"), MockEntity(unique_id="2345")] + await platform.async_add_entities(entities) + + await client.send_json( + { + "id": 6, + "type": "config/entity_registry/update", + "entity_id": "test_domain.world", + "new_entity_id": "test_domain.planet", + } + ) + + msg = await client.receive_json() + + assert not msg["success"] + + +async def test_update_invalid_entity_id(hass, client): + """Test update entity id to an invalid entity id.""" + mock_registry( + hass, + { + "test_domain.world": RegistryEntry( + entity_id="test_domain.world", + unique_id="1234", + # Using component.async_add_entities is equal to platform "domain" + platform="test_platform", + ) + }, + ) + platform = MockEntityPlatform(hass) + entities = [MockEntity(unique_id="1234"), MockEntity(unique_id="2345")] + await platform.async_add_entities(entities) + + await client.send_json( + { + "id": 6, + "type": "config/entity_registry/update", + "entity_id": "test_domain.world", + "new_entity_id": "another_domain.planet", + } + ) + + msg = await client.receive_json() + + assert not msg["success"] + + async def test_remove_entity(hass, client): """Test removing entity.""" registry = mock_registry( @@ -380,3 +511,20 @@ async def test_remove_entity(hass, client): assert msg["success"] assert len(registry.entities) == 0 + + +async def test_remove_non_existing_entity(hass, client): + """Test removing non existing entity.""" + mock_registry(hass, {}) + + await client.send_json( + { + "id": 6, + "type": "config/entity_registry/remove", + "entity_id": "test_domain.world", + } + ) + + msg = await client.receive_json() + + assert not msg["success"] diff --git a/tests/components/tasmota/test_sensor.py b/tests/components/tasmota/test_sensor.py index b8e5583579b..8da08b16376 100644 --- a/tests/components/tasmota/test_sensor.py +++ b/tests/components/tasmota/test_sensor.py @@ -412,11 +412,7 @@ async def test_enable_status_sensor(hass, mqtt_mock, setup_tasmota): async_fire_time_changed( hass, - dt.utcnow() - + timedelta( - seconds=config_entries.EntityRegistryDisabledHandler.RELOAD_AFTER_UPDATE_DELAY - + 1 - ), + dt.utcnow() + timedelta(seconds=config_entries.RELOAD_AFTER_UPDATE_DELAY + 1), ) await hass.async_block_till_done() diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 5d63964d6b1..59e1b0754c0 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -1189,11 +1189,7 @@ async def test_reload_entry_entity_registry_works(hass): async_fire_time_changed( hass, - dt.utcnow() - + timedelta( - seconds=config_entries.EntityRegistryDisabledHandler.RELOAD_AFTER_UPDATE_DELAY - + 1 - ), + dt.utcnow() + timedelta(seconds=config_entries.RELOAD_AFTER_UPDATE_DELAY + 1), ) await hass.async_block_till_done() -- GitLab