diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index c7825918752f12a0f85d0b79277290461c9f0b98..57497df2d7ecc9d6f32e64fd749dc9372b784739 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -1,11 +1,11 @@ """Provide a way to connect entities belonging to one device.""" from __future__ import annotations -from collections import OrderedDict +from collections import UserDict from collections.abc import Coroutine import logging import time -from typing import TYPE_CHECKING, Any, NamedTuple, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast import attr @@ -48,11 +48,6 @@ ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30 RUNTIME_ONLY_ATTRS = {"suggested_area"} -class _DeviceIndex(NamedTuple): - identifiers: dict[tuple[str, str], str] - connections: dict[tuple[str, str], str] - - class DeviceEntryDisabler(StrEnum): """What disabled a device entry.""" @@ -149,23 +144,6 @@ def format_mac(mac: str) -> str: return mac -def _async_get_device_id_from_index( - devices_index: _DeviceIndex, - identifiers: set[tuple[str, str]], - connections: set[tuple[str, str]] | None, -) -> str | None: - """Check if device has previously been registered.""" - for identifier in identifiers: - if identifier in devices_index.identifiers: - return devices_index.identifiers[identifier] - if not connections: - return None - for connection in _normalize_connections(connections): - if connection in devices_index.connections: - return devices_index.connections[connection] - return None - - class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]): """Store entity registry data.""" @@ -210,13 +188,69 @@ class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]): return old_data +_EntryTypeT = TypeVar("_EntryTypeT", DeviceEntry, DeletedDeviceEntry) + + +class DeviceRegistryItems(UserDict[str, _EntryTypeT]): + """Container for device registry items, maps device id -> entry. + + Maintains two additional indexes: + - (connection_type, connection identifier) -> entry + - (DOMAIN, identifier) -> entry + """ + + def __init__(self) -> None: + """Initialize the container.""" + super().__init__() + self._connections: dict[tuple[str, str], _EntryTypeT] = {} + self._identifiers: dict[tuple[str, str], _EntryTypeT] = {} + + def __setitem__(self, key: str, entry: _EntryTypeT) -> None: + """Add an item.""" + if key in self: + old_entry = self[key] + for connection in old_entry.connections: + del self._connections[connection] + for identifier in old_entry.identifiers: + del self._identifiers[identifier] + # type ignore linked to mypy issue: https://github.com/python/mypy/issues/13596 + super().__setitem__(key, entry) # type: ignore[assignment] + for connection in entry.connections: + self._connections[connection] = entry + for identifier in entry.identifiers: + self._identifiers[identifier] = entry + + def __delitem__(self, key: str) -> None: + """Remove an item.""" + entry = self[key] + for connection in entry.connections: + del self._connections[connection] + for identifier in entry.identifiers: + del self._identifiers[identifier] + super().__delitem__(key) + + def get_entry( + self, + identifiers: set[tuple[str, str]], + connections: set[tuple[str, str]] | None, + ) -> _EntryTypeT | None: + """Get entry from identifiers or connections.""" + for identifier in identifiers: + if identifier in self._identifiers: + return self._identifiers[identifier] + if not connections: + return None + for connection in _normalize_connections(connections): + if connection in self._connections: + return self._connections[connection] + return None + + class DeviceRegistry: """Class to hold a registry of devices.""" - devices: dict[str, DeviceEntry] - deleted_devices: dict[str, DeletedDeviceEntry] - _registered_index: _DeviceIndex - _deleted_index: _DeviceIndex + devices: DeviceRegistryItems[DeviceEntry] + deleted_devices: DeviceRegistryItems[DeletedDeviceEntry] def __init__(self, hass: HomeAssistant) -> None: """Initialize the device registry.""" @@ -228,7 +262,6 @@ class DeviceRegistry: atomic_writes=True, minor_version=STORAGE_VERSION_MINOR, ) - self._clear_index() @callback def async_get(self, device_id: str) -> DeviceEntry | None: @@ -242,12 +275,7 @@ class DeviceRegistry: connections: set[tuple[str, str]] | None = None, ) -> DeviceEntry | None: """Check if device is registered.""" - device_id = _async_get_device_id_from_index( - self._registered_index, identifiers, connections - ) - if device_id is None: - return None - return self.devices[device_id] + return self.devices.get_entry(identifiers, connections) def _async_get_deleted_device( self, @@ -255,55 +283,7 @@ class DeviceRegistry: connections: set[tuple[str, str]] | None, ) -> DeletedDeviceEntry | None: """Check if device is deleted.""" - device_id = _async_get_device_id_from_index( - self._deleted_index, identifiers, connections - ) - if device_id is None: - return None - return self.deleted_devices[device_id] - - def _add_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None: - """Add a device and index it.""" - if isinstance(device, DeletedDeviceEntry): - devices_index = self._deleted_index - self.deleted_devices[device.id] = device - else: - devices_index = self._registered_index - self.devices[device.id] = device - - _add_device_to_index(devices_index, device) - - def _remove_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None: - """Remove a device and remove it from the index.""" - if isinstance(device, DeletedDeviceEntry): - devices_index = self._deleted_index - self.deleted_devices.pop(device.id) - else: - devices_index = self._registered_index - self.devices.pop(device.id) - - _remove_device_from_index(devices_index, device) - - def _update_device(self, old_device: DeviceEntry, new_device: DeviceEntry) -> None: - """Update a device and the index.""" - self.devices[new_device.id] = new_device - - devices_index = self._registered_index - _remove_device_from_index(devices_index, old_device) - _add_device_to_index(devices_index, new_device) - - def _clear_index(self) -> None: - """Clear the index.""" - self._registered_index = _DeviceIndex(identifiers={}, connections={}) - self._deleted_index = _DeviceIndex(identifiers={}, connections={}) - - def _rebuild_index(self) -> None: - """Create the index after loading devices.""" - self._clear_index() - for device in self.devices.values(): - _add_device_to_index(self._registered_index, device) - for deleted_device in self.deleted_devices.values(): - _add_device_to_index(self._deleted_index, deleted_device) + return self.deleted_devices.get_entry(identifiers, connections) @callback def async_get_or_create( @@ -346,11 +326,11 @@ class DeviceRegistry: if deleted_device is None: device = DeviceEntry(is_new=True) else: - self._remove_device(deleted_device) + self.deleted_devices.pop(deleted_device.id) device = deleted_device.to_device_entry( config_entry_id, connections, identifiers ) - self._add_device(device) + self.devices[device.id] = device if default_manufacturer is not UNDEFINED and device.manufacturer is None: manufacturer = default_manufacturer @@ -516,7 +496,7 @@ class DeviceRegistry: return old new = attr.evolve(old, **new_values) - self._update_device(old, new) + self.devices[device_id] = new # If its only run time attributes (suggested_area) # that do not get saved we do not want to write @@ -542,16 +522,13 @@ class DeviceRegistry: @callback def async_remove_device(self, device_id: str) -> None: """Remove a device from the device registry.""" - device = self.devices[device_id] - self._remove_device(device) - self._add_device( - DeletedDeviceEntry( - config_entries=device.config_entries, - connections=device.connections, - identifiers=device.identifiers, - id=device.id, - orphaned_timestamp=None, - ) + device = self.devices.pop(device_id) + self.deleted_devices[device_id] = DeletedDeviceEntry( + config_entries=device.config_entries, + connections=device.connections, + identifiers=device.identifiers, + id=device.id, + orphaned_timestamp=None, ) for other_device in list(self.devices.values()): if other_device.via_device_id == device_id: @@ -567,8 +544,8 @@ class DeviceRegistry: data = await self._store.async_load() - devices = OrderedDict() - deleted_devices = OrderedDict() + devices: DeviceRegistryItems[DeviceEntry] = DeviceRegistryItems() + deleted_devices: DeviceRegistryItems[DeletedDeviceEntry] = DeviceRegistryItems() if data is not None: for device in data["devices"]: @@ -607,7 +584,6 @@ class DeviceRegistry: self.devices = devices self.deleted_devices = deleted_devices - self._rebuild_index() @callback def async_schedule_save(self) -> None: @@ -692,7 +668,7 @@ class DeviceRegistry: deleted_device.orphaned_timestamp + ORPHANED_DEVICE_KEEP_SECONDS < now_time ): - self._remove_device(deleted_device) + del self.deleted_devices[deleted_device.id] @callback def async_clear_area_id(self, area_id: str) -> None: @@ -879,27 +855,3 @@ def _normalize_connections(connections: set[tuple[str, str]]) -> set[tuple[str, (key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value) for key, value in connections } - - -def _add_device_to_index( - devices_index: _DeviceIndex, - device: DeviceEntry | DeletedDeviceEntry, -) -> None: - """Add a device to the index.""" - for identifier in device.identifiers: - devices_index.identifiers[identifier] = device.id - for connection in device.connections: - devices_index.connections[connection] = device.id - - -def _remove_device_from_index( - devices_index: _DeviceIndex, - device: DeviceEntry | DeletedDeviceEntry, -) -> None: - """Remove a device from the index.""" - for identifier in device.identifiers: - if identifier in devices_index.identifiers: - del devices_index.identifiers[identifier] - for connection in device.connections: - if connection in devices_index.connections: - del devices_index.connections[connection] diff --git a/tests/common.py b/tests/common.py index 89d1a1d9116648981111e4a511720e2bab3b7eb0..232701bd7465207bff290bd95abbf5b5222e81cf 100644 --- a/tests/common.py +++ b/tests/common.py @@ -469,12 +469,15 @@ def mock_area_registry(hass, mock_entries=None): return registry -def mock_device_registry(hass, mock_entries=None, mock_deleted_entries=None): +def mock_device_registry(hass, mock_entries=None): """Mock the Device Registry.""" registry = device_registry.DeviceRegistry(hass) - registry.devices = mock_entries or OrderedDict() - registry.deleted_devices = mock_deleted_entries or OrderedDict() - registry._rebuild_index() + registry.devices = device_registry.DeviceRegistryItems() + if mock_entries is None: + mock_entries = {} + for key, entry in mock_entries.items(): + registry.devices[key] = entry + registry.deleted_devices = device_registry.DeviceRegistryItems() hass.data[device_registry.DATA_REGISTRY] = registry return registry