From 7fc3f8e47330cea60bec605d7ad8d334ca184d40 Mon Sep 17 00:00:00 2001 From: Erik Montnemery <erik@montnemery.com> Date: Sun, 14 Jan 2024 11:06:35 +0100 Subject: [PATCH] Improve calls to async_show_progress in octoprint (#107792) --- .../components/octoprint/config_flow.py | 45 ++++++++----------- .../components/octoprint/test_config_flow.py | 20 +++++---- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/homeassistant/components/octoprint/config_flow.py b/homeassistant/components/octoprint/config_flow.py index 696898400bf..01a3e9518c0 100644 --- a/homeassistant/components/octoprint/config_flow.py +++ b/homeassistant/components/octoprint/config_flow.py @@ -53,12 +53,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): VERSION = 1 api_key_task: asyncio.Task[None] | None = None + discovery_schema: vol.Schema | None = None _reauth_data: dict[str, Any] | None = None + _user_input: dict[str, Any] | None = None def __init__(self) -> None: """Handle a config flow for OctoPrint.""" - self.discovery_schema = None - self._user_input = None self._sessions: list[aiohttp.ClientSession] = [] async def async_step_user(self, user_input=None): @@ -97,17 +97,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): ), ) - self.api_key_task = None - return await self.async_step_get_api_key(user_input) + self._user_input = user_input + return await self.async_step_get_api_key() - async def async_step_get_api_key(self, user_input): + async def async_step_get_api_key(self, user_input=None): """Get an Application Api Key.""" if not self.api_key_task: - self.api_key_task = self.hass.async_create_task( - self._async_get_auth_key(user_input) - ) + self.api_key_task = self.hass.async_create_task(self._async_get_auth_key()) + if not self.api_key_task.done(): return self.async_show_progress( - step_id="get_api_key", progress_action="get_api_key" + step_id="get_api_key", + progress_action="get_api_key", + progress_task=self.api_key_task, ) try: @@ -118,9 +119,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): except Exception as err: # pylint: disable=broad-except _LOGGER.exception("Failed to get an application key : %s", err) return self.async_show_progress_done(next_step_id="auth_failed") + finally: + self.api_key_task = None - # store this off here to pick back up in the user step - self._user_input = user_input return self.async_show_progress_done(next_step_id="user") async def _finish_config(self, user_input: dict): @@ -238,26 +239,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): ), ) - self.api_key_task = None self._reauth_data[CONF_USERNAME] = user_input[CONF_USERNAME] - return await self.async_step_get_api_key(self._reauth_data) + self._user_input = self._reauth_data + return await self.async_step_get_api_key() - async def _async_get_auth_key(self, user_input: dict): + async def _async_get_auth_key(self): """Get application api key.""" - octoprint = self._get_octoprint_client(user_input) + octoprint = self._get_octoprint_client(self._user_input) - try: - user_input[CONF_API_KEY] = await octoprint.request_app_key( - "Home Assistant", user_input[CONF_USERNAME], 300 - ) - finally: - # Continue the flow after show progress when the task is done. - self.hass.async_create_task( - self.hass.config_entries.flow.async_configure( - flow_id=self.flow_id, user_input=user_input - ) - ) + self._user_input[CONF_API_KEY] = await octoprint.request_app_key( + "Home Assistant", self._user_input[CONF_USERNAME], 300 + ) def _get_octoprint_client(self, user_input: dict) -> OctoprintClient: """Build an octoprint client from the user_input.""" diff --git a/tests/components/octoprint/test_config_flow.py b/tests/components/octoprint/test_config_flow.py index e3cf45708fa..8e20983a791 100644 --- a/tests/components/octoprint/test_config_flow.py +++ b/tests/components/octoprint/test_config_flow.py @@ -95,8 +95,9 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None: result = await hass.config_entries.flow.async_configure( result["flow_id"], ) + await hass.async_block_till_done() + assert result["type"] == "progress" - assert result["type"] == "progress_done" with patch( "pyoctoprintapi.OctoprintClient.get_discovery_info", side_effect=ApiError, @@ -144,8 +145,9 @@ async def test_form_unknown_exception(hass: HomeAssistant) -> None: result = await hass.config_entries.flow.async_configure( result["flow_id"], ) + await hass.async_block_till_done() + assert result["type"] == "progress" - assert result["type"] == "progress_done" with patch( "pyoctoprintapi.OctoprintClient.get_discovery_info", side_effect=Exception, @@ -203,7 +205,7 @@ async def test_show_zerconf_form(hass: HomeAssistant) -> None: ) await hass.async_block_till_done() - assert result["type"] == "progress_done" + assert result["type"] == "progress" with patch( "pyoctoprintapi.OctoprintClient.get_server_info", @@ -269,7 +271,7 @@ async def test_show_ssdp_form(hass: HomeAssistant) -> None: ) await hass.async_block_till_done() - assert result["type"] == "progress_done" + assert result["type"] == "progress" with patch( "pyoctoprintapi.OctoprintClient.get_server_info", @@ -390,10 +392,11 @@ async def test_failed_auth(hass: HomeAssistant) -> None: result = await hass.config_entries.flow.async_configure( result["flow_id"], ) + await hass.async_block_till_done() - assert result["type"] == "progress_done" - result = await hass.config_entries.flow.async_configure(result["flow_id"]) + assert result["type"] == "progress" + result = await hass.config_entries.flow.async_configure(result["flow_id"]) assert result["type"] == "abort" assert result["reason"] == "auth_failed" @@ -421,10 +424,11 @@ async def test_failed_auth_unexpected_error(hass: HomeAssistant) -> None: result = await hass.config_entries.flow.async_configure( result["flow_id"], ) + await hass.async_block_till_done() - assert result["type"] == "progress_done" - result = await hass.config_entries.flow.async_configure(result["flow_id"]) + assert result["type"] == "progress" + result = await hass.config_entries.flow.async_configure(result["flow_id"]) assert result["type"] == "abort" assert result["reason"] == "auth_failed" -- GitLab