From b3403d7fca602603fe92c862b852ad25fcc64788 Mon Sep 17 00:00:00 2001
From: Jan Bouwhuis <jbouwh@users.noreply.github.com>
Date: Thu, 3 Nov 2022 13:06:53 +0100
Subject: [PATCH] Improve MQTT type hints part 3 (#80542)

* Improve typing debug_info

* Improve typing device_automation

* Improve typing device_trigger

* Improve typing fan

* Additional type hints device_trigger

* Set fan type hints to class level

* Cleanup and mypy

* Follow up and missed hint

* Follow up comment
---
 homeassistant/components/mqtt/debug_info.py   |   2 +-
 .../components/mqtt/device_automation.py      |  16 ++-
 .../components/mqtt/device_trigger.py         |  21 ++--
 homeassistant/components/mqtt/fan.py          | 109 ++++++++++--------
 4 files changed, 90 insertions(+), 58 deletions(-)

diff --git a/homeassistant/components/mqtt/debug_info.py b/homeassistant/components/mqtt/debug_info.py
index 5fae98eaea5..bdbdd74de96 100644
--- a/homeassistant/components/mqtt/debug_info.py
+++ b/homeassistant/components/mqtt/debug_info.py
@@ -28,7 +28,7 @@ def log_messages(
 
     debug_info_entities = get_mqtt_data(hass).debug_info_entities
 
-    def _log_message(msg):
+    def _log_message(msg: Any) -> None:
         """Log message."""
         messages = debug_info_entities[entity_id]["subscriptions"][
             msg.subscribed_topic
diff --git a/homeassistant/components/mqtt/device_automation.py b/homeassistant/components/mqtt/device_automation.py
index 0646a5bda0c..a1bc2cdeb3e 100644
--- a/homeassistant/components/mqtt/device_automation.py
+++ b/homeassistant/components/mqtt/device_automation.py
@@ -1,9 +1,14 @@
 """Provides device automations for MQTT."""
+from __future__ import annotations
+
 import functools
 
 import voluptuous as vol
 
+from homeassistant.config_entries import ConfigEntry
+from homeassistant.core import HomeAssistant
 import homeassistant.helpers.config_validation as cv
+from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
 
 from . import device_trigger
 from .config import MQTT_BASE_SCHEMA
@@ -20,14 +25,19 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
 ).extend(MQTT_BASE_SCHEMA.schema)
 
 
-async def async_setup_entry(hass, config_entry):
+async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
     """Set up MQTT device automation dynamically through MQTT discovery."""
 
     setup = functools.partial(_async_setup_automation, hass, config_entry=config_entry)
     await async_setup_entry_helper(hass, "device_automation", setup, PLATFORM_SCHEMA)
 
 
-async def _async_setup_automation(hass, config, config_entry, discovery_data):
+async def _async_setup_automation(
+    hass: HomeAssistant,
+    config: ConfigType,
+    config_entry: ConfigEntry,
+    discovery_data: DiscoveryInfoType,
+) -> None:
     """Set up an MQTT device automation."""
     if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER:
         await device_trigger.async_setup_trigger(
@@ -35,6 +45,6 @@ async def _async_setup_automation(hass, config, config_entry, discovery_data):
         )
 
 
-async def async_removed_from_device(hass, device_id):
+async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
     """Handle Mqtt removed from a device."""
     await device_trigger.async_removed_from_device(hass, device_id)
diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py
index f51731284cc..e8bcf1abc48 100644
--- a/homeassistant/components/mqtt/device_trigger.py
+++ b/homeassistant/components/mqtt/device_trigger.py
@@ -3,7 +3,7 @@ from __future__ import annotations
 
 from collections.abc import Callable
 import logging
-from typing import cast
+from typing import Any, cast
 
 import attr
 import voluptuous as vol
@@ -23,7 +23,7 @@ from homeassistant.exceptions import HomeAssistantError
 from homeassistant.helpers import config_validation as cv
 from homeassistant.helpers.dispatcher import async_dispatcher_send
 from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
-from homeassistant.helpers.typing import ConfigType
+from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
 
 from . import debug_info, trigger as mqtt_trigger
 from .config import MQTT_BASE_SCHEMA
@@ -35,7 +35,7 @@ from .const import (
     CONF_TOPIC,
     DOMAIN,
 )
-from .discovery import MQTT_DISCOVERY_DONE
+from .discovery import MQTT_DISCOVERY_DONE, MQTTDiscoveryPayload
 from .mixins import (
     MQTT_ENTITY_DEVICE_INFO_SCHEMA,
     MqttDiscoveryDeviceUpdate,
@@ -96,7 +96,7 @@ class TriggerInstance:
 
     async def async_attach_trigger(self) -> None:
         """Attach MQTT trigger."""
-        mqtt_config = {
+        mqtt_config: dict[str, Any] = {
             CONF_PLATFORM: DOMAIN,
             CONF_TOPIC: self.trigger.topic,
             CONF_ENCODING: DEFAULT_ENCODING,
@@ -123,7 +123,7 @@ class Trigger:
     """Device trigger settings."""
 
     device_id: str = attr.ib()
-    discovery_data: dict | None = attr.ib()
+    discovery_data: DiscoveryInfoType | None = attr.ib()
     hass: HomeAssistant = attr.ib()
     payload: str | None = attr.ib()
     qos: int | None = attr.ib()
@@ -193,7 +193,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
         hass: HomeAssistant,
         config: ConfigType,
         device_id: str,
-        discovery_data: dict,
+        discovery_data: DiscoveryInfoType,
         config_entry: ConfigEntry,
     ) -> None:
         """Initialize."""
@@ -237,7 +237,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
             self.hass, discovery_hash, self.discovery_data, self.device_id
         )
 
-    async def async_update(self, discovery_data: dict) -> None:
+    async def async_update(self, discovery_data: MQTTDiscoveryPayload) -> None:
         """Handle MQTT device trigger discovery updates."""
         discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
         discovery_id = discovery_hash[1]
@@ -261,11 +261,14 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
 
 
 async def async_setup_trigger(
-    hass, config: ConfigType, config_entry: ConfigEntry, discovery_data: dict
+    hass: HomeAssistant,
+    config: ConfigType,
+    config_entry: ConfigEntry,
+    discovery_data: DiscoveryInfoType,
 ) -> None:
     """Set up the MQTT device trigger."""
     config = TRIGGER_DISCOVERY_SCHEMA(config)
-    discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
+    discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
 
     if (device_id := update_device(hass, config_entry, config)) is None:
         async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
diff --git a/homeassistant/components/mqtt/fan.py b/homeassistant/components/mqtt/fan.py
index 1b6c3e425a4..8a65b909eb8 100644
--- a/homeassistant/components/mqtt/fan.py
+++ b/homeassistant/components/mqtt/fan.py
@@ -1,6 +1,7 @@
 """Support for MQTT fans."""
 from __future__ import annotations
 
+from collections.abc import Callable
 import functools
 import logging
 import math
@@ -27,6 +28,7 @@ from homeassistant.const import (
 from homeassistant.core import HomeAssistant, callback
 import homeassistant.helpers.config_validation as cv
 from homeassistant.helpers.entity_platform import AddEntitiesCallback
+from homeassistant.helpers.template import Template
 from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
 from homeassistant.util.percentage import (
     int_states_in_range,
@@ -54,7 +56,13 @@ from .mixins import (
     async_setup_platform_helper,
     warn_for_legacy_schema,
 )
-from .models import MqttCommandTemplate, MqttValueTemplate
+from .models import (
+    MqttCommandTemplate,
+    MqttValueTemplate,
+    PublishPayloadType,
+    ReceiveMessage,
+    ReceivePayloadType,
+)
 from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
 
 CONF_PERCENTAGE_STATE_TOPIC = "percentage_state_topic"
@@ -110,18 +118,18 @@ MQTT_FAN_ATTRIBUTES_BLOCKED = frozenset(
 _LOGGER = logging.getLogger(__name__)
 
 
-def valid_speed_range_configuration(config):
+def valid_speed_range_configuration(config: ConfigType) -> ConfigType:
     """Validate that the fan speed_range configuration is valid, throws if it isn't."""
-    if config.get(CONF_SPEED_RANGE_MIN) == 0:
+    if config[CONF_SPEED_RANGE_MIN] == 0:
         raise ValueError("speed_range_min must be > 0")
-    if config.get(CONF_SPEED_RANGE_MIN) >= config.get(CONF_SPEED_RANGE_MAX):
+    if config[CONF_SPEED_RANGE_MIN] >= config[CONF_SPEED_RANGE_MAX]:
         raise ValueError("speed_range_max must be > speed_range_min")
     return config
 
 
-def valid_preset_mode_configuration(config):
+def valid_preset_mode_configuration(config: ConfigType) -> ConfigType:
     """Validate that the preset mode reset payload is not one of the preset modes."""
-    if config.get(CONF_PAYLOAD_RESET_PRESET_MODE) in config.get(CONF_PRESET_MODES_LIST):
+    if config[CONF_PAYLOAD_RESET_PRESET_MODE] in config[CONF_PRESET_MODES_LIST]:
         raise ValueError("preset_modes must not contain payload_reset_preset_mode")
     return config
 
@@ -250,8 +258,8 @@ async def _async_setup_entity(
     hass: HomeAssistant,
     async_add_entities: AddEntitiesCallback,
     config: ConfigType,
-    config_entry: ConfigEntry | None = None,
-    discovery_data: dict | None = None,
+    config_entry: ConfigEntry,
+    discovery_data: DiscoveryInfoType | None = None,
 ) -> None:
     """Set up the MQTT fan."""
     async_add_entities([MqttFan(hass, config, config_entry, discovery_data)])
@@ -263,32 +271,41 @@ class MqttFan(MqttEntity, FanEntity):
     _entity_id_format = fan.ENTITY_ID_FORMAT
     _attributes_extra_blocked = MQTT_FAN_ATTRIBUTES_BLOCKED
 
-    def __init__(self, hass, config, config_entry, discovery_data):
+    _command_templates: dict[str, Callable[[PublishPayloadType], PublishPayloadType]]
+    _value_templates: dict[str, Callable[[ReceivePayloadType], ReceivePayloadType]]
+    _feature_percentage: bool
+    _feature_preset_mode: bool
+    _topic: dict[str, Any]
+    _optimistic: bool
+    _optimistic_oscillation: bool
+    _optimistic_percentage: bool
+    _optimistic_preset_mode: bool
+    _payload: dict[str, Any]
+    _speed_range: tuple[int, int]
+
+    def __init__(
+        self,
+        hass: HomeAssistant,
+        config: ConfigType,
+        config_entry: ConfigEntry,
+        discovery_data: DiscoveryInfoType | None,
+    ) -> None:
         """Initialize the MQTT fan."""
         self._attr_percentage = None
         self._attr_preset_mode = None
 
-        self._topic = None
-        self._payload = None
-        self._value_templates = None
-        self._command_templates = None
-        self._optimistic = None
-        self._optimistic_oscillation = None
-        self._optimistic_percentage = None
-        self._optimistic_preset_mode = None
-
         MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
 
     @staticmethod
-    def config_schema():
+    def config_schema() -> vol.Schema:
         """Return the config schema."""
         return DISCOVERY_SCHEMA
 
-    def _setup_from_config(self, config):
+    def _setup_from_config(self, config: ConfigType) -> None:
         """(Re)Setup the entity."""
         self._speed_range = (
-            config.get(CONF_SPEED_RANGE_MIN),
-            config.get(CONF_SPEED_RANGE_MAX),
+            config[CONF_SPEED_RANGE_MIN],
+            config[CONF_SPEED_RANGE_MAX],
         )
         self._topic = {
             key: config.get(key)
@@ -303,18 +320,6 @@ class MqttFan(MqttEntity, FanEntity):
                 CONF_OSCILLATION_COMMAND_TOPIC,
             )
         }
-        self._value_templates = {
-            CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE),
-            ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_VALUE_TEMPLATE),
-            ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_VALUE_TEMPLATE),
-            ATTR_OSCILLATING: config.get(CONF_OSCILLATION_VALUE_TEMPLATE),
-        }
-        self._command_templates = {
-            CONF_STATE: config.get(CONF_COMMAND_TEMPLATE),
-            ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_COMMAND_TEMPLATE),
-            ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_COMMAND_TEMPLATE),
-            ATTR_OSCILLATING: config.get(CONF_OSCILLATION_COMMAND_TEMPLATE),
-        }
         self._payload = {
             "STATE_ON": config[CONF_PAYLOAD_ON],
             "STATE_OFF": config[CONF_PAYLOAD_OFF],
@@ -359,24 +364,38 @@ class MqttFan(MqttEntity, FanEntity):
         if self._feature_preset_mode:
             self._attr_supported_features |= FanEntityFeature.PRESET_MODE
 
-        for key, tpl in self._command_templates.items():
+        command_templates: dict[str, Template | None] = {
+            CONF_STATE: config.get(CONF_COMMAND_TEMPLATE),
+            ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_COMMAND_TEMPLATE),
+            ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_COMMAND_TEMPLATE),
+            ATTR_OSCILLATING: config.get(CONF_OSCILLATION_COMMAND_TEMPLATE),
+        }
+        self._command_templates = {}
+        for key, tpl in command_templates.items():
             self._command_templates[key] = MqttCommandTemplate(
                 tpl, entity=self
             ).async_render
 
