diff --git a/homeassistant/components/bluetooth/passive_update_coordinator.py b/homeassistant/components/bluetooth/passive_update_coordinator.py index df06a7c534b2209e90c87498c7c85cd607c86ec4..be232f87b24a3ff868a90ab59ded77392ed58680 100644 --- a/homeassistant/components/bluetooth/passive_update_coordinator.py +++ b/homeassistant/components/bluetooth/passive_update_coordinator.py @@ -98,7 +98,7 @@ class PassiveBluetoothDataUpdateCoordinator( self.async_update_listeners() -class PassiveBluetoothCoordinatorEntity( +class PassiveBluetoothCoordinatorEntity( # pylint: disable=hass-enforce-class-module BaseCoordinatorEntity[_PassiveBluetoothDataUpdateCoordinatorT] ): """A class for entities using DataUpdateCoordinator.""" diff --git a/homeassistant/components/starlink/device_tracker.py b/homeassistant/components/starlink/device_tracker.py index 34769d687ffb66547fa71cce1df6b0e4702d6b18..129efa0d025411ff8baf42c2f5c6e91a1e1c2768 100644 --- a/homeassistant/components/starlink/device_tracker.py +++ b/homeassistant/components/starlink/device_tracker.py @@ -28,7 +28,9 @@ async def async_setup_entry( @dataclass(frozen=True, kw_only=True) -class StarlinkDeviceTrackerEntityDescription(EntityDescription): +class StarlinkDeviceTrackerEntityDescription( # pylint: disable=hass-enforce-class-module + EntityDescription +): """Describes a Starlink button entity.""" latitude_fn: Callable[[StarlinkData], float] diff --git a/pylint/plugins/hass_enforce_class_module.py b/pylint/plugins/hass_enforce_class_module.py index 6491a702b7fc1c9ad87588ef869260b613cf9e53..e48cae877a558617b85a1ff2f462572b41992794 100644 --- a/pylint/plugins/hass_enforce_class_module.py +++ b/pylint/plugins/hass_enforce_class_module.py @@ -2,14 +2,23 @@ from __future__ import annotations -from ast import ClassDef - from astroid import nodes from pylint.checkers import BaseChecker from pylint.lint import PyLinter from homeassistant.const import Platform +_BASE_ENTITY_MODULES: set[str] = { + "BaseCoordinatorEntity", + "CoordinatorEntity", + "Entity", + "EntityDescription", + "ManualTriggerEntity", + "RestoreEntity", + "ToggleEntity", + "ToggleEntityDescription", + "TriggerBaseEntity", +} _MODULES: dict[str, set[str]] = { "air_quality": {"AirQualityEntity"}, "alarm_control_panel": { @@ -82,6 +91,11 @@ _ENTITY_COMPONENTS: set[str] = {platform.value for platform in Platform}.union( ) +_MODULE_CLASSES = { + class_name for classes in _MODULES.values() for class_name in classes +} + + class HassEnforceClassModule(BaseChecker): """Checker for class in correct module.""" @@ -106,11 +120,15 @@ class HassEnforceClassModule(BaseChecker): current_integration = parts[2] current_module = parts[3] if len(parts) > 3 else "" + ancestors = list(node.ancestors()) + if current_module != "entity" and current_integration not in _ENTITY_COMPONENTS: top_level_ancestors = list(node.ancestors(recurs=False)) for ancestor in top_level_ancestors: - if ancestor.name == "Entity": + if ancestor.name in _BASE_ENTITY_MODULES and not any( + anc.name in _MODULE_CLASSES for anc in ancestors + ): self.add_message( "hass-enforce-class-module", node=node, @@ -118,15 +136,10 @@ class HassEnforceClassModule(BaseChecker): ) return - ancestors: list[ClassDef] | None = None - for expected_module, classes in _MODULES.items(): if expected_module in (current_module, current_integration): continue - if ancestors is None: - ancestors = list(node.ancestors()) # cache result for other modules - for ancestor in ancestors: if ancestor.name in classes: self.add_message( diff --git a/tests/pylint/test_enforce_class_module.py b/tests/pylint/test_enforce_class_module.py index 8927147e89a2e18456fc76843e228a0f62e7829e..8b3ac563c6a1dc7cc55357a358ac6c4656f1518f 100644 --- a/tests/pylint/test_enforce_class_module.py +++ b/tests/pylint/test_enforce_class_module.py @@ -84,6 +84,12 @@ def test_enforce_class_platform_good( class CustomSensorEntity(SensorEntity): pass + + class CoordinatorEntity: + pass + + class CustomCoordinatorSensorEntity(CoordinatorEntity, SensorEntity): + pass """ root_node = astroid.parse(code, path) walker = ASTWalker(linter) @@ -115,6 +121,12 @@ def test_enforce_class_module_bad_simple( class TestCoordinator(DataUpdateCoordinator): pass + + class CoordinatorEntity: + pass + + class CustomCoordinatorSensorEntity(CoordinatorEntity): + pass """, path, ) @@ -133,6 +145,16 @@ def test_enforce_class_module_bad_simple( end_line=5, end_col_offset=21, ), + MessageTest( + msg_id="hass-enforce-class-module", + line=11, + node=root_node.body[3], + args=("CoordinatorEntity", "entity"), + confidence=UNDEFINED, + col_offset=0, + end_line=11, + end_col_offset=35, + ), ): walker.walk(root_node)