From ee486c269c4fb7ab151a4b219d90329b26be4939 Mon Sep 17 00:00:00 2001
From: cs12ag <70966712+cs12ag@users.noreply.github.com>
Date: Mon, 3 Mar 2025 13:06:25 +0000
Subject: [PATCH] Fix unique identifiers where multiple IKEA Tradfri gateways
 are in use (#136060)

* Create unique identifiers where multiple gateways are in use

Resolving issue https://github.com/home-assistant/core/issues/134497

* Added migration function to __init__.py

Added migration function to execute upon initialisation, to:
a) remove the erroneously-added config)_entry added to the device (gateway B gets added as a config_entry to a device associated to gateway A), and
b) swap out the non-unique identifiers for genuinely unique identifiers.

* Added tests to simulate migration from bad data scenario (i.e. explicitly executing migrate_entity_unique_ids() from __init__.py)

* Ammendments suggested in first review

* Changes after second review

* Rewrite of test_migrate_config_entry_and_identifiers after feedback

* Converted migrate function into major version, updated tests

* Finalised variable naming convention per feedback, added test to validate config entry migrated to v2

* Hopefully final changes for cosmetic / comment stucture

* Further code-coverage in test_migrate_config_entry_and_identifiers()

* Minor test corrections

* Added test for non-tradfri identifiers
---
 homeassistant/components/tradfri/__init__.py  |  94 ++++++++-
 .../components/tradfri/config_flow.py         |   2 +-
 homeassistant/components/tradfri/entity.py    |   2 +-
 tests/components/tradfri/__init__.py          |   2 +
 tests/components/tradfri/test_init.py         | 186 +++++++++++++++++-
 5 files changed, 280 insertions(+), 6 deletions(-)