-        for key, tpl in self._value_templates.items():
+        self._value_templates = {}
+        value_templates: dict[str, Template | None] = {
+            CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE),
+            ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_VALUE_TEMPLATE),
+            ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_VALUE_TEMPLATE),
+            ATTR_OSCILLATING: config.get(CONF_OSCILLATION_VALUE_TEMPLATE),
+        }
+        for key, tpl in value_templates.items():
             self._value_templates[key] = MqttValueTemplate(
                 tpl,
                 entity=self,
             ).async_render_with_possible_json_value
 
-    def _prepare_subscribe_topics(self):
+    def _prepare_subscribe_topics(self) -> None:
         """(Re)Subscribe to topics."""
-        topics = {}
+        topics: dict[str, Any] = {}
 
         @callback
         @log_messages(self.hass, self.entity_id)
-        def state_received(msg):
+        def state_received(msg: ReceiveMessage) -> None:
             """Handle new received MQTT message."""
             payload = self._value_templates[CONF_STATE](msg.payload)
             if not payload:
@@ -400,7 +419,7 @@ class MqttFan(MqttEntity, FanEntity):
 
         @callback
         @log_messages(self.hass, self.entity_id)
-        def percentage_received(msg):
+        def percentage_received(msg: ReceiveMessage) -> None:
             """Handle new received MQTT message for the percentage."""
             rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
                 msg.payload
