diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index eef36b026e1e75f7a578186c0322831dd82095f5..d37e3babac846fe2fe1696adaed6960ebc0f8f7e 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -117,7 +117,7 @@ async def async_from_config_dict(config: Dict[str, Any], hass, config, core_config.get(conf_util.CONF_PACKAGES, {})) hass.config_entries = config_entries.ConfigEntries(hass, config) - await hass.config_entries.async_load() + await hass.config_entries.async_initialize() # Filter out the repeating and common config section [homeassistant] components = set(key.split(' ')[0] for key in config.keys() diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 7b22c2e197c04024a8f75b6c8a1be7ddb0f18c43..696965613039309ae638c0b9076ef71073fca369 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -119,6 +119,7 @@ should follow the same return values as a normal step. If the result of the step is to show a form, the user will be able to continue the flow from the config panel. """ +import asyncio import logging import functools import uuid @@ -205,6 +206,11 @@ ENTRY_STATE_NOT_LOADED = 'not_loaded' # An error occurred when trying to unload the entry ENTRY_STATE_FAILED_UNLOAD = 'failed_unload' +UNRECOVERABLE_STATES = ( + ENTRY_STATE_MIGRATION_ERROR, + ENTRY_STATE_FAILED_UNLOAD, +) + DISCOVERY_NOTIFICATION_ID = 'config_entry_discovery' DISCOVERY_SOURCES = ( SOURCE_DISCOVERY, @@ -221,6 +227,18 @@ CONN_CLASS_ASSUMED = 'assumed' CONN_CLASS_UNKNOWN = 'unknown' +class ConfigError(HomeAssistantError): + """Error while configuring an account.""" + + +class UnknownEntry(ConfigError): + """Unknown entry specified.""" + + +class OperationNotAllowed(ConfigError): + """Raised when a config entry operation is not allowed.""" + + class ConfigEntry: """Hold a configuration entry.""" @@ -228,7 +246,7 @@ class ConfigEntry: 'source', 'connection_class', 'state', '_setup_lock', 'update_listeners', '_async_cancel_retry_setup') - def __init__(self, version: str, domain: str, title: str, data: dict, + def __init__(self, version: int, domain: str, title: str, data: dict, source: str, connection_class: str, options: Optional[dict] = None, entry_id: Optional[str] = None, @@ -283,7 +301,7 @@ class ConfigEntry: result = await component.async_setup_entry(hass, self) if not isinstance(result, bool): - _LOGGER.error('%s.async_config_entry did not return boolean', + _LOGGER.error('%s.async_setup_entry did not return boolean', component.DOMAIN) result = False except ConfigEntryNotReady: @@ -316,7 +334,7 @@ class ConfigEntry: else: self.state = ENTRY_STATE_SETUP_ERROR - async def async_unload(self, hass, *, component=None): + async def async_unload(self, hass, *, component=None) -> bool: """Unload an entry. Returns if unload is possible and was successful. @@ -325,17 +343,22 @@ class ConfigEntry: component = getattr(hass.components, self.domain) if component.DOMAIN == self.domain: - if self._async_cancel_retry_setup is not None: - self._async_cancel_retry_setup() - self.state = ENTRY_STATE_NOT_LOADED - return True + if self.state in UNRECOVERABLE_STATES: + return False if self.state != ENTRY_STATE_LOADED: + if self._async_cancel_retry_setup is not None: + self._async_cancel_retry_setup() + self._async_cancel_retry_setup = None + + self.state = ENTRY_STATE_NOT_LOADED return True supports_unload = hasattr(component, 'async_unload_entry') if not supports_unload: + if component.DOMAIN == self.domain: + self.state = ENTRY_STATE_FAILED_UNLOAD return False try: @@ -420,14 +443,6 @@ class ConfigEntry: } -class ConfigError(HomeAssistantError): - """Error while configuring an account.""" - - -class UnknownEntry(ConfigError): - """Unknown entry specified.""" - - class ConfigEntries: """Manage the configuration entries. @@ -474,34 +489,33 @@ class ConfigEntries: async def async_remove(self, entry_id): """Remove an entry.""" - found = None - for index, entry in enumerate(self._entries): - if entry.entry_id == entry_id: - found = index - break + entry = self.async_get_entry(entry_id) - if found is None: + if entry is None: raise UnknownEntry - entry = self._entries.pop(found) - self._async_schedule_save() + if entry.state in UNRECOVERABLE_STATES: + unload_success = entry.state != ENTRY_STATE_FAILED_UNLOAD + else: + unload_success = await self.async_unload(entry_id) - unloaded = await entry.async_unload(self.hass) + self._entries.remove(entry) + self._async_schedule_save() - device_registry = await \ - self.hass.helpers.device_registry.async_get_registry() - device_registry.async_clear_config_entry(entry_id) + dev_reg, ent_reg = await asyncio.gather( + self.hass.helpers.device_registry.async_get_registry(), + self.hass.helpers.entity_registry.async_get_registry(), + ) - entity_registry = await \ - self.hass.helpers.entity_registry.async_get_registry() - entity_registry.async_clear_config_entry(entry_id) + dev_reg.async_clear_config_entry(entry_id) + ent_reg.async_clear_config_entry(entry_id) return { - 'require_restart': not unloaded + 'require_restart': not unload_success } - async def async_load(self) -> None: - """Handle loading the config.""" + async def async_initialize(self) -> None: + """Initialize config entry config.""" # Migrating for config entries stored before 0.73 config = await self.hass.helpers.storage.async_migrator( self.hass.config.path(PATH_CONFIG), self._store, @@ -527,6 +541,56 @@ class ConfigEntries: options=entry.get('options')) for entry in config['entries']] + async def async_setup(self, entry_id: str) -> bool: + """Set up a config entry. + + Return True if entry has been successfully loaded. + """ + entry = self.async_get_entry(entry_id) + + if entry is None: + raise UnknownEntry + + if entry.state != ENTRY_STATE_NOT_LOADED: + raise OperationNotAllowed + + # Setup Component if not set up yet + if entry.domain in self.hass.config.components: + await entry.async_setup(self.hass) + else: + # Setting up the component will set up all its config entries + result = await async_setup_component( + self.hass, entry.domain, self._hass_config) + + if not result: + return result + + return entry.state == ENTRY_STATE_LOADED + + async def async_unload(self, entry_id: str) -> bool: + """Unload a config entry.""" + entry = self.async_get_entry(entry_id) + + if entry is None: + raise UnknownEntry + + if entry.state in UNRECOVERABLE_STATES: + raise OperationNotAllowed + + return await entry.async_unload(self.hass) + + async def async_reload(self, entry_id: str) -> bool: + """Reload an entry. + + If an entry was not loaded, will just load. + """ + unload_result = await self.async_unload(entry_id) + + if not unload_result: + return unload_result + + return await self.async_setup(entry_id) + @callback def async_update_entry(self, entry, *, data=_UNDEF, options=_UNDEF): """Update a config entry.""" @@ -597,14 +661,7 @@ class ConfigEntries: self._entries.append(entry) self._async_schedule_save() - # Setup entry - if entry.domain in self.hass.config.components: - # Component already set up, just need to call setup_entry - await entry.async_setup(self.hass) - else: - # Setting up component will also load the entries - await async_setup_component( - self.hass, entry.domain, self._hass_config) + await self.async_setup(entry.entry_id) result['result'] = entry return result diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 8991035cc225f49bca404aadc11e88dbdcda42ca..e7a5b7637968ec645c8fafc2c6b504ffdadcaeb9 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -407,7 +407,7 @@ async def test_saving_and_loading(hass): # Now load written data in new config manager manager = config_entries.ConfigEntries(hass, {}) - await manager.async_load() + await manager.async_initialize() # Ensure same order for orig, loaded in zip(hass.config_entries.async_entries(), @@ -518,7 +518,7 @@ async def test_loading_default_config(hass): manager = config_entries.ConfigEntries(hass, {}) with patch('homeassistant.util.json.open', side_effect=FileNotFoundError): - await manager.async_load() + await manager.async_initialize() assert len(manager.async_entries()) == 0 @@ -650,3 +650,219 @@ async def test_entry_options(hass, manager): assert entry.options == { 'second': True } + + +async def test_entry_setup_succeed(hass, manager): + """Test that we can setup an entry.""" + entry = MockConfigEntry( + domain='comp', + state=config_entries.ENTRY_STATE_NOT_LOADED + ) + entry.add_to_hass(hass) + + mock_setup = MagicMock(return_value=mock_coro(True)) + mock_setup_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component(hass, 'comp', MockModule( + 'comp', + async_setup=mock_setup, + async_setup_entry=mock_setup_entry + )) + + assert await manager.async_setup(entry.entry_id) + assert len(mock_setup.mock_calls) == 1 + assert len(mock_setup_entry.mock_calls) == 1 + assert entry.state == config_entries.ENTRY_STATE_LOADED + + +@pytest.mark.parametrize('state', ( + config_entries.ENTRY_STATE_LOADED, + config_entries.ENTRY_STATE_SETUP_ERROR, + config_entries.ENTRY_STATE_MIGRATION_ERROR, + config_entries.ENTRY_STATE_SETUP_RETRY, + config_entries.ENTRY_STATE_FAILED_UNLOAD, +)) +async def test_entry_setup_invalid_state(hass, manager, state): + """Test that we cannot setup an entry with invalid state.""" + entry = MockConfigEntry( + domain='comp', + state=state + ) + entry.add_to_hass(hass) + + mock_setup = MagicMock(return_value=mock_coro(True)) + mock_setup_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component(hass, 'comp', MockModule( + 'comp', + async_setup=mock_setup, + async_setup_entry=mock_setup_entry + )) + + with pytest.raises(config_entries.OperationNotAllowed): + assert await manager.async_setup(entry.entry_id) + + assert len(mock_setup.mock_calls) == 0 + assert len(mock_setup_entry.mock_calls) == 0 + assert entry.state == state + + +async def test_entry_unload_succeed(hass, manager): + """Test that we can unload an entry.""" + entry = MockConfigEntry( + domain='comp', + state=config_entries.ENTRY_STATE_LOADED + ) + entry.add_to_hass(hass) + + async_unload_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component(hass, 'comp', MockModule( + 'comp', + async_unload_entry=async_unload_entry + )) + + assert await manager.async_unload(entry.entry_id) + assert len(async_unload_entry.mock_calls) == 1 + assert entry.state == config_entries.ENTRY_STATE_NOT_LOADED + + +@pytest.mark.parametrize('state', ( + config_entries.ENTRY_STATE_NOT_LOADED, + config_entries.ENTRY_STATE_SETUP_ERROR, + config_entries.ENTRY_STATE_SETUP_RETRY, +)) +async def test_entry_unload_failed_to_load(hass, manager, state): + """Test that we can unload an entry.""" + entry = MockConfigEntry( + domain='comp', + state=state, + ) + entry.add_to_hass(hass) + + async_unload_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component(hass, 'comp', MockModule( + 'comp', + async_unload_entry=async_unload_entry + )) + + assert await manager.async_unload(entry.entry_id) + assert len(async_unload_entry.mock_calls) == 0 + assert entry.state == config_entries.ENTRY_STATE_NOT_LOADED + + +@pytest.mark.parametrize('state', ( + config_entries.ENTRY_STATE_MIGRATION_ERROR, + config_entries.ENTRY_STATE_FAILED_UNLOAD, +)) +async def test_entry_unload_invalid_state(hass, manager, state): + """Test that we cannot unload an entry with invalid state.""" + entry = MockConfigEntry( + domain='comp', + state=state + ) + entry.add_to_hass(hass) + + async_unload_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component(hass, 'comp', MockModule( + 'comp', + async_unload_entry=async_unload_entry + )) + + with pytest.raises(config_entries.OperationNotAllowed): + assert await manager.async_unload(entry.entry_id) + + assert len(async_unload_entry.mock_calls) == 0 + assert entry.state == state + + +async def test_entry_reload_succeed(hass, manager): + """Test that we can reload an entry.""" + entry = MockConfigEntry( + domain='comp', + state=config_entries.ENTRY_STATE_LOADED + ) + entry.add_to_hass(hass) + + async_setup = MagicMock(return_value=mock_coro(True)) + async_setup_entry = MagicMock(return_value=mock_coro(True)) + async_unload_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component(hass, 'comp', MockModule( + 'comp', + async_setup=async_setup, + async_setup_entry=async_setup_entry, + async_unload_entry=async_unload_entry + )) + + assert await manager.async_reload(entry.entry_id) + assert len(async_unload_entry.mock_calls) == 1 + assert len(async_setup.mock_calls) == 1 + assert len(async_setup_entry.mock_calls) == 1 + assert entry.state == config_entries.ENTRY_STATE_LOADED + + +@pytest.mark.parametrize('state', ( + config_entries.ENTRY_STATE_NOT_LOADED, + config_entries.ENTRY_STATE_SETUP_ERROR, + config_entries.ENTRY_STATE_SETUP_RETRY, +)) +async def test_entry_reload_not_loaded(hass, manager, state): + """Test that we can reload an entry.""" + entry = MockConfigEntry( + domain='comp', + state=state + ) + entry.add_to_hass(hass) + + async_setup = MagicMock(return_value=mock_coro(True)) + async_setup_entry = MagicMock(return_value=mock_coro(True)) + async_unload_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component(hass, 'comp', MockModule( + 'comp', + async_setup=async_setup, + async_setup_entry=async_setup_entry, + async_unload_entry=async_unload_entry + )) + + assert await manager.async_reload(entry.entry_id) + assert len(async_unload_entry.mock_calls) == 0 + assert len(async_setup.mock_calls) == 1 + assert len(async_setup_entry.mock_calls) == 1 + assert entry.state == config_entries.ENTRY_STATE_LOADED + + +@pytest.mark.parametrize('state', ( + config_entries.ENTRY_STATE_MIGRATION_ERROR, + config_entries.ENTRY_STATE_FAILED_UNLOAD, +)) +async def test_entry_reload_error(hass, manager, state): + """Test that we can reload an entry.""" + entry = MockConfigEntry( + domain='comp', + state=state + ) + entry.add_to_hass(hass) + + async_setup = MagicMock(return_value=mock_coro(True)) + async_setup_entry = MagicMock(return_value=mock_coro(True)) + async_unload_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component(hass, 'comp', MockModule( + 'comp', + async_setup=async_setup, + async_setup_entry=async_setup_entry, + async_unload_entry=async_unload_entry + )) + + with pytest.raises(config_entries.OperationNotAllowed): + assert await manager.async_reload(entry.entry_id) + + assert len(async_unload_entry.mock_calls) == 0 + assert len(async_setup.mock_calls) == 0 + assert len(async_setup_entry.mock_calls) == 0 + + assert entry.state == state