From cd2045b66d1f30c4cdae151989eb6bd725f3081e Mon Sep 17 00:00:00 2001
From: Erik Montnemery <erik@montnemery.com>
Date: Thu, 1 Sep 2022 17:45:19 +0200
Subject: [PATCH] Clean up user overridden device class in entity registry
 (#77662)

---
 homeassistant/helpers/entity_registry.py | 14 ++++-
 tests/helpers/test_entity_registry.py    | 72 ++++++++++++++++++++++++
 2 files changed, 84 insertions(+), 2 deletions(-)

diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py
index 23e9cc5f752..d495d196440 100644
--- a/homeassistant/helpers/entity_registry.py
+++ b/homeassistant/helpers/entity_registry.py
@@ -30,6 +30,7 @@ from homeassistant.const import (
     MAX_LENGTH_STATE_ENTITY_ID,
     STATE_UNAVAILABLE,
     STATE_UNKNOWN,
+    Platform,
 )
 from homeassistant.core import (
     Event,
@@ -62,7 +63,7 @@ SAVE_DELAY = 10
 _LOGGER = logging.getLogger(__name__)
 
 STORAGE_VERSION_MAJOR = 1
-STORAGE_VERSION_MINOR = 7
+STORAGE_VERSION_MINOR = 8
 STORAGE_KEY = "core.entity_registry"
 
 # Attributes relevant to describing entity
@@ -970,10 +971,19 @@ async def _async_migrate(
             entity["hidden_by"] = None
 
     if old_major_version == 1 and old_minor_version < 7:
-        # Version 1.6 adds has_entity_name
+        # Version 1.7 adds has_entity_name
         for entity in data["entities"]:
             entity["has_entity_name"] = False
 
+    if old_major_version == 1 and old_minor_version < 8:
+        # Cleanup after frontend bug which incorrectly updated device_class
+        # Fixed by frontend PR #13551
+        for entity in data["entities"]:
+            domain = split_entity_id(entity["entity_id"])[0]
+            if domain in [Platform.BINARY_SENSOR, Platform.COVER]:
+                continue
+            entity["device_class"] = None
+
     if old_major_version > 1:
         raise NotImplementedError
     return data
diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py
index 9c2592eace0..e4c371a0198 100644
--- a/tests/helpers/test_entity_registry.py
+++ b/tests/helpers/test_entity_registry.py
@@ -528,6 +528,78 @@ async def test_migration_1_1(hass, hass_storage):
     assert entry.original_device_class == "best_class"
 
 
+@pytest.mark.parametrize("load_registries", [False])
+async def test_migration_1_7(hass, hass_storage):
+    """Test migration from version 1.7.
+
+    This tests cleanup after frontend bug which incorrectly updated device_class
+    """
+    entity_dict = {
+        "area_id": None,
+        "capabilities": {},
+        "config_entry_id": None,
+        "device_id": None,
+        "disabled_by": None,
+        "entity_category": None,
+        "has_entity_name": False,
+        "hidden_by": None,
+        "icon": None,
+        "id": "12345",
+        "name": None,
+        "options": None,
+        "original_icon": None,
+        "original_name": None,
+        "platform": "super_platform",
+        "supported_features": 0,
+        "unique_id": "very_unique",
+        "unit_of_measurement": None,
+    }
+
+    hass_storage[er.STORAGE_KEY] = {
+        "version": 1,
+        "minor_version": 7,
+        "data": {
+            "entities": [
+                {
+                    **entity_dict,
+                    "device_class": "original_class_by_integration",
+                    "entity_id": "test.entity",
+                    "original_device_class": "new_class_by_integration",
+                },
+                {
+                    **entity_dict,
+                    "device_class": "class_by_user",
+                    "entity_id": "binary_sensor.entity",
+                    "original_device_class": "class_by_integration",
+                },
+                {
+                    **entity_dict,
+                    "device_class": "class_by_user",
+                    "entity_id": "cover.entity",
+                    "original_device_class": "class_by_integration",
+                },
+            ]
+        },
+    }
+
+    await er.async_load(hass)
+    registry = er.async_get(hass)
+
+    entry = registry.async_get_or_create("test", "super_platform", "very_unique")
+    assert entry.device_class is None
+    assert entry.original_device_class == "new_class_by_integration"
+
+    entry = registry.async_get_or_create(
+        "binary_sensor", "super_platform", "very_unique"
+    )
+    assert entry.device_class == "class_by_user"
+    assert entry.original_device_class == "class_by_integration"
+
+    entry = registry.async_get_or_create("cover", "super_platform", "very_unique")
+    assert entry.device_class == "class_by_user"
+    assert entry.original_device_class == "class_by_integration"
+
+
 @pytest.mark.parametrize("load_registries", [False])
 async def test_loading_invalid_entity_id(hass, hass_storage):
     """Test we skip entities with invalid entity IDs."""
-- 
GitLab