From 5b422daf36c06636e5d5b00f5237b65e3b66fcfd Mon Sep 17 00:00:00 2001
From: Jan Bouwhuis <jbouwh@users.noreply.github.com>
Date: Fri, 22 Sep 2023 13:32:30 +0200
Subject: [PATCH] Avoid redundant calls to `async_write_ha_state` in MQTT light
 (#100690)

* Limit state writes for mqtt light

* Additional tests and review follow up
---
 .../components/mqtt/light/schema_basic.py     | 31 +++---
 .../components/mqtt/light/schema_json.py      | 21 ++++-
 .../components/mqtt/light/schema_template.py  | 16 +++-
 tests/components/mqtt/test_light.py           | 57 +++++++++++
 tests/components/mqtt/test_light_json.py      | 94 +++++++++++++++++++
 tests/components/mqtt/test_light_template.py  | 66 +++++++++++++
 6 files changed, 264 insertions(+), 21 deletions(-)

diff --git a/homeassistant/components/mqtt/light/schema_basic.py b/homeassistant/components/mqtt/light/schema_basic.py
index c12d8719d7a..ab8d9921161 100644
--- a/homeassistant/components/mqtt/light/schema_basic.py
+++ b/homeassistant/components/mqtt/light/schema_basic.py
@@ -55,7 +55,7 @@ from ..const import (
     PAYLOAD_NONE,
 )
 from ..debug_info import log_messages
-from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity
+from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, write_state_on_attr_change
 from ..models import (
     MessageCallbackType,
     MqttCommandTemplate,
@@ -66,7 +66,7 @@ from ..models import (
     ReceivePayloadType,
     TemplateVarsType,
 )
-from ..util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
+from ..util import valid_publish_topic, valid_subscribe_topic
 from .schema import MQTT_LIGHT_SCHEMA_SCHEMA
 
 _LOGGER = logging.getLogger(__name__)
@@ -415,6 +415,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(self, {"_attr_is_on"})
         def state_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages."""
             payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE](
@@ -430,7 +431,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
                 self._attr_is_on = False
             elif payload == PAYLOAD_NONE:
                 self._attr_is_on = None
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
 
         if self._topic[CONF_STATE_TOPIC] is not None:
             topics[CONF_STATE_TOPIC] = {
@@ -442,6 +442,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(self, {"_attr_brightness"})
         def brightness_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages for the brightness."""
             payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE](
@@ -459,8 +460,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
             percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE]
             self._attr_brightness = min(round(percent_bright * 255), 255)
 
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
-
         add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received)
 
         @callback