@@ -446,9 +465,9 @@ class MqttFan(MqttEntity, FanEntity):
 
         @callback
         @log_messages(self.hass, self.entity_id)
-        def preset_mode_received(msg):
+        def preset_mode_received(msg: ReceiveMessage) -> None:
             """Handle new received MQTT message for preset mode."""
-            preset_mode = self._value_templates[ATTR_PRESET_MODE](msg.payload)
+            preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
             if preset_mode == self._payload["PRESET_MODE_RESET"]:
                 self._attr_preset_mode = None
                 self.async_write_ha_state()
@@ -456,7 +475,7 @@ class MqttFan(MqttEntity, FanEntity):
             if not preset_mode:
                 _LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
                 return
-            if preset_mode not in self.preset_modes:
+            if not self.preset_modes or preset_mode not in self.preset_modes:
                 _LOGGER.warning(
                     "'%s' received on topic %s. '%s' is not a valid preset mode",
                     msg.payload,
@@ -479,7 +498,7 @@ class MqttFan(MqttEntity, FanEntity):
 
         @callback
         @log_messages(self.hass, self.entity_id)
-        def oscillation_received(msg):
+        def oscillation_received(msg: ReceiveMessage) -> None:
             """Handle new received MQTT message for the oscillation."""
             payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
             if not payload:
@@ -504,7 +523,7 @@ class MqttFan(MqttEntity, FanEntity):
             self.hass, self._sub_state, topics
         )
 
-    async def _subscribe_topics(self):
+    async def _subscribe_topics(self) -> None:
         """(Re)Subscribe to topics."""
         await subscription.async_subscribe_topics(self.hass, self._sub_state)
 
-- 
GitLab