diff --git a/homeassistant/components/config/area_registry.py b/homeassistant/components/config/area_registry.py index f40ed7834e3d6f247862755a908a3740df1040ff..09cd1c1c8ce9a0e8ee71c2245eb78946a0e720af 100644 --- a/homeassistant/components/config/area_registry.py +++ b/homeassistant/components/config/area_registry.py @@ -2,124 +2,102 @@ import voluptuous as vol from homeassistant.components import websocket_api -from homeassistant.components.websocket_api.decorators import ( - async_response, - require_admin, -) from homeassistant.core import callback -from homeassistant.helpers.area_registry import async_get_registry - -WS_TYPE_LIST = "config/area_registry/list" -SCHEMA_WS_LIST = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_LIST} -) - -WS_TYPE_CREATE = "config/area_registry/create" -SCHEMA_WS_CREATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_CREATE, vol.Required("name"): str} -) - -WS_TYPE_DELETE = "config/area_registry/delete" -SCHEMA_WS_DELETE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_DELETE, vol.Required("area_id"): str} -) - -WS_TYPE_UPDATE = "config/area_registry/update" -SCHEMA_WS_UPDATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - { - vol.Required("type"): WS_TYPE_UPDATE, - vol.Required("area_id"): str, - vol.Required("name"): str, - } -) +from homeassistant.helpers.area_registry import async_get async def async_setup(hass): """Enable the Area Registry views.""" - hass.components.websocket_api.async_register_command( - WS_TYPE_LIST, websocket_list_areas, SCHEMA_WS_LIST - ) - hass.components.websocket_api.async_register_command( - WS_TYPE_CREATE, websocket_create_area, SCHEMA_WS_CREATE - ) - hass.components.websocket_api.async_register_command( - WS_TYPE_DELETE, websocket_delete_area, SCHEMA_WS_DELETE - ) - hass.components.websocket_api.async_register_command( - WS_TYPE_UPDATE, websocket_update_area, SCHEMA_WS_UPDATE - ) + hass.components.websocket_api.async_register_command(websocket_list_areas) + hass.components.websocket_api.async_register_command(websocket_create_area) + hass.components.websocket_api.async_register_command(websocket_delete_area) + hass.components.websocket_api.async_register_command(websocket_update_area) return True -@async_response -async def websocket_list_areas(hass, connection, msg): +@websocket_api.websocket_command({vol.Required("type"): "config/area_registry/list"}) +@callback +def websocket_list_areas(hass, connection, msg): """Handle list areas command.""" - registry = await async_get_registry(hass) - connection.send_message( - websocket_api.result_message( - msg["id"], - [ - {"name": entry.name, "area_id": entry.id} - for entry in registry.async_list_areas() - ], - ) + registry = async_get(hass) + connection.send_result( + msg["id"], + [_entry_dict(entry) for entry in registry.async_list_areas()], ) -@require_admin -@async_response -async def websocket_create_area(hass, connection, msg): +@websocket_api.websocket_command( + { + vol.Required("type"): "config/area_registry/create", + vol.Required("name"): str, + vol.Optional("picture"): vol.Any(str, None), + } +) +@websocket_api.require_admin +@callback +def websocket_create_area(hass, connection, msg): """Create area command.""" - registry = await async_get_registry(hass) + registry = async_get(hass) + + data = dict(msg) + data.pop("type") + data.pop("id") + try: - entry = registry.async_create(msg["name"]) + entry = registry.async_create(**data) except ValueError as err: - connection.send_message( - websocket_api.error_message(msg["id"], "invalid_info", str(err)) - ) + connection.send_error(msg["id"], "invalid_info", str(err)) else: - connection.send_message( - websocket_api.result_message(msg["id"], _entry_dict(entry)) - ) + connection.send_result(msg["id"], _entry_dict(entry)) -@require_admin -@async_response -async def websocket_delete_area(hass, connection, msg): +@websocket_api.websocket_command( + { + vol.Required("type"): "config/area_registry/delete", + vol.Required("area_id"): str, + } +) +@websocket_api.require_admin +@callback +def websocket_delete_area(hass, connection, msg): """Delete area command.""" - registry = await async_get_registry(hass) + registry = async_get(hass) try: registry.async_delete(msg["area_id"]) except KeyError: - connection.send_message( - websocket_api.error_message( - msg["id"], "invalid_info", "Area ID doesn't exist" - ) - ) + connection.send_error(msg["id"], "invalid_info", "Area ID doesn't exist") else: connection.send_message(websocket_api.result_message(msg["id"], "success")) -@require_admin -@async_response -async def websocket_update_area(hass, connection, msg): +@websocket_api.websocket_command( + { + vol.Required("type"): "config/area_registry/update", + vol.Required("area_id"): str, + vol.Optional("name"): str, + vol.Optional("picture"): vol.Any(str, None), + } +) +@websocket_api.require_admin +@callback +def websocket_update_area(hass, connection, msg): """Handle update area websocket command.""" - registry = await async_get_registry(hass) + registry = async_get(hass) + + data = dict(msg) + data.pop("type") + data.pop("id") try: - entry = registry.async_update(msg["area_id"], msg["name"]) + entry = registry.async_update(**data) except ValueError as err: - connection.send_message( - websocket_api.error_message(msg["id"], "invalid_info", str(err)) - ) + connection.send_error(msg["id"], "invalid_info", str(err)) else: - connection.send_message( - websocket_api.result_message(msg["id"], _entry_dict(entry)) - ) + connection.send_result(msg["id"], _entry_dict(entry)) @callback def _entry_dict(entry): """Convert entry to API format.""" - return {"area_id": entry.id, "name": entry.name} + return {"area_id": entry.id, "name": entry.name, "picture": entry.picture} diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index 67d713e50878820cd1bf5e8698cf194a85a11e58..0073ecfb44b045e14219ab85db8ce89698c6e438 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -12,6 +12,8 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.loader import bind_hass from homeassistant.util import slugify +from .typing import UNDEFINED, UndefinedType + # mypy: disallow-any-generics DATA_REGISTRY = "area_registry" @@ -27,6 +29,7 @@ class AreaEntry: name: str = attr.ib() normalized_name: str = attr.ib() + picture: str | None = attr.ib(default=None) id: str | None = attr.ib(default=None) def generate_id(self, existing_ids: Container[str]) -> None: @@ -76,14 +79,14 @@ class AreaRegistry: return self.async_create(name) @callback - def async_create(self, name: str) -> AreaEntry: + def async_create(self, name: str, picture: str | None = None) -> AreaEntry: """Create a new area.""" normalized_name = normalize_area_name(name) if self.async_get_area_by_name(name): raise ValueError(f"The name {name} ({normalized_name}) is already in use") - area = AreaEntry(name=name, normalized_name=normalized_name) + area = AreaEntry(name=name, normalized_name=normalized_name, picture=picture) area.generate_id(self.areas) assert area.id is not None self.areas[area.id] = area @@ -113,36 +116,57 @@ class AreaRegistry: self.async_schedule_save() @callback - def async_update(self, area_id: str, name: str) -> AreaEntry: + def async_update( + self, + area_id: str, + name: str | UndefinedType = UNDEFINED, + picture: str | None | UndefinedType = UNDEFINED, + ) -> AreaEntry: """Update name of area.""" - updated = self._async_update(area_id, name) + updated = self._async_update(area_id, name=name, picture=picture) self.hass.bus.async_fire( EVENT_AREA_REGISTRY_UPDATED, {"action": "update", "area_id": area_id} ) return updated @callback - def _async_update(self, area_id: str, name: str) -> AreaEntry: + def _async_update( + self, + area_id: str, + name: str | UndefinedType = UNDEFINED, + picture: str | None | UndefinedType = UNDEFINED, + ) -> AreaEntry: """Update name of area.""" old = self.areas[area_id] changes = {} - if name == old.name: - return old + if picture is not UNDEFINED: + changes["picture"] = picture - normalized_name = normalize_area_name(name) + normalized_name = None - if normalized_name != old.normalized_name and self.async_get_area_by_name(name): - raise ValueError(f"The name {name} ({normalized_name}) is already in use") + if name is not UNDEFINED: + normalized_name = normalize_area_name(name) - changes["name"] = name - changes["normalized_name"] = normalized_name + if normalized_name != old.normalized_name and self.async_get_area_by_name( + name + ): + raise ValueError( + f"The name {name} ({normalized_name}) is already in use" + ) + + changes["name"] = name + changes["normalized_name"] = normalized_name + + if not changes: + return old new = self.areas[area_id] = attr.evolve(old, **changes) - self._normalized_name_area_idx[ - normalized_name - ] = self._normalized_name_area_idx.pop(old.normalized_name) + if normalized_name is not None: + self._normalized_name_area_idx[ + normalized_name + ] = self._normalized_name_area_idx.pop(old.normalized_name) self.async_schedule_save() return new @@ -157,7 +181,11 @@ class AreaRegistry: for area in data["areas"]: normalized_name = normalize_area_name(area["name"]) areas[area["id"]] = AreaEntry( - name=area["name"], id=area["id"], normalized_name=normalized_name + name=area["name"], + id=area["id"], + # New in 2021.11 + picture=area.get("picture"), + normalized_name=normalized_name, ) self._normalized_name_area_idx[normalized_name] = area["id"] @@ -174,10 +202,7 @@ class AreaRegistry: data = {} data["areas"] = [ - { - "name": entry.name, - "id": entry.id, - } + {"name": entry.name, "id": entry.id, "picture": entry.picture} for entry in self.areas.values() ] diff --git a/tests/components/config/test_area_registry.py b/tests/components/config/test_area_registry.py index 35176cc79f90fd3608227075d48fd3afb46bd32d..497a395ed309629f7a6f1055933793f345c55058 100644 --- a/tests/components/config/test_area_registry.py +++ b/tests/components/config/test_area_registry.py @@ -22,13 +22,17 @@ def registry(hass): async def test_list_areas(hass, client, registry): """Test list entries.""" registry.async_create("mock 1") - registry.async_create("mock 2") + registry.async_create("mock 2", "/image/example.png") await client.send_json({"id": 1, "type": "config/area_registry/list"}) msg = await client.receive_json() assert len(msg["result"]) == len(registry.areas) + assert msg["result"][0]["name"] == "mock 1" + assert msg["result"][0]["picture"] is None + assert msg["result"][1]["name"] == "mock 2" + assert msg["result"][1]["picture"] == "/image/example.png" async def test_create_area(hass, client, registry): @@ -98,6 +102,23 @@ async def test_update_area(hass, client, registry): "id": 1, "area_id": area.id, "name": "mock 2", + "picture": "/image/example.png", + "type": "config/area_registry/update", + } + ) + + msg = await client.receive_json() + + assert msg["result"]["area_id"] == area.id + assert msg["result"]["name"] == "mock 2" + assert msg["result"]["picture"] == "/image/example.png" + assert len(registry.areas) == 1 + + await client.send_json( + { + "id": 2, + "area_id": area.id, + "picture": None, "type": "config/area_registry/update", } ) @@ -106,6 +127,7 @@ async def test_update_area(hass, client, registry): assert msg["result"]["area_id"] == area.id assert msg["result"]["name"] == "mock 2" + assert msg["result"]["picture"] is None assert len(registry.areas) == 1