@@ -501,6 +500,9 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(
+            self, {"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"}
+        )
         def rgb_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages for RGB."""
             rgb = _rgbx_received(
@@ -509,12 +511,14 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
             if rgb is None:
                 return
             self._attr_rgb_color = cast(tuple[int, int, int], rgb)
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
 
         add_topic(CONF_RGB_STATE_TOPIC, rgb_received)
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(
+            self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"}
+        )
         def rgbw_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages for RGBW."""
             rgbw = _rgbx_received(
@@ -526,12 +530,14 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
             if rgbw is None:
                 return
             self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw)
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
 
         add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received)
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(
+            self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"}
+        )
         def rgbww_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages for RGBWW."""
 
@@ -558,12 +564,12 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
             if rgbww is None:
                 return
             self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww)
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
 
         add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received)
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(self, {"_attr_color_mode"})
         def color_mode_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages for color mode."""
             payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE](
@@ -574,12 +580,12 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
                 return
 
             self._attr_color_mode = ColorMode(str(payload))
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
 
         add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received)
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(self, {"_attr_color_mode", "_attr_color_temp"})
         def color_temp_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages for color temperature."""
             payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE](
@@ -592,12 +598,12 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
             if self._optimistic_color_mode:
                 self._attr_color_mode = ColorMode.COLOR_TEMP
             self._attr_color_temp = int(payload)
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
 
         add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received)
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(self, {"_attr_effect"})
         def effect_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages for effect."""
             payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE](
@@ -608,12 +614,12 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
                 return
 
             self._attr_effect = str(payload)
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
 
         add_topic(CONF_EFFECT_STATE_TOPIC, effect_received)
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(self, {"_attr_color_mode", "_attr_hs_color"})
         def hs_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages for hs color."""
             payload = self._value_templates[CONF_HS_VALUE_TEMPLATE](
@@ -627,7 +633,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
                 if self._optimistic_color_mode:
                     self._attr_color_mode = ColorMode.HS
                 self._attr_hs_color = cast(tuple[float, float], hs_color)
-                get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
             except ValueError:
                 _LOGGER.warning("Failed to parse hs state update: '%s'", payload)
 
@@ -635,6 +640,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(self, {"_attr_color_mode", "_attr_xy_color"})
         def xy_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages for xy color."""
             payload = self._value_templates[CONF_XY_VALUE_TEMPLATE](
@@ -648,7 +654,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
             if self._optimistic_color_mode:
                 self._attr_color_mode = ColorMode.XY
             self._attr_xy_color = cast(tuple[float, float], xy_color)
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
 
         add_topic(CONF_XY_STATE_TOPIC, xy_received)
 
diff --git a/homeassistant/components/mqtt/light/schema_json.py b/homeassistant/components/mqtt/light/schema_json.py
index 11574b88798..ee7e78b0028 100644
--- a/homeassistant/components/mqtt/light/schema_json.py
+++ b/homeassistant/components/mqtt/light/schema_json.py
@@ -63,9 +63,9 @@ from ..const import (
     CONF_STATE_TOPIC,
 )
 from ..debug_info import log_messages
-from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity
+from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, write_state_on_attr_change
 from ..models import ReceiveMessage
-from ..util import get_mqtt_data, valid_subscribe_topic
+from ..util import valid_subscribe_topic
 from .schema import MQTT_LIGHT_SCHEMA_SCHEMA
 from .schema_basic import (
     CONF_BRIGHTNESS_SCALE,
@@ -347,6 +347,21 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(
+            self,
+            {
+                "_attr_brightness",
+                "_attr_color_temp",
+                "_attr_effect",
+                "_attr_hs_color",
+                "_attr_is_on",
+                "_attr_rgb_color",
+                "_attr_rgbw_color",
+                "_attr_rgbww_color",
+                "_attr_xy_color",
+                "color_mode",
+            },
+        )
         def state_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages."""
             values = json_loads_object(msg.payload)
@@ -419,8 +434,6 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
                 with suppress(KeyError):
                     self._attr_effect = cast(str, values["effect"])
 
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
-
         if self._topic[CONF_STATE_TOPIC] is not None:
             self._sub_state = subscription.async_prepare_subscribe_topics(
                 self.hass,
diff --git a/homeassistant/components/mqtt/light/schema_template.py b/homeassistant/components/mqtt/light/schema_template.py
index e811c45fc67..ecbcdcd18d7 100644
--- a/homeassistant/components/mqtt/light/schema_template.py
+++ b/homeassistant/components/mqtt/light/schema_template.py
@@ -46,7 +46,7 @@ from ..const import (
     PAYLOAD_NONE,
 )
 from ..debug_info import log_messages
-from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity
+from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, write_state_on_attr_change
 from ..models import (
     MqttCommandTemplate,
     MqttValueTemplate,
@@ -54,7 +54,6 @@ from ..models import (
     ReceiveMessage,
     ReceivePayloadType,
 )
-from ..util import get_mqtt_data
 from .schema import MQTT_LIGHT_SCHEMA_SCHEMA
 from .schema_basic import MQTT_LIGHT_ATTRIBUTES_BLOCKED
 
@@ -215,6 +214,17 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
 
         @callback
         @log_messages(self.hass, self.entity_id)
+        @write_state_on_attr_change(
+            self,
+            {
+                "_attr_brightness",
+                "_attr_color_mode",
+                "_attr_color_temp",
+                "_attr_effect",
+                "_attr_hs_color",
+                "_attr_is_on",
+            },
+        )
         def state_received(msg: ReceiveMessage) -> None:
             """Handle new MQTT messages."""
             state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload)
@@ -283,8 +293,6 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
                 else:
                     _LOGGER.warning("Unsupported effect value received")
 
-            get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
-
         if self._topics[CONF_STATE_TOPIC] is not None:
             self._sub_state = subscription.async_prepare_subscribe_topics(
                 self.hass,
diff --git a/tests/components/mqtt/test_light.py b/tests/components/mqtt/test_light.py
index 0199ee19772..58d37943403 100644
--- a/tests/components/mqtt/test_light.py
+++ b/tests/components/mqtt/test_light.py
@@ -221,6 +221,7 @@ from .test_common import (
     help_test_setting_attribute_via_mqtt_json_message,
     help_test_setting_attribute_with_template,
     help_test_setting_blocked_attribute_via_mqtt_json_message,
+    help_test_skipped_async_ha_write_state,
     help_test_unique_id,
     help_test_unload_config_entry_with_platform,
     help_test_update_with_json_attrs_bad_json,
@@ -3635,3 +3636,59 @@ async def test_unload_entry(
     await help_test_unload_config_entry_with_platform(
         hass, mqtt_mock_entry, domain, config
     )
+
+
+@pytest.mark.parametrize(
+    "hass_config",
+    [
+        help_custom_config(
+            light.DOMAIN,
+            DEFAULT_CONFIG,
+            (
+                {
+                    "availability_topic": "availability-topic",
+                    "json_attributes_topic": "json-attributes-topic",
+                    "state_topic": "test-topic",
+                    "state_value_template": "{{ value_json.state }}",
+                    "brightness_state_topic": "brightness-state-topic",
+                    "color_mode_state_topic": "color-mode-state-topic",
+                    "color_temp_state_topic": "color-temp-state-topic",
+                    "effect_state_topic": "effect-state-topic",
+                    "effect_list": ["effect1", "effect2"],
+                    "hs_state_topic": "hs-state-topic",
+                    "xy_state_topic": "xy-state-topic",
+                    "rgb_state_topic": "rgb-state-topic",
+                    "rgbw_state_topic": "rgbw-state-topic",
+                    "rgbww_state_topic": "rgbww-state-topic",
+                },
+            ),
+        )
+    ],
+)
+@pytest.mark.parametrize(
+    ("topic", "payload1", "payload2"),
+    [
+        ("test-topic", '{"state":"ON"}', '{"state":"OFF"}'),
+        ("availability-topic", "online", "offline"),
+        ("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
+        ("brightness-state-topic", "50", "100"),
+        ("color-mode-state-topic", "rgb", "color_temp"),
+        ("color-temp-state-topic", "800", "200"),
+        ("effect-state-topic", "effect1", "effect2"),
+        ("hs-state-topic", "210,50", "200,50"),
+        ("xy-state-topic", "128,128", "96,96"),
+        ("rgb-state-topic", "128,128,128", "128,128,64"),
+        ("rgbw-state-topic", "128,128,128,255", "128,128,128,128"),
+        ("rgbww-state-topic", "128,128,128,32,255", "128,128,128,64,255"),
+    ],
+)
+async def test_skipped_async_ha_write_state(
+    hass: HomeAssistant,
+    mqtt_mock_entry: MqttMockHAClientGenerator,
+    topic: str,
+    payload1: str,
+    payload2: str,
+) -> None:
+    """Test a write state command is only called when there is change."""
+    await mqtt_mock_entry()
+    await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)
diff --git a/tests/components/mqtt/test_light_json.py b/tests/components/mqtt/test_light_json.py
index 7ff4ccbab85..3b44f86460f 100644
--- a/tests/components/mqtt/test_light_json.py
+++ b/tests/components/mqtt/test_light_json.py
@@ -124,6 +124,7 @@ from .test_common import (
     help_test_setting_attribute_via_mqtt_json_message,
     help_test_setting_attribute_with_template,
     help_test_setting_blocked_attribute_via_mqtt_json_message,
+    help_test_skipped_async_ha_write_state,
     help_test_unique_id,
     help_test_update_with_json_attrs_bad_json,
     help_test_update_with_json_attrs_not_dict,
@@ -2453,3 +2454,96 @@ async def test_setup_manual_entity_from_yaml(
     await mqtt_mock_entry()
     platform = light.DOMAIN
     assert hass.states.get(f"{platform}.test")
+
+
+@pytest.mark.parametrize(
+    "hass_config",
+    [
+        help_custom_config(
+            light.DOMAIN,
+            DEFAULT_CONFIG,
+            (
+                {
+                    "color_mode": True,
+                    "effect": True,
+                    "supported_color_modes": [
+                        "color_temp",
+                        "hs",
+                        "xy",
+                        "rgb",
+                        "rgbw",
+                        "rgbww",
+                        "white",
+                    ],
+                    "effect_list": ["effect1", "effect2"],
+                    "availability_topic": "availability-topic",
+                    "json_attributes_topic": "json-attributes-topic",
+                    "state_topic": "test-topic",
+                },
+            ),
+        )
+    ],
+)
+@pytest.mark.parametrize(
+    ("topic", "payload1", "payload2"),
+    [
+        ("test-topic", '{"state":"ON"}', '{"state":"OFF"}'),
+        ("availability-topic", "online", "offline"),
+        ("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
+        (
+            "test-topic",
+            '{"state":"ON","effect":"effect1"}',
+            '{"state":"ON","effect":"effect2"}',
+        ),
+        (
+            "test-topic",
+            '{"state":"ON","brightness":255}',
+            '{"state":"ON","brightness":96}',
+        ),
+        (
+            "test-topic",
+            '{"state":"ON","brightness":96}',
+            '{"state":"ON","color_mode":"white","brightness":96}',
+        ),
+        (
+            "test-topic",
+            '{"state":"ON","color_mode":"color_temp", "color_temp": 200}',
+            '{"state":"ON","color_mode":"color_temp", "color_temp": 2400}',
+        ),
+        (
+            "test-topic",
+            '{"state":"ON","color_mode":"hs", "color": {"h":24.0,"s":100.0}}',
+            '{"state":"ON","color_mode":"hs", "color": {"h":24.0,"s":90.0}}',
+        ),
+        (
+            "test-topic",
+            '{"state":"ON","color_mode":"xy","color": {"x":0.14,"y":0.131}}',
+            '{"state":"ON","color_mode":"xy","color": {"x":0.16,"y": 0.100}}',
+        ),
+        (
+            "test-topic",
+            '{"state":"ON","brightness":255,"color_mode":"rgb","color":{"r":128,"g":128,"b":255}}',
+            '{"state":"ON","brightness":255,"color_mode":"rgb","color": {"r":255,"g":128,"b":255}}',
+        ),
+        (
+            "test-topic",
+            '{"state":"ON","color_mode":"rgbw","color":{"r":128,"g":128,"b":255,"w":128}}',
+            '{"state":"ON","color_mode":"rgbw","color": {"r":128,"g":128,"b":255,"w":255}}',
+        ),
+        (
+            "test-topic",
+            '{"state":"ON","color_mode":"rgbww","color":{"r":128,"g":128,"b":255,"c":32,"w":128}}',
+            '{"state":"ON","color_mode":"rgbww","color": {"r":128,"g":128,"b":255,"c":16,"w":128}}',
+        ),
+    ],
+)
+async def test_skipped_async_ha_write_state(
+    hass: HomeAssistant,
+    mqtt_mock_entry: MqttMockHAClientGenerator,
+    topic: str,
+    payload1: str,
+    payload2: str,
+) -> None:
+    """Test a write state command is only called when there is change."""
+    await mqtt_mock_entry()
+    await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)
diff --git a/tests/components/mqtt/test_light_template.py b/tests/components/mqtt/test_light_template.py
index 0583a1176b6..f9f355025e9 100644
--- a/tests/components/mqtt/test_light_template.py
+++ b/tests/components/mqtt/test_light_template.py
@@ -46,6 +46,7 @@ from homeassistant.const import (
 from homeassistant.core import HomeAssistant, State
 
 from .test_common import (
+    help_custom_config,
     help_test_availability_when_connection_lost,
     help_test_availability_without_topic,
     help_test_custom_availability_payload,
@@ -68,6 +69,7 @@ from .test_common import (
     help_test_setting_attribute_via_mqtt_json_message,
     help_test_setting_attribute_with_template,
     help_test_setting_blocked_attribute_via_mqtt_json_message,
+    help_test_skipped_async_ha_write_state,
     help_test_unique_id,
     help_test_unload_config_entry_with_platform,
     help_test_update_with_json_attrs_bad_json,
@@ -1378,3 +1380,67 @@ async def test_unload_entry(
     await help_test_unload_config_entry_with_platform(
         hass, mqtt_mock_entry, domain, config
     )
+
+
+@pytest.mark.parametrize(
+    "hass_config",
+    [
+        help_custom_config(
+            light.DOMAIN,
+            DEFAULT_CONFIG,
+            (
+                {
+                    "availability_topic": "availability-topic",
+                    "json_attributes_topic": "json-attributes-topic",
+                    "state_topic": "test-topic",
+                    "state_template": "{{ value_json.state }}",
+                    "brightness_template": "{{ value_json.brightness }}",
+                    "color_temp_template": "{{ value_json.color_temp }}",
+                    "effect_template": "{{ value_json.effect }}",
+                    "red_template": "{{ value_json.r }}",
+                    "green_template": "{{ value_json.g }}",
+                    "blue_template": "{{ value_json.b }}",
+                    "effect_list": ["effect1", "effect2"],
+                },
+            ),
+        )
+    ],
+)
+@pytest.mark.parametrize(
+    ("topic", "payload1", "payload2"),
+    [
+        ("test-topic", '{"state":"on"}', '{"state":"off"}'),
+        ("availability-topic", "online", "offline"),
+        ("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
+        (
+            "test-topic",
+            '{"state":"on", "brightness":50}',
+            '{"state":"on", "brightness":100}',
+        ),
+        (
+            "test-topic",
+            '{"state":"on", "brightness":50,"color_temp":200}',
+            '{"state":"on", "brightness":50,"color_temp":1600}',
+        ),
+        (
+            "test-topic",
+            '{"state":"on", "r":128, "g":128, "b":128}',
+            '{"state":"on", "r":128, "g":128, "b":255}',
+        ),
+        (
+            "test-topic",
+            '{"state":"on", "effect":"effect1"}',
+            '{"state":"on", "effect":"effect2"}',
+        ),
+    ],
+)
+async def test_skipped_async_ha_write_state(
+    hass: HomeAssistant,
+    mqtt_mock_entry: MqttMockHAClientGenerator,
+    topic: str,
+    payload1: str,
+    payload2: str,
+) -> None:
+    """Test a write state command is only called when there is change."""
+    await mqtt_mock_entry()
+    await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)
-- 
GitLab