diff --git a/homeassistant/components/mqtt/cover.py b/homeassistant/components/mqtt/cover.py index 7423094d2096e76cd96481619d7861d3255e4b3f..39c4090109c7f59906e577ccd9d221c570eb72b4 100644 --- a/homeassistant/components/mqtt/cover.py +++ b/homeassistant/components/mqtt/cover.py @@ -45,9 +45,14 @@ from .const import ( DEFAULT_OPTIMISTIC, ) from .debug_info import log_messages -from .mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, async_setup_entry_helper +from .mixins import ( + MQTT_ENTITY_COMMON_SCHEMA, + MqttEntity, + async_setup_entry_helper, + write_state_on_attr_change, +) from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage -from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic +from .util import valid_publish_topic, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -349,6 +354,7 @@ class MqttCover(MqttEntity, CoverEntity): @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"_attr_current_cover_tilt_position"}) def tilt_message_received(msg: ReceiveMessage) -> None: """Handle tilt updates.""" payload = self._tilt_status_template(msg.payload) @@ -361,6 +367,9 @@ class MqttCover(MqttEntity, CoverEntity): @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change( + self, {"_attr_is_closed", "_attr_is_closing", "_attr_is_opening"} + ) def state_message_received(msg: ReceiveMessage) -> None: """Handle new MQTT state messages.""" payload = self._value_template(msg.payload) @@ -398,10 +407,18 @@ class MqttCover(MqttEntity, CoverEntity): return self._update_state(state) - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) - @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change( + self, + { + "_attr_current_cover_position", + "_attr_current_cover_tilt_position", + "_attr_is_closed", + "_attr_is_closing", + "_attr_is_opening", + }, + ) def position_message_received(msg: ReceiveMessage) -> None: """Handle new MQTT position messages.""" payload: ReceivePayloadType = self._get_position_template(msg.payload) @@ -444,8 +461,6 @@ class MqttCover(MqttEntity, CoverEntity): else STATE_OPEN ) - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) - if self._config.get(CONF_GET_POSITION_TOPIC): topics["get_position_topic"] = { "topic": self._config.get(CONF_GET_POSITION_TOPIC), @@ -721,7 +736,6 @@ class MqttCover(MqttEntity, CoverEntity): ): level = self.find_percentage_in_range(payload) self._attr_current_cover_tilt_position = level - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) else: _LOGGER.warning( "Payload '%s' is out of range, must be between '%s' and '%s' inclusive", diff --git a/tests/components/mqtt/test_cover.py b/tests/components/mqtt/test_cover.py index 2eec5f8374b941017a7ee0c496ad6d62a1d63d73..74dc48f440242c3e8dcd12cda37d552ea0b56350 100644 --- a/tests/components/mqtt/test_cover.py +++ b/tests/components/mqtt/test_cover.py @@ -47,6 +47,7 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant from .test_common import ( + help_custom_config, help_test_availability_when_connection_lost, help_test_availability_without_topic, help_test_custom_availability_payload, @@ -69,6 +70,7 @@ from .test_common import ( help_test_setting_attribute_via_mqtt_json_message, help_test_setting_attribute_with_template, help_test_setting_blocked_attribute_via_mqtt_json_message, + help_test_skipped_async_ha_write_state, help_test_unique_id, help_test_unload_config_entry_with_platform, help_test_update_with_json_attrs_bad_json, @@ -3666,3 +3668,43 @@ async def test_unload_entry( await help_test_unload_config_entry_with_platform( hass, mqtt_mock_entry, domain, config ) + + +@pytest.mark.parametrize( + "hass_config", + [ + help_custom_config( + cover.DOMAIN, + DEFAULT_CONFIG, + ( + { + "availability_topic": "availability-topic", + "json_attributes_topic": "json-attributes-topic", + "state_topic": "test-topic", + "position_topic": "position-topic", + "tilt_status_topic": "tilt-status-topic", + }, + ), + ) + ], +) +@pytest.mark.parametrize( + ("topic", "payload1", "payload2"), + [ + ("test-topic", "open", "closed"), + ("availability-topic", "online", "offline"), + ("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'), + ("position-topic", "50", "100"), + ("tilt-status-topic", "50", "100"), + ], +) +async def test_skipped_async_ha_write_state( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + topic: str, + payload1: str, + payload2: str, +) -> None: + """Test a write state command is only called when there is change.""" + await mqtt_mock_entry() + await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)