From 2abcd7cd947e487fbeb1817fb5e6653617e03069 Mon Sep 17 00:00:00 2001
From: Erik Montnemery <erik@montnemery.com>
Date: Fri, 18 Feb 2022 11:35:44 +0100
Subject: [PATCH] Correct state restoring for MQTT temperature sensors (#66741)

* Correct state restoring for MQTT temperature sensors

* Adjust test

* Adjust test
---
 homeassistant/components/mqtt/sensor.py |  9 +++----
 tests/components/mqtt/test_sensor.py    | 32 +++++++++++++++++++------
 2 files changed, 30 insertions(+), 11 deletions(-)

diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py
index 31a784a259e..c24535ebd1f 100644
--- a/homeassistant/components/mqtt/sensor.py
+++ b/homeassistant/components/mqtt/sensor.py
@@ -13,8 +13,8 @@ from homeassistant.components.sensor import (
     DEVICE_CLASSES_SCHEMA,
     ENTITY_ID_FORMAT,
     STATE_CLASSES_SCHEMA,
+    RestoreSensor,
     SensorDeviceClass,
-    SensorEntity,
 )
 from homeassistant.config_entries import ConfigEntry
 from homeassistant.const import (
@@ -30,7 +30,6 @@ from homeassistant.core import HomeAssistant, callback
 import homeassistant.helpers.config_validation as cv
 from homeassistant.helpers.entity_platform import AddEntitiesCallback
 from homeassistant.helpers.event import async_track_point_in_utc_time
-from homeassistant.helpers.restore_state import RestoreEntity
 from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
 from homeassistant.util import dt as dt_util
 
@@ -144,7 +143,7 @@ async def _async_setup_entity(
     async_add_entities([MqttSensor(hass, config, config_entry, discovery_data)])
 
 
-class MqttSensor(MqttEntity, SensorEntity, RestoreEntity):
+class MqttSensor(MqttEntity, RestoreSensor):
     """Representation of a sensor that can be updated using MQTT."""
 
     _entity_id_format = ENTITY_ID_FORMAT
@@ -172,6 +171,8 @@ class MqttSensor(MqttEntity, SensorEntity, RestoreEntity):
             and expire_after > 0
             and (last_state := await self.async_get_last_state()) is not None
             and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE]
+            and (last_sensor_data := await self.async_get_last_sensor_data())
+            is not None
             # We might have set up a trigger already after subscribing from
             # super().async_added_to_hass(), then we should not restore state
             and not self._expiration_trigger
@@ -182,7 +183,7 @@ class MqttSensor(MqttEntity, SensorEntity, RestoreEntity):
                 _LOGGER.debug("Skip state recovery after reload for %s", self.entity_id)
                 return
             self._expired = False
-            self._state = last_state.state
+            self._state = last_sensor_data.native_value
 
             self._expiration_trigger = async_track_point_in_utc_time(
                 self.hass, self._value_is_expired, expiration_at
diff --git a/tests/components/mqtt/test_sensor.py b/tests/components/mqtt/test_sensor.py
index 8a1be6b11e2..b653e04c82e 100644
--- a/tests/components/mqtt/test_sensor.py
+++ b/tests/components/mqtt/test_sensor.py
@@ -2,13 +2,19 @@
 import copy
 from datetime import datetime, timedelta
 import json
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
 
 import pytest
 
 from homeassistant.components.mqtt.sensor import MQTT_SENSOR_ATTRIBUTES_BLOCKED
 import homeassistant.components.sensor as sensor
-from homeassistant.const import EVENT_STATE_CHANGED, STATE_UNAVAILABLE, STATE_UNKNOWN
+from homeassistant.const import (
+    EVENT_STATE_CHANGED,
+    STATE_UNAVAILABLE,
+    STATE_UNKNOWN,
+    TEMP_CELSIUS,
+    TEMP_FAHRENHEIT,
+)
 import homeassistant.core as ha
 from homeassistant.helpers import device_registry as dr
 from homeassistant.setup import async_setup_component
@@ -989,10 +995,15 @@ async def test_cleanup_triggers_and_restoring_state(
     config1["name"] = "test1"
     config1["expire_after"] = 30
     config1["state_topic"] = "test-topic1"
+    config1["device_class"] = "temperature"
+    config1["unit_of_measurement"] = TEMP_FAHRENHEIT
+
     config2 = copy.deepcopy(DEFAULT_CONFIG[domain])
     config2["name"] = "test2"
     config2["expire_after"] = 5
     config2["state_topic"] = "test-topic2"
+    config2["device_class"] = "temperature"
+    config2["unit_of_measurement"] = TEMP_CELSIUS
 
     freezer.move_to("2022-02-02 12:01:00+01:00")
 
@@ -1004,7 +1015,7 @@ async def test_cleanup_triggers_and_restoring_state(
     await hass.async_block_till_done()
     async_fire_mqtt_message(hass, "test-topic1", "100")
     state = hass.states.get("sensor.test1")
-    assert state.state == "100"
+    assert state.state == "38"  # 100 °F -> 38 °C
 
     async_fire_mqtt_message(hass, "test-topic2", "200")
     state = hass.states.get("sensor.test2")
@@ -1026,14 +1037,14 @@ async def test_cleanup_triggers_and_restoring_state(
     assert "State recovered after reload for sensor.test2" not in caplog.text
 
     state = hass.states.get("sensor.test1")
-    assert state.state == "100"
+    assert state.state == "38"  # 100 °F -> 38 °C
 
     state = hass.states.get("sensor.test2")
     assert state.state == STATE_UNAVAILABLE
 
-    async_fire_mqtt_message(hass, "test-topic1", "101")
+    async_fire_mqtt_message(hass, "test-topic1", "80")
     state = hass.states.get("sensor.test1")
-    assert state.state == "101"
+    assert state.state == "27"  # 80 °F -> 27 °C
 
     async_fire_mqtt_message(hass, "test-topic2", "201")
     state = hass.states.get("sensor.test2")
@@ -1057,10 +1068,16 @@ async def test_skip_restoring_state_with_over_due_expire_trigger(
         {},
         last_changed=datetime.fromisoformat("2022-02-02 12:01:35+01:00"),
     )
+    fake_extra_data = MagicMock()
     with patch(
         "homeassistant.helpers.restore_state.RestoreEntity.async_get_last_state",
         return_value=fake_state,
-    ), assert_setup_component(1, domain):
+    ), patch(
+        "homeassistant.helpers.restore_state.RestoreEntity.async_get_last_extra_data",
+        return_value=fake_extra_data,
+    ), assert_setup_component(
+        1, domain
+    ):
         assert await async_setup_component(hass, domain, {domain: config3})
         await hass.async_block_till_done()
     assert "Skip state recovery after reload for sensor.test3" in caplog.text
@@ -1087,4 +1104,5 @@ async def test_encoding_subscribable_topics(
         value,
         attribute,
         attribute_value,
+        skip_raw_test=True,
     )
-- 
GitLab