Skip to content
Snippets Groups Projects
Unverified Commit 7fc3f8e4 authored by Erik Montnemery's avatar Erik Montnemery Committed by GitHub
Browse files

Improve calls to async_show_progress in octoprint (#107792)

parent 1c9764bc
No related branches found
No related tags found
No related merge requests found
......@@ -53,12 +53,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1
api_key_task: asyncio.Task[None] | None = None
discovery_schema: vol.Schema | None = None
_reauth_data: dict[str, Any] | None = None
_user_input: dict[str, Any] | None = None
def __init__(self) -> None:
"""Handle a config flow for OctoPrint."""
self.discovery_schema = None
self._user_input = None
self._sessions: list[aiohttp.ClientSession] = []
async def async_step_user(self, user_input=None):
......@@ -97,17 +97,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
),
)
self.api_key_task = None
return await self.async_step_get_api_key(user_input)
self._user_input = user_input
return await self.async_step_get_api_key()
async def async_step_get_api_key(self, user_input):
async def async_step_get_api_key(self, user_input=None):
"""Get an Application Api Key."""
if not self.api_key_task:
self.api_key_task = self.hass.async_create_task(
self._async_get_auth_key(user_input)
)
self.api_key_task = self.hass.async_create_task(self._async_get_auth_key())
if not self.api_key_task.done():
return self.async_show_progress(
step_id="get_api_key", progress_action="get_api_key"
step_id="get_api_key",
progress_action="get_api_key",
progress_task=self.api_key_task,
)
try:
......@@ -118,9 +119,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Failed to get an application key : %s", err)
return self.async_show_progress_done(next_step_id="auth_failed")
finally:
self.api_key_task = None
# store this off here to pick back up in the user step
self._user_input = user_input
return self.async_show_progress_done(next_step_id="user")
async def _finish_config(self, user_input: dict):
......@@ -238,26 +239,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
),
)
self.api_key_task = None
self._reauth_data[CONF_USERNAME] = user_input[CONF_USERNAME]
return await self.async_step_get_api_key(self._reauth_data)
self._user_input = self._reauth_data
return await self.async_step_get_api_key()
async def _async_get_auth_key(self, user_input: dict):
async def _async_get_auth_key(self):
"""Get application api key."""
octoprint = self._get_octoprint_client(user_input)
octoprint = self._get_octoprint_client(self._user_input)
try:
user_input[CONF_API_KEY] = await octoprint.request_app_key(
"Home Assistant", user_input[CONF_USERNAME], 300
)
finally:
# Continue the flow after show progress when the task is done.
self.hass.async_create_task(
self.hass.config_entries.flow.async_configure(
flow_id=self.flow_id, user_input=user_input
)
)
self._user_input[CONF_API_KEY] = await octoprint.request_app_key(
"Home Assistant", self._user_input[CONF_USERNAME], 300
)
def _get_octoprint_client(self, user_input: dict) -> OctoprintClient:
"""Build an octoprint client from the user_input."""
......
......@@ -95,8 +95,9 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
)
await hass.async_block_till_done()
assert result["type"] == "progress"
assert result["type"] == "progress_done"
with patch(
"pyoctoprintapi.OctoprintClient.get_discovery_info",
side_effect=ApiError,
......@@ -144,8 +145,9 @@ async def test_form_unknown_exception(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
)
await hass.async_block_till_done()
assert result["type"] == "progress"
assert result["type"] == "progress_done"
with patch(
"pyoctoprintapi.OctoprintClient.get_discovery_info",
side_effect=Exception,
......@@ -203,7 +205,7 @@ async def test_show_zerconf_form(hass: HomeAssistant) -> None:
)
await hass.async_block_till_done()
assert result["type"] == "progress_done"
assert result["type"] == "progress"
with patch(
"pyoctoprintapi.OctoprintClient.get_server_info",
......@@ -269,7 +271,7 @@ async def test_show_ssdp_form(hass: HomeAssistant) -> None:
)
await hass.async_block_till_done()
assert result["type"] == "progress_done"
assert result["type"] == "progress"
with patch(
"pyoctoprintapi.OctoprintClient.get_server_info",
......@@ -390,10 +392,11 @@ async def test_failed_auth(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
)
await hass.async_block_till_done()
assert result["type"] == "progress_done"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == "progress"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == "abort"
assert result["reason"] == "auth_failed"
......@@ -421,10 +424,11 @@ async def test_failed_auth_unexpected_error(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
)
await hass.async_block_till_done()
assert result["type"] == "progress_done"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == "progress"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == "abort"
assert result["reason"] == "auth_failed"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment