From 7fc3f8e47330cea60bec605d7ad8d334ca184d40 Mon Sep 17 00:00:00 2001
From: Erik Montnemery <erik@montnemery.com>
Date: Sun, 14 Jan 2024 11:06:35 +0100
Subject: [PATCH] Improve calls to async_show_progress in octoprint (#107792)

---
 .../components/octoprint/config_flow.py       | 45 ++++++++-----------
 .../components/octoprint/test_config_flow.py  | 20 +++++----
 2 files changed, 31 insertions(+), 34 deletions(-)

diff --git a/homeassistant/components/octoprint/config_flow.py b/homeassistant/components/octoprint/config_flow.py
index 696898400bf..01a3e9518c0 100644
--- a/homeassistant/components/octoprint/config_flow.py
+++ b/homeassistant/components/octoprint/config_flow.py
@@ -53,12 +53,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
     VERSION = 1
 
     api_key_task: asyncio.Task[None] | None = None
+    discovery_schema: vol.Schema | None = None
     _reauth_data: dict[str, Any] | None = None
+    _user_input: dict[str, Any] | None = None
 
     def __init__(self) -> None:
         """Handle a config flow for OctoPrint."""
-        self.discovery_schema = None
-        self._user_input = None
         self._sessions: list[aiohttp.ClientSession] = []
 
     async def async_step_user(self, user_input=None):
@@ -97,17 +97,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
                     ),
                 )
 
-        self.api_key_task = None
-        return await self.async_step_get_api_key(user_input)
+        self._user_input = user_input
+        return await self.async_step_get_api_key()
 
-    async def async_step_get_api_key(self, user_input):
+    async def async_step_get_api_key(self, user_input=None):
         """Get an Application Api Key."""
         if not self.api_key_task:
-            self.api_key_task = self.hass.async_create_task(
-                self._async_get_auth_key(user_input)
-            )
+            self.api_key_task = self.hass.async_create_task(self._async_get_auth_key())
+        if not self.api_key_task.done():
             return self.async_show_progress(
-                step_id="get_api_key", progress_action="get_api_key"
+                step_id="get_api_key",
+                progress_action="get_api_key",
+                progress_task=self.api_key_task,
             )
 
         try:
@@ -118,9 +119,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
         except Exception as err:  # pylint: disable=broad-except
             _LOGGER.exception("Failed to get an application key : %s", err)
             return self.async_show_progress_done(next_step_id="auth_failed")
+        finally:
+            self.api_key_task = None
 
-        # store this off here to pick back up in the user step
-        self._user_input = user_input
         return self.async_show_progress_done(next_step_id="user")
 
     async def _finish_config(self, user_input: dict):
@@ -238,26 +239,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
                 ),
             )
 
-        self.api_key_task = None
         self._reauth_data[CONF_USERNAME] = user_input[CONF_USERNAME]
 
-        return await self.async_step_get_api_key(self._reauth_data)
+        self._user_input = self._reauth_data
+        return await self.async_step_get_api_key()
 
-    async def _async_get_auth_key(self, user_input: dict):
+    async def _async_get_auth_key(self):
         """Get application api key."""
-        octoprint = self._get_octoprint_client(user_input)
+        octoprint = self._get_octoprint_client(self._user_input)
 
-        try:
-            user_input[CONF_API_KEY] = await octoprint.request_app_key(
-                "Home Assistant", user_input[CONF_USERNAME], 300
-            )
-        finally:
-            # Continue the flow after show progress when the task is done.
-            self.hass.async_create_task(
-                self.hass.config_entries.flow.async_configure(
-                    flow_id=self.flow_id, user_input=user_input
-                )
-            )
+        self._user_input[CONF_API_KEY] = await octoprint.request_app_key(
+            "Home Assistant", self._user_input[CONF_USERNAME], 300
+        )
 
     def _get_octoprint_client(self, user_input: dict) -> OctoprintClient:
         """Build an octoprint client from the user_input."""
diff --git a/tests/components/octoprint/test_config_flow.py b/tests/components/octoprint/test_config_flow.py
index e3cf45708fa..8e20983a791 100644
--- a/tests/components/octoprint/test_config_flow.py
+++ b/tests/components/octoprint/test_config_flow.py
@@ -95,8 +95,9 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None:
         result = await hass.config_entries.flow.async_configure(
             result["flow_id"],
         )
+        await hass.async_block_till_done()
+    assert result["type"] == "progress"
 
-    assert result["type"] == "progress_done"
     with patch(
         "pyoctoprintapi.OctoprintClient.get_discovery_info",
         side_effect=ApiError,
@@ -144,8 +145,9 @@ async def test_form_unknown_exception(hass: HomeAssistant) -> None:
         result = await hass.config_entries.flow.async_configure(
             result["flow_id"],
         )
+        await hass.async_block_till_done()
+    assert result["type"] == "progress"
 
-    assert result["type"] == "progress_done"
     with patch(
         "pyoctoprintapi.OctoprintClient.get_discovery_info",
         side_effect=Exception,
@@ -203,7 +205,7 @@ async def test_show_zerconf_form(hass: HomeAssistant) -> None:
         )
         await hass.async_block_till_done()
 
-    assert result["type"] == "progress_done"
+    assert result["type"] == "progress"
 
     with patch(
         "pyoctoprintapi.OctoprintClient.get_server_info",
@@ -269,7 +271,7 @@ async def test_show_ssdp_form(hass: HomeAssistant) -> None:
         )
         await hass.async_block_till_done()
 
-    assert result["type"] == "progress_done"
+    assert result["type"] == "progress"
 
     with patch(
         "pyoctoprintapi.OctoprintClient.get_server_info",
@@ -390,10 +392,11 @@ async def test_failed_auth(hass: HomeAssistant) -> None:
         result = await hass.config_entries.flow.async_configure(
             result["flow_id"],
         )
+        await hass.async_block_till_done()
 
-    assert result["type"] == "progress_done"
-    result = await hass.config_entries.flow.async_configure(result["flow_id"])
+    assert result["type"] == "progress"
 
+    result = await hass.config_entries.flow.async_configure(result["flow_id"])
     assert result["type"] == "abort"
     assert result["reason"] == "auth_failed"
 
@@ -421,10 +424,11 @@ async def test_failed_auth_unexpected_error(hass: HomeAssistant) -> None:
         result = await hass.config_entries.flow.async_configure(
             result["flow_id"],
         )
+        await hass.async_block_till_done()
 
-    assert result["type"] == "progress_done"
-    result = await hass.config_entries.flow.async_configure(result["flow_id"])
+    assert result["type"] == "progress"
 
+    result = await hass.config_entries.flow.async_configure(result["flow_id"])
     assert result["type"] == "abort"
     assert result["reason"] == "auth_failed"
 
-- 
GitLab