diff --git a/homeassistant/components/tradfri/__init__.py b/homeassistant/components/tradfri/__init__.py
index 2073829e021..c3e8938b244 100644
--- a/homeassistant/components/tradfri/__init__.py
+++ b/homeassistant/components/tradfri/__init__.py
@@ -159,7 +159,7 @@ def remove_stale_devices(
     device_entries = dr.async_entries_for_config_entry(
         device_registry, config_entry.entry_id
     )
-    all_device_ids = {device.id for device in devices}
+    all_device_ids = {str(device.id) for device in devices}
 
     for device_entry in device_entries:
         device_id: str | None = None
@@ -176,7 +176,7 @@ def remove_stale_devices(
                 gateway_id = _id
                 break
 
-            device_id = _id
+            device_id = _id.replace(f"{config_entry.data[CONF_GATEWAY_ID]}-", "")
             break
 
         if gateway_id is not None:
@@ -190,3 +190,93 @@ def remove_stale_devices(
             device_registry.async_update_device(
                 device_entry.id, remove_config_entry_id=config_entry.entry_id
             )
+
+
+async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
+    """Migrate old entry."""
+    LOGGER.debug(
+        "Migrating Tradfri configuration from version %s.%s",
+        config_entry.version,
+        config_entry.minor_version,
+    )
+
+    if config_entry.version > 1:
+        # This means the user has downgraded from a future version
+        return False
+
+    if config_entry.version == 1:
+        # Migrate to version 2
+        migrate_config_entry_and_identifiers(hass, config_entry)
+
+        hass.config_entries.async_update_entry(config_entry, version=2)
+
+    LOGGER.debug(
+        "Migration to Tradfri configuration version %s.%s successful",
+        config_entry.version,
+        config_entry.minor_version,
+    )
+
+    return True
+
+
+def migrate_config_entry_and_identifiers(
+    hass: HomeAssistant, config_entry: ConfigEntry
+) -> None:
+    """Migrate old non-unique identifiers to new unique identifiers."""
+
+    related_device_flag: bool
+    device_id: str
+
+    device_reg = dr.async_get(hass)
+    # Get all devices associated to contextual gateway config_entry
+    # and loop through list of devices.
+    for device in dr.async_entries_for_config_entry(device_reg, config_entry.entry_id):
+        related_device_flag = False
+        for identifier in device.identifiers:
+            if identifier[0] != DOMAIN:
+                continue
+
+            related_device_flag = True
+
+            _id = identifier[1]
+
+            # Identify gateway device.
+            if _id == config_entry.data[CONF_GATEWAY_ID]:
+                # Using this to avoid updating gateway's own device registry entry
+                related_device_flag = False
+                break
+
+            device_id = str(_id)
+            break
+
+        # Check that device is related to tradfri domain (and is not the gateway itself)
+        if not related_device_flag:
+            continue
+
+        # Loop through list of config_entry_ids for device
+        config_entry_ids = device.config_entries
+        for config_entry_id in config_entry_ids:
+            # Check that the config entry in list is not the device's primary config entry
+            if config_entry_id == device.primary_config_entry:
+                continue
+
+            # Check that the 'other' config entry is also a tradfri config entry
+            other_entry = hass.config_entries.async_get_entry(config_entry_id)
+
+            if other_entry is None or other_entry.domain != DOMAIN:
+                continue
+
+            # Remove non-primary 'tradfri' config entry from device's config_entry_ids
+            device_reg.async_update_device(
+                device.id, remove_config_entry_id=config_entry_id
+            )
+
+        if config_entry.data[CONF_GATEWAY_ID] in device_id:
+            continue
+
+        device_reg.async_update_device(
+            device.id,
+            new_identifiers={
+                (DOMAIN, f"{config_entry.data[CONF_GATEWAY_ID]}-{device_id}")
+            },
+        )
diff --git a/homeassistant/components/tradfri/config_flow.py b/homeassistant/components/tradfri/config_flow.py
index 29d876346a7..9f5b39a9657 100644
--- a/homeassistant/components/tradfri/config_flow.py
+++ b/homeassistant/components/tradfri/config_flow.py
@@ -35,7 +35,7 @@ class AuthError(Exception):
 class FlowHandler(ConfigFlow, domain=DOMAIN):
     """Handle a config flow."""
 
-    VERSION = 1
+    VERSION = 2
 
     def __init__(self) -> None:
         """Initialize flow."""
diff --git a/homeassistant/components/tradfri/entity.py b/homeassistant/components/tradfri/entity.py
index b06d0081477..41c20b19de5 100644
--- a/homeassistant/components/tradfri/entity.py
+++ b/homeassistant/components/tradfri/entity.py
@@ -58,7 +58,7 @@ class TradfriBaseEntity(CoordinatorEntity[TradfriDeviceDataUpdateCoordinator]):
 
         info = self._device.device_info
         self._attr_device_info = DeviceInfo(
-            identifiers={(DOMAIN, self._device_id)},
+            identifiers={(DOMAIN, f"{gateway_id}-{self._device_id}")},
             manufacturer=info.manufacturer,
             model=info.model_number,
             name=self._device.name,
diff --git a/tests/components/tradfri/__init__.py b/tests/components/tradfri/__init__.py
index 37792ae7e32..f73d887d16c 100644
--- a/tests/components/tradfri/__init__.py
+++ b/tests/components/tradfri/__init__.py
@@ -1,4 +1,6 @@
 """Tests for the tradfri component."""
 
 GATEWAY_ID = "mock-gateway-id"
+GATEWAY_ID1 = "mockgatewayid1"
+GATEWAY_ID2 = "mockgatewayid2"
 TRADFRI_PATH = "homeassistant.components.tradfri"
diff --git a/tests/components/tradfri/test_init.py b/tests/components/tradfri/test_init.py
index 54ce469f3c5..a1a4b8d9627 100644
--- a/tests/components/tradfri/test_init.py
+++ b/tests/components/tradfri/test_init.py
@@ -2,13 +2,19 @@
 
 from unittest.mock import MagicMock
 
+from pytradfri.const import ATTR_FIRMWARE_VERSION, ATTR_GATEWAY_ID
+from pytradfri.gateway import Gateway
+
 from homeassistant.components import tradfri
+from homeassistant.components.tradfri.const import DOMAIN
 from homeassistant.core import HomeAssistant
 from homeassistant.helpers import device_registry as dr
+from homeassistant.setup import async_setup_component
 
-from . import GATEWAY_ID
+from . import GATEWAY_ID, GATEWAY_ID1, GATEWAY_ID2
+from .common import CommandStore
 
-from tests.common import MockConfigEntry
+from tests.common import MockConfigEntry, load_json_object_fixture
 
 
 async def test_entry_setup_unload(
@@ -66,6 +72,7 @@ async def test_remove_stale_devices(
     device_registry.async_get_or_create(
         config_entry_id=config_entry.entry_id,
         identifiers={(tradfri.DOMAIN, "stale_device_id")},
+        name="stale-device",
     )
     device_entries = dr.async_entries_for_config_entry(
         device_registry, config_entry.entry_id
@@ -91,3 +98,178 @@ async def test_remove_stale_devices(
     assert device_entry.manufacturer == "IKEA of Sweden"
     assert device_entry.name == "Gateway"
     assert device_entry.model == "E1526"
+
+
+async def test_migrate_config_entry_and_identifiers(
+    hass: HomeAssistant,
+    device_registry: dr.DeviceRegistry,
+    command_store: CommandStore,
+) -> None:
+    """Test correction of device registry entries."""
+    config_entry1 = MockConfigEntry(
+        domain=tradfri.DOMAIN,
+        data={
+            tradfri.CONF_HOST: "mock-host1",
+            tradfri.CONF_IDENTITY: "mock-identity1",
+            tradfri.CONF_KEY: "mock-key1",
+            tradfri.CONF_GATEWAY_ID: GATEWAY_ID1,
+        },
+    )
+
+    gateway1 = mock_gateway_fixture(command_store, GATEWAY_ID1)
+    command_store.register_device(
+        gateway1, load_json_object_fixture("bulb_w.json", DOMAIN)
+    )
+    config_entry1.add_to_hass(hass)
+
+    config_entry2 = MockConfigEntry(
+        domain=tradfri.DOMAIN,
+        data={
+            tradfri.CONF_HOST: "mock-host2",
+            tradfri.CONF_IDENTITY: "mock-identity2",
+            tradfri.CONF_KEY: "mock-key2",
+            tradfri.CONF_GATEWAY_ID: GATEWAY_ID2,
+        },
+    )
+
+    config_entry2.add_to_hass(hass)
+
+    # Add non-tradfri config entry for use in testing negation logic
+    config_entry3 = MockConfigEntry(
+        domain="test_domain",
+    )
+
+    config_entry3.add_to_hass(hass)
+
+    # Create gateway device for config entry 1
+    gateway1_device = device_registry.async_get_or_create(
+        config_entry_id=config_entry1.entry_id,
+        identifiers={(config_entry1.domain, config_entry1.data["gateway_id"])},
+        name="Gateway",
+    )
+
+    # Create bulb 1 on gateway 1 in Device Registry - this has the old identifiers format
+    gateway1_bulb1 = device_registry.async_get_or_create(
+        config_entry_id=config_entry1.entry_id,
+        identifiers={(tradfri.DOMAIN, 65537)},
+        name="bulb1",
+    )
+
+    # Update bulb 1 device to have both config entry IDs
+    # This is to simulate existing data scenario with older version of tradfri component
+    device_registry.async_update_device(
+        gateway1_bulb1.id,
+        add_config_entry_id=config_entry2.entry_id,
+    )
+
+    # Create bulb 2 on gateway 1 in Device Registry - this has the new identifiers format
+    gateway1_bulb2 = device_registry.async_get_or_create(
+        config_entry_id=config_entry1.entry_id,
+        identifiers={(tradfri.DOMAIN, f"{GATEWAY_ID1}-65538")},
+        name="bulb2",
+    )
+
+    # Update bulb 2 device to have an additional config entry from config_entry3
+    # This is to simulate scenario whereby a device entry
+    # is shared by multiple config entries
+    # and where at least one of those config entries is not the 'tradfri' domain
+    device_registry.async_update_device(
+        gateway1_bulb2.id,
+        add_config_entry_id=config_entry3.entry_id,
+        merge_identifiers={("test_domain", "config_entry_3-device2")},
+    )
+
+    # Create a device on config entry 3 in Device Registry
+    config_entry3_device = device_registry.async_get_or_create(
+        config_entry_id=config_entry3.entry_id,
+        identifiers={("test_domain", "config_entry_3-device1")},
+        name="device",
+    )
+
+    # Set up all tradfri config entries.
+    await async_setup_component(hass, DOMAIN, {})
+    await hass.async_block_till_done()
+
+    # Validate that gateway 1 bulb 1 is still the same device entry
+    # This inherently also validates that the device's identifiers
+    # have been updated to the new unique format
+    device_entries = dr.async_entries_for_config_entry(
+        device_registry, config_entry1.entry_id
+    )
+    assert (
+        device_registry.async_get_device(
+            identifiers={(tradfri.DOMAIN, f"{GATEWAY_ID1}-65537")}
+        ).id
+        == gateway1_bulb1.id
+    )
+
+    # Validate that gateway 1 bulb 1 only has gateway 1's config ID associated to it
+    # (Device at index 0 is the gateway)
+    assert device_entries[1].config_entries == {config_entry1.entry_id}
+
+    # Validate that the gateway 1 device is unchanged
+    assert device_entries[0].id == gateway1_device.id
+    assert device_entries[0].identifiers == gateway1_device.identifiers
+    assert device_entries[0].config_entries == gateway1_device.config_entries
+
+    # Validate that gateway 1 bulb 2 now only exists associated to config entry 3.
+    # The device will have had its identifiers updated to the new format (for the tradfri
+    # domain) per migrate_config_entry_and_identifiers().
+    # The device will have then been removed from config entry 1 (gateway1)
+    # due to it not matching a device in the command store.
+    device_entry = device_registry.async_get_device(
+        identifiers={(tradfri.DOMAIN, f"{GATEWAY_ID1}-65538")}
+    )
+
+    assert device_entry.id == gateway1_bulb2.id
+    # Assert that the only config entry associated to this device is config entry 3
+    assert device_entry.config_entries == {config_entry3.entry_id}
+    # Assert that that device's other identifiers remain untouched
+    assert device_entry.identifiers == {
+        (tradfri.DOMAIN, f"{GATEWAY_ID1}-65538"),
+        ("test_domain", "config_entry_3-device2"),
+    }
+
+    # Validate that gateway 2 bulb 1 has been added to device registry and with correct unique identifiers
+    # (This bulb device exists on gateway 2 because the command_store created above will be executed
+    # for each gateway being set up.)
+    device_entries = dr.async_entries_for_config_entry(
+        device_registry, config_entry2.entry_id
+    )
+    assert len(device_entries) == 2
+    assert device_entries[1].identifiers == {(tradfri.DOMAIN, f"{GATEWAY_ID2}-65537")}
+
+    # Validate that gateway 2 bulb 1 only has gateway 2's config ID associated to it
+    assert device_entries[1].config_entries == {config_entry2.entry_id}
+
+    # Validate that config entry 3 device 1 is still present,
+    # and has not had its config entries or identifiers changed
+    # N.B. The gateway1_bulb2 device will qualify in this set
+    # because the config entry 3 was added to it above
+    device_entries = dr.async_entries_for_config_entry(
+        device_registry, config_entry3.entry_id
+    )
+    assert len(device_entries) == 2
+    assert device_entries[0].id == config_entry3_device.id
+    assert device_entries[0].identifiers == {("test_domain", "config_entry_3-device1")}
+    assert device_entries[0].config_entries == {config_entry3.entry_id}
+
+    # Assert that the tradfri config entries have been migrated to v2 and
+    # the non-tradfri config entry remains at v1
+    assert config_entry1.version == 2
+    assert config_entry2.version == 2
+    assert config_entry3.version == 1
+
+
+def mock_gateway_fixture(command_store: CommandStore, gateway_id: str) -> Gateway:
+    """Mock a Tradfri gateway."""
+    gateway = Gateway()
+    command_store.register_response(
+        gateway.get_gateway_info(),
+        {ATTR_GATEWAY_ID: gateway_id, ATTR_FIRMWARE_VERSION: "1.2.1234"},
+    )
+    command_store.register_response(
+        gateway.get_devices(),
+        [],
+    )
+    return gateway
-- 
GitLab