From c759512c70d7b260ba40b92944a82f012c76652b Mon Sep 17 00:00:00 2001
From: epenet <6771947+epenet@users.noreply.github.com>
Date: Mon, 23 Sep 2024 02:55:55 +0200
Subject: [PATCH] Prevent callback decorator on coroutine functions (#126429)

* Prevent callback decorator on async functions

* Adjust

* Adjust

* Adjust components

* Adjust tests

* Rename

* One more

* Adjust

* Adjust again

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
---
 .../components/assist_satellite/websocket_api.py  |  2 --
 homeassistant/components/cloud/google_config.py   |  2 +-
 homeassistant/components/dsmr/sensor.py           |  2 +-
 homeassistant/components/fritz/switch.py          |  3 +--
 homeassistant/components/lcn/websocket.py         |  3 +--
 homeassistant/components/lifx/sensor.py           |  1 -
 homeassistant/components/madvr/__init__.py        |  3 +--
 homeassistant/components/onkyo/media_player.py    |  2 --
 homeassistant/components/plaato/sensor.py         |  2 +-
 homeassistant/components/wake_word/__init__.py    |  1 -
 homeassistant/components/zha/helpers.py           |  2 +-
 pylint/plugins/hass_enforce_type_hints.py         | 15 +++++++++++++--
 tests/test_core.py                                |  4 ++--
 13 files changed, 22 insertions(+), 20 deletions(-)

diff --git a/homeassistant/components/assist_satellite/websocket_api.py b/homeassistant/components/assist_satellite/websocket_api.py
index 741f4364e7f..4c95d9555aa 100644
--- a/homeassistant/components/assist_satellite/websocket_api.py
+++ b/homeassistant/components/assist_satellite/websocket_api.py
@@ -34,7 +34,6 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
     websocket_api.async_register_command(hass, websocket_test_connection)
 
 
-@callback
 @websocket_api.websocket_command(
     {
         vol.Required("type"): "assist_satellite/intercept_wake_word",
@@ -101,7 +100,6 @@ def websocket_get_configuration(
     connection.send_result(msg["id"], config_dict)
 
 
-@callback
 @websocket_api.websocket_command(
     {
         vol.Required("type"): "assist_satellite/set_wake_words",
diff --git a/homeassistant/components/cloud/google_config.py b/homeassistant/components/cloud/google_config.py
index 3586823ca11..43dd5279d35 100644
--- a/homeassistant/components/cloud/google_config.py
+++ b/homeassistant/components/cloud/google_config.py
@@ -478,7 +478,7 @@ class CloudGoogleConfig(AbstractConfig):
         self.async_schedule_google_sync_all()
 
     @callback
-    async def _handle_device_registry_updated(
+    def _handle_device_registry_updated(
         self, event: Event[dr.EventDeviceRegistryUpdatedData]
     ) -> None:
         """Handle when device registry updated."""
diff --git a/homeassistant/components/dsmr/sensor.py b/homeassistant/components/dsmr/sensor.py
index b76736a1101..a069c32be04 100644
--- a/homeassistant/components/dsmr/sensor.py
+++ b/homeassistant/components/dsmr/sensor.py
@@ -713,7 +713,7 @@ async def async_setup_entry(
     task = asyncio.create_task(connect_and_reconnect())
 
     @callback
-    async def _async_stop(_: Event) -> None:
+    def _async_stop(_: Event) -> None:
         if add_entities_handler is not None:
             add_entities_handler()
         task.cancel()
diff --git a/homeassistant/components/fritz/switch.py b/homeassistant/components/fritz/switch.py
index ce89cfc736d..dfcb1162c3e 100644
--- a/homeassistant/components/fritz/switch.py
+++ b/homeassistant/components/fritz/switch.py
@@ -9,7 +9,7 @@ from homeassistant.components.network import async_get_source_ip
 from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription
 from homeassistant.config_entries import ConfigEntry
 from homeassistant.const import EntityCategory
-from homeassistant.core import HomeAssistant, callback
+from homeassistant.core import HomeAssistant
 from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC, DeviceInfo
 from homeassistant.helpers.dispatcher import async_dispatcher_connect
 from homeassistant.helpers.entity import Entity
@@ -242,7 +242,6 @@ async def async_setup_entry(
 
     async_add_entities(entities_list)
 
-    @callback
     async def async_update_avm_device() -> None:
         """Update the values of the AVM device."""
         async_add_entities(await _async_profile_entities_list(avm_wrapper, data_fritz))
diff --git a/homeassistant/components/lcn/websocket.py b/homeassistant/components/lcn/websocket.py
index 65896cc78d1..d3268dfbf91 100644
--- a/homeassistant/components/lcn/websocket.py
+++ b/homeassistant/components/lcn/websocket.py
@@ -21,7 +21,7 @@ from homeassistant.const import (
     CONF_NAME,
     CONF_RESOURCE,
 )
-from homeassistant.core import HomeAssistant, callback
+from homeassistant.core import HomeAssistant
 from homeassistant.helpers import device_registry as dr, entity_registry as er
 import homeassistant.helpers.config_validation as cv
 
@@ -102,7 +102,6 @@ def get_config_entry(
 ) -> AsyncWebSocketCommandHandler:
     """Websocket decorator to ensure the config_entry exists and return it."""
 
-    @callback
     @wraps(func)
     async def get_entry(
         hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
diff --git a/homeassistant/components/lifx/sensor.py b/homeassistant/components/lifx/sensor.py
index 2f54317f9bd..68f354024e4 100644
--- a/homeassistant/components/lifx/sensor.py
+++ b/homeassistant/components/lifx/sensor.py
@@ -65,7 +65,6 @@ class LIFXRssiSensor(LIFXEntity, SensorEntity):
         """Handle coordinator updates."""
         self._attr_native_value = self.coordinator.rssi
 
-    @callback
     async def async_added_to_hass(self) -> None:
         """Enable RSSI updates."""
         self.async_on_remove(self.coordinator.async_enable_rssi_updates())
diff --git a/homeassistant/components/madvr/__init__.py b/homeassistant/components/madvr/__init__.py
index a6ad3b2d1fd..bb42adb21fc 100644
--- a/homeassistant/components/madvr/__init__.py
+++ b/homeassistant/components/madvr/__init__.py
@@ -8,7 +8,7 @@ from madvr.madvr import Madvr
 
 from homeassistant.config_entries import ConfigEntry
 from homeassistant.const import CONF_HOST, CONF_PORT, EVENT_HOMEASSISTANT_STOP, Platform
-from homeassistant.core import Event, HomeAssistant, callback
+from homeassistant.core import Event, HomeAssistant
 
 from .coordinator import MadVRCoordinator
 
@@ -47,7 +47,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: MadVRConfigEntry) -> boo
 
     await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
 
-    @callback
     async def handle_unload(event: Event) -> None:
         """Handle unload."""
         await async_handle_unload(coordinator=coordinator)
diff --git a/homeassistant/components/onkyo/media_player.py b/homeassistant/components/onkyo/media_player.py
index 1718ecb36be..af4285e2abd 100644
--- a/homeassistant/components/onkyo/media_player.py
+++ b/homeassistant/components/onkyo/media_player.py
@@ -268,7 +268,6 @@ async def async_setup_platform(
 
         _LOGGER.debug("Manually creating receiver: %s (%s)", name, host)
 
-        @callback
         async def async_onkyo_interview_callback(conn: pyeiscp.Connection) -> None:
             """Receiver interviewed, connection not yet active."""
             info = ReceiverInfo(conn.host, conn.port, conn.name, conn.identifier)
@@ -284,7 +283,6 @@ async def async_setup_platform(
     else:
         _LOGGER.debug("Discovering receivers")
 
-        @callback
         async def async_onkyo_discovery_callback(conn: pyeiscp.Connection) -> None:
             """Receiver discovered, connection not yet active."""
             info = ReceiverInfo(conn.host, conn.port, conn.name, conn.identifier)
diff --git a/homeassistant/components/plaato/sensor.py b/homeassistant/components/plaato/sensor.py
index 7aa30dd2fe0..b11bac40144 100644
--- a/homeassistant/components/plaato/sensor.py
+++ b/homeassistant/components/plaato/sensor.py
@@ -44,7 +44,7 @@ async def async_setup_entry(
     entry_data = hass.data[DOMAIN][entry.entry_id]
 
     @callback
-    async def _async_update_from_webhook(device_id, sensor_data: PlaatoDevice):
+    def _async_update_from_webhook(device_id, sensor_data: PlaatoDevice):
         """Update/Create the sensors."""
         entry_data[SENSOR_DATA] = sensor_data
 
diff --git a/homeassistant/components/wake_word/__init__.py b/homeassistant/components/wake_word/__init__.py
index 84e59ab66d6..00db5a7355b 100644
--- a/homeassistant/components/wake_word/__init__.py
+++ b/homeassistant/components/wake_word/__init__.py
@@ -137,7 +137,6 @@ class WakeWordDetectionEntity(RestoreEntity):
     }
 )
 @websocket_api.async_response
-@callback
 async def websocket_entity_info(
     hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
 ) -> None:
diff --git a/homeassistant/components/zha/helpers.py b/homeassistant/components/zha/helpers.py
index dc999f13693..8e22e412e60 100644
--- a/homeassistant/components/zha/helpers.py
+++ b/homeassistant/components/zha/helpers.py
@@ -1107,7 +1107,7 @@ def async_cluster_exists(hass: HomeAssistant, cluster_id, skip_coordinator=True)
 
 
 @callback
-async def async_add_entities(
+def async_add_entities(
     _async_add_entities: AddEntitiesCallback,
     entity_class: type[ZHAEntity],
     entities: list[EntityData],
diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py
index 7f4a7fbd485..f696bc55177 100644
--- a/pylint/plugins/hass_enforce_type_hints.py
+++ b/pylint/plugins/hass_enforce_type_hints.py
@@ -3093,6 +3093,11 @@ class HassTypeHintChecker(BaseChecker):
             "hass-consider-usefixtures-decorator",
             "Used when an argument type is None and could be a fixture",
         ),
+        "W7434": (
+            "A coroutine function should not be decorated with @callback",
+            "hass-async-callback-decorator",
+            "Used when a coroutine function has an invalid @callback decorator",
+        ),
     }
     options = (
         (
@@ -3195,6 +3200,14 @@ class HassTypeHintChecker(BaseChecker):
                 self._check_function(function_node, match, annotations)
                 checked_class_methods.add(function_node.name)
 
+    def visit_asyncfunctiondef(self, node: nodes.AsyncFunctionDef) -> None:
+        """Apply checks on an AsyncFunctionDef node."""
+        if (
+            decoratornames := node.decoratornames()
+        ) and "homeassistant.core.callback" in decoratornames:
+            self.add_message("hass-async-callback-decorator", node=node)
+        self.visit_functiondef(node)
+
     def visit_functiondef(self, node: nodes.FunctionDef) -> None:
         """Apply relevant type hint checks on a FunctionDef node."""
         annotations = _get_all_annotations(node)
@@ -3234,8 +3247,6 @@ class HassTypeHintChecker(BaseChecker):
                 continue
             self._check_function(node, match, annotations)
 
-    visit_asyncfunctiondef = visit_functiondef
-
     def _check_function(
         self,
         node: nodes.FunctionDef,
diff --git a/tests/test_core.py b/tests/test_core.py
index 9ca57d1563f..9f19a372634 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -2194,7 +2194,7 @@ async def test_async_functions_with_callback(hass: HomeAssistant) -> None:
     runs = []
 
     @ha.callback
-    async def test():
+    async def test():  # pylint: disable=hass-async-callback-decorator
         runs.append(True)
 
     await hass.async_add_job(test)
@@ -2205,7 +2205,7 @@ async def test_async_functions_with_callback(hass: HomeAssistant) -> None:
     assert len(runs) == 2
 
     @ha.callback
-    async def service_handler(call):
+    async def service_handler(call):  # pylint: disable=hass-async-callback-decorator
         runs.append(True)
 
     hass.services.async_register("test_domain", "test_service", service_handler)
-- 
GitLab