Skip to content
Snippets Groups Projects
Unverified Commit 61ff40c2 authored by epenet's avatar epenet Committed by GitHub
Browse files

Add base Entity classes to enforce-class-module pylint plugin (#126473)

parent 31200040
No related branches found
No related tags found
No related merge requests found
......@@ -98,7 +98,7 @@ class PassiveBluetoothDataUpdateCoordinator(
self.async_update_listeners()
class PassiveBluetoothCoordinatorEntity(
class PassiveBluetoothCoordinatorEntity( # pylint: disable=hass-enforce-class-module
BaseCoordinatorEntity[_PassiveBluetoothDataUpdateCoordinatorT]
):
"""A class for entities using DataUpdateCoordinator."""
......
......@@ -28,7 +28,9 @@ async def async_setup_entry(
@dataclass(frozen=True, kw_only=True)
class StarlinkDeviceTrackerEntityDescription(EntityDescription):
class StarlinkDeviceTrackerEntityDescription( # pylint: disable=hass-enforce-class-module
EntityDescription
):
"""Describes a Starlink button entity."""
latitude_fn: Callable[[StarlinkData], float]
......
......@@ -2,14 +2,23 @@
from __future__ import annotations
from ast import ClassDef
from astroid import nodes
from pylint.checkers import BaseChecker
from pylint.lint import PyLinter
from homeassistant.const import Platform
_BASE_ENTITY_MODULES: set[str] = {
"BaseCoordinatorEntity",
"CoordinatorEntity",
"Entity",
"EntityDescription",
"ManualTriggerEntity",
"RestoreEntity",
"ToggleEntity",
"ToggleEntityDescription",
"TriggerBaseEntity",
}
_MODULES: dict[str, set[str]] = {
"air_quality": {"AirQualityEntity"},
"alarm_control_panel": {
......@@ -82,6 +91,11 @@ _ENTITY_COMPONENTS: set[str] = {platform.value for platform in Platform}.union(
)
_MODULE_CLASSES = {
class_name for classes in _MODULES.values() for class_name in classes
}
class HassEnforceClassModule(BaseChecker):
"""Checker for class in correct module."""
......@@ -106,11 +120,15 @@ class HassEnforceClassModule(BaseChecker):
current_integration = parts[2]
current_module = parts[3] if len(parts) > 3 else ""
ancestors = list(node.ancestors())
if current_module != "entity" and current_integration not in _ENTITY_COMPONENTS:
top_level_ancestors = list(node.ancestors(recurs=False))
for ancestor in top_level_ancestors:
if ancestor.name == "Entity":
if ancestor.name in _BASE_ENTITY_MODULES and not any(
anc.name in _MODULE_CLASSES for anc in ancestors
):
self.add_message(
"hass-enforce-class-module",
node=node,
......@@ -118,15 +136,10 @@ class HassEnforceClassModule(BaseChecker):
)
return
ancestors: list[ClassDef] | None = None
for expected_module, classes in _MODULES.items():
if expected_module in (current_module, current_integration):
continue
if ancestors is None:
ancestors = list(node.ancestors()) # cache result for other modules
for ancestor in ancestors:
if ancestor.name in classes:
self.add_message(
......
......@@ -84,6 +84,12 @@ def test_enforce_class_platform_good(
class CustomSensorEntity(SensorEntity):
pass
class CoordinatorEntity:
pass
class CustomCoordinatorSensorEntity(CoordinatorEntity, SensorEntity):
pass
"""
root_node = astroid.parse(code, path)
walker = ASTWalker(linter)
......@@ -115,6 +121,12 @@ def test_enforce_class_module_bad_simple(
class TestCoordinator(DataUpdateCoordinator):
pass
class CoordinatorEntity:
pass
class CustomCoordinatorSensorEntity(CoordinatorEntity):
pass
""",
path,
)
......@@ -133,6 +145,16 @@ def test_enforce_class_module_bad_simple(
end_line=5,
end_col_offset=21,
),
MessageTest(
msg_id="hass-enforce-class-module",
line=11,
node=root_node.body[3],
args=("CoordinatorEntity", "entity"),
confidence=UNDEFINED,
col_offset=0,
end_line=11,
end_col_offset=35,
),
):
walker.walk(root_node)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment