Skip to content
Snippets Groups Projects
Unverified Commit b9be4910 authored by Luke Lashley's avatar Luke Lashley Committed by GitHub
Browse files

Add options flow to Roborock (#104345)


Co-authored-by: default avatarRobert Resch <robert@resch.dev>
parent ec16fc23
No related branches found
No related tags found
No related merge requests found
...@@ -31,6 +31,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ...@@ -31,6 +31,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up roborock from a config entry.""" """Set up roborock from a config entry."""
_LOGGER.debug("Integration async setup entry: %s", entry.as_dict()) _LOGGER.debug("Integration async setup entry: %s", entry.as_dict())
entry.async_on_unload(entry.add_update_listener(update_listener))
user_data = UserData.from_dict(entry.data[CONF_USER_DATA]) user_data = UserData.from_dict(entry.data[CONF_USER_DATA])
api_client = RoborockApiClient(entry.data[CONF_USERNAME], entry.data[CONF_BASE_URL]) api_client = RoborockApiClient(entry.data[CONF_USERNAME], entry.data[CONF_BASE_URL])
...@@ -50,8 +51,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ...@@ -50,8 +51,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
translation_key="home_data_fail", translation_key="home_data_fail",
) from err ) from err
_LOGGER.debug("Got home data %s", home_data) _LOGGER.debug("Got home data %s", home_data)
all_devices: list[HomeDataDevice] = home_data.devices + home_data.received_devices
device_map: dict[str, HomeDataDevice] = { device_map: dict[str, HomeDataDevice] = {
device.duid: device for device in home_data.devices + home_data.received_devices device.duid: device for device in all_devices
} }
product_info: dict[str, HomeDataProduct] = { product_info: dict[str, HomeDataProduct] = {
product.id: product for product in home_data.products product.id: product for product in home_data.products
...@@ -177,3 +179,9 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ...@@ -177,3 +179,9 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
await asyncio.gather(*release_tasks) await asyncio.gather(*release_tasks)
return unload_ok return unload_ok
async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle options update."""
# Reload entry to update data
await hass.config_entries.async_reload(entry.entry_id)
...@@ -17,10 +17,24 @@ from roborock.exceptions import ( ...@@ -17,10 +17,24 @@ from roborock.exceptions import (
from roborock.web_api import RoborockApiClient from roborock.web_api import RoborockApiClient
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry, ConfigFlow, ConfigFlowResult from homeassistant.config_entries import (
ConfigEntry,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
OptionsFlowWithConfigEntry,
)
from homeassistant.const import CONF_USERNAME from homeassistant.const import CONF_USERNAME
from homeassistant.core import callback
from .const import CONF_BASE_URL, CONF_ENTRY_CODE, CONF_USER_DATA, DOMAIN
from .const import (
CONF_BASE_URL,
CONF_ENTRY_CODE,
CONF_USER_DATA,
DEFAULT_DRAWABLES,
DOMAIN,
DRAWABLES,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
...@@ -107,9 +121,6 @@ class RoborockFlowHandler(ConfigFlow, domain=DOMAIN): ...@@ -107,9 +121,6 @@ class RoborockFlowHandler(ConfigFlow, domain=DOMAIN):
CONF_USER_DATA: login_data.as_dict(), CONF_USER_DATA: login_data.as_dict(),
}, },
) )
await self.hass.config_entries.async_reload(
self.reauth_entry.entry_id
)
return self.async_abort(reason="reauth_successful") return self.async_abort(reason="reauth_successful")
return self._create_entry(self._client, self._username, login_data) return self._create_entry(self._client, self._username, login_data)
...@@ -154,3 +165,43 @@ class RoborockFlowHandler(ConfigFlow, domain=DOMAIN): ...@@ -154,3 +165,43 @@ class RoborockFlowHandler(ConfigFlow, domain=DOMAIN):
CONF_BASE_URL: client.base_url, CONF_BASE_URL: client.base_url,
}, },
) )
@staticmethod
@callback
def async_get_options_flow(
config_entry: ConfigEntry,
) -> OptionsFlow:
"""Create the options flow."""
return RoborockOptionsFlowHandler(config_entry)
class RoborockOptionsFlowHandler(OptionsFlowWithConfigEntry):
"""Handle an option flow for Roborock."""
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage the options."""
return await self.async_step_drawables()
async def async_step_drawables(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage the map object drawable options."""
if user_input is not None:
self.options.setdefault(DRAWABLES, {}).update(user_input)
return self.async_create_entry(title="", data=self.options)
data_schema = {}
for drawable, default_value in DEFAULT_DRAWABLES.items():
data_schema[
vol.Required(
drawable.value,
default=self.config_entry.options.get(DRAWABLES, {}).get(
drawable, default_value
),
)
] = bool
return self.async_show_form(
step_id=DRAWABLES,
data_schema=vol.Schema(data_schema),
)
...@@ -9,6 +9,28 @@ CONF_ENTRY_CODE = "code" ...@@ -9,6 +9,28 @@ CONF_ENTRY_CODE = "code"
CONF_BASE_URL = "base_url" CONF_BASE_URL = "base_url"
CONF_USER_DATA = "user_data" CONF_USER_DATA = "user_data"
# Option Flow steps
DRAWABLES = "drawables"
DEFAULT_DRAWABLES = {
Drawable.CHARGER: True,
Drawable.CLEANED_AREA: False,
Drawable.GOTO_PATH: False,
Drawable.IGNORED_OBSTACLES: False,
Drawable.IGNORED_OBSTACLES_WITH_PHOTO: False,
Drawable.MOP_PATH: False,
Drawable.NO_CARPET_AREAS: False,
Drawable.NO_GO_AREAS: False,
Drawable.NO_MOPPING_AREAS: False,
Drawable.OBSTACLES: False,
Drawable.OBSTACLES_WITH_PHOTO: False,
Drawable.PATH: True,
Drawable.PREDICTED_PATH: False,
Drawable.VACUUM_POSITION: True,
Drawable.VIRTUAL_WALLS: False,
Drawable.ZONES: False,
}
PLATFORMS = [ PLATFORMS = [
Platform.BINARY_SENSOR, Platform.BINARY_SENSOR,
Platform.BUTTON, Platform.BUTTON,
...@@ -21,11 +43,6 @@ PLATFORMS = [ ...@@ -21,11 +43,6 @@ PLATFORMS = [
Platform.VACUUM, Platform.VACUUM,
] ]
IMAGE_DRAWABLES: list[Drawable] = [
Drawable.PATH,
Drawable.CHARGER,
Drawable.VACUUM_POSITION,
]
IMAGE_CACHE_INTERVAL = 90 IMAGE_CACHE_INTERVAL = 90
......
...@@ -7,6 +7,7 @@ from itertools import chain ...@@ -7,6 +7,7 @@ from itertools import chain
from roborock import RoborockCommand from roborock import RoborockCommand
from vacuum_map_parser_base.config.color import ColorsPalette from vacuum_map_parser_base.config.color import ColorsPalette
from vacuum_map_parser_base.config.drawable import Drawable
from vacuum_map_parser_base.config.image_config import ImageConfig from vacuum_map_parser_base.config.image_config import ImageConfig
from vacuum_map_parser_base.config.size import Sizes from vacuum_map_parser_base.config.size import Sizes
from vacuum_map_parser_roborock.map_data_parser import RoborockMapDataParser from vacuum_map_parser_roborock.map_data_parser import RoborockMapDataParser
...@@ -20,7 +21,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback ...@@ -20,7 +21,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import slugify from homeassistant.util import slugify
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import DOMAIN, IMAGE_CACHE_INTERVAL, IMAGE_DRAWABLES, MAP_SLEEP from .const import DEFAULT_DRAWABLES, DOMAIN, DRAWABLES, IMAGE_CACHE_INTERVAL, MAP_SLEEP
from .coordinator import RoborockDataUpdateCoordinator from .coordinator import RoborockDataUpdateCoordinator
from .device import RoborockCoordinatedEntity from .device import RoborockCoordinatedEntity
...@@ -35,10 +36,18 @@ async def async_setup_entry( ...@@ -35,10 +36,18 @@ async def async_setup_entry(
coordinators: dict[str, RoborockDataUpdateCoordinator] = hass.data[DOMAIN][ coordinators: dict[str, RoborockDataUpdateCoordinator] = hass.data[DOMAIN][
config_entry.entry_id config_entry.entry_id
] ]
drawables = [
drawable
for drawable, default_value in DEFAULT_DRAWABLES.items()
if config_entry.options.get(DRAWABLES, {}).get(drawable, default_value)
]
entities = list( entities = list(
chain.from_iterable( chain.from_iterable(
await asyncio.gather( await asyncio.gather(
*(create_coordinator_maps(coord) for coord in coordinators.values()) *(
create_coordinator_maps(coord, drawables)
for coord in coordinators.values()
)
) )
) )
) )
...@@ -58,13 +67,14 @@ class RoborockMap(RoborockCoordinatedEntity, ImageEntity): ...@@ -58,13 +67,14 @@ class RoborockMap(RoborockCoordinatedEntity, ImageEntity):
map_flag: int, map_flag: int,
starting_map: bytes, starting_map: bytes,
map_name: str, map_name: str,
drawables: list[Drawable],
) -> None: ) -> None:
"""Initialize a Roborock map.""" """Initialize a Roborock map."""
RoborockCoordinatedEntity.__init__(self, unique_id, coordinator) RoborockCoordinatedEntity.__init__(self, unique_id, coordinator)
ImageEntity.__init__(self, coordinator.hass) ImageEntity.__init__(self, coordinator.hass)
self._attr_name = map_name self._attr_name = map_name
self.parser = RoborockMapDataParser( self.parser = RoborockMapDataParser(
ColorsPalette(), Sizes(), IMAGE_DRAWABLES, ImageConfig(), [] ColorsPalette(), Sizes(), drawables, ImageConfig(), []
) )
self._attr_image_last_updated = dt_util.utcnow() self._attr_image_last_updated = dt_util.utcnow()
self.map_flag = map_flag self.map_flag = map_flag
...@@ -140,7 +150,7 @@ class RoborockMap(RoborockCoordinatedEntity, ImageEntity): ...@@ -140,7 +150,7 @@ class RoborockMap(RoborockCoordinatedEntity, ImageEntity):
async def create_coordinator_maps( async def create_coordinator_maps(
coord: RoborockDataUpdateCoordinator, coord: RoborockDataUpdateCoordinator, drawables: list[Drawable]
) -> list[RoborockMap]: ) -> list[RoborockMap]:
"""Get the starting map information for all maps for this device. """Get the starting map information for all maps for this device.
...@@ -148,7 +158,6 @@ async def create_coordinator_maps( ...@@ -148,7 +158,6 @@ async def create_coordinator_maps(
Only one map can be loaded at a time per device. Only one map can be loaded at a time per device.
""" """
entities = [] entities = []
cur_map = coord.current_map cur_map = coord.current_map
# This won't be None at this point as the coordinator will have run first. # This won't be None at this point as the coordinator will have run first.
assert cur_map is not None assert cur_map is not None
...@@ -180,6 +189,7 @@ async def create_coordinator_maps( ...@@ -180,6 +189,7 @@ async def create_coordinator_maps(
map_flag, map_flag,
api_data, api_data,
map_info.name, map_info.name,
drawables,
) )
) )
if len(coord.maps) != 1: if len(coord.maps) != 1:
......
...@@ -31,6 +31,32 @@ ...@@ -31,6 +31,32 @@
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
} }
}, },
"options": {
"step": {
"drawables": {
"description": "Specify which features to draw on the map.",
"data": {
"charger": "Charger",
"cleaned_area": "Cleaned area",
"goto_path": "Go-to path",
"ignored_obstacles": "Ignored obstacles",
"ignored_obstacles_with_photo": "Ignored obstacles with photo",
"mop_path": "Mop path",
"no_carpet_zones": "No carpet zones",
"no_go_zones": "No-go zones",
"no_mopping_zones": "No mopping zones",
"obstacles": "Obstacles",
"obstacles_with_photo": "Obstacles with photo",
"path": "Path",
"predicted_path": "Predicted path",
"room_names": "Room names",
"vacuum_position": "Vacuum position",
"virtual_walls": "Virtual walls",
"zones": "Zones"
}
}
}
},
"entity": { "entity": {
"binary_sensor": { "binary_sensor": {
"in_cleaning": { "in_cleaning": {
......
...@@ -11,9 +11,10 @@ from roborock.exceptions import ( ...@@ -11,9 +11,10 @@ from roborock.exceptions import (
RoborockInvalidEmail, RoborockInvalidEmail,
RoborockUrlException, RoborockUrlException,
) )
from vacuum_map_parser_base.config.drawable import Drawable
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.roborock.const import CONF_ENTRY_CODE, DOMAIN from homeassistant.components.roborock.const import CONF_ENTRY_CODE, DOMAIN, DRAWABLES
from homeassistant.const import CONF_USERNAME from homeassistant.const import CONF_USERNAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
...@@ -185,6 +186,28 @@ async def test_config_flow_failures_code_login( ...@@ -185,6 +186,28 @@ async def test_config_flow_failures_code_login(
assert len(mock_setup.mock_calls) == 1 assert len(mock_setup.mock_calls) == 1
async def test_options_flow_drawables(
hass: HomeAssistant, setup_entry: MockConfigEntry
) -> None:
"""Test that the options flow works."""
result = await hass.config_entries.options.async_init(setup_entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == DRAWABLES
with patch(
"homeassistant.components.roborock.async_setup_entry", return_value=True
) as mock_setup:
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={Drawable.PREDICTED_PATH: True},
)
await hass.async_block_till_done()
assert result["type"] == FlowResultType.CREATE_ENTRY
assert setup_entry.options[DRAWABLES][Drawable.PREDICTED_PATH] is True
assert len(mock_setup.mock_calls) == 1
async def test_reauth_flow( async def test_reauth_flow(
hass: HomeAssistant, bypass_api_fixture, mock_roborock_entry: MockConfigEntry hass: HomeAssistant, bypass_api_fixture, mock_roborock_entry: MockConfigEntry
) -> None: ) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment