From b3403d7fca602603fe92c862b852ad25fcc64788 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis <jbouwh@users.noreply.github.com> Date: Thu, 3 Nov 2022 13:06:53 +0100 Subject: [PATCH] Improve MQTT type hints part 3 (#80542) * Improve typing debug_info * Improve typing device_automation * Improve typing device_trigger * Improve typing fan * Additional type hints device_trigger * Set fan type hints to class level * Cleanup and mypy * Follow up and missed hint * Follow up comment --- homeassistant/components/mqtt/debug_info.py | 2 +- .../components/mqtt/device_automation.py | 16 ++- .../components/mqtt/device_trigger.py | 21 ++-- homeassistant/components/mqtt/fan.py | 109 ++++++++++-------- 4 files changed, 90 insertions(+), 58 deletions(-) diff --git a/homeassistant/components/mqtt/debug_info.py b/homeassistant/components/mqtt/debug_info.py index 5fae98eaea5..bdbdd74de96 100644 --- a/homeassistant/components/mqtt/debug_info.py +++ b/homeassistant/components/mqtt/debug_info.py @@ -28,7 +28,7 @@ def log_messages( debug_info_entities = get_mqtt_data(hass).debug_info_entities - def _log_message(msg): + def _log_message(msg: Any) -> None: """Log message.""" messages = debug_info_entities[entity_id]["subscriptions"][ msg.subscribed_topic diff --git a/homeassistant/components/mqtt/device_automation.py b/homeassistant/components/mqtt/device_automation.py index 0646a5bda0c..a1bc2cdeb3e 100644 --- a/homeassistant/components/mqtt/device_automation.py +++ b/homeassistant/components/mqtt/device_automation.py @@ -1,9 +1,14 @@ """Provides device automations for MQTT.""" +from __future__ import annotations + import functools import voluptuous as vol +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import device_trigger from .config import MQTT_BASE_SCHEMA @@ -20,14 +25,19 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( ).extend(MQTT_BASE_SCHEMA.schema) -async def async_setup_entry(hass, config_entry): +async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> None: """Set up MQTT device automation dynamically through MQTT discovery.""" setup = functools.partial(_async_setup_automation, hass, config_entry=config_entry) await async_setup_entry_helper(hass, "device_automation", setup, PLATFORM_SCHEMA) -async def _async_setup_automation(hass, config, config_entry, discovery_data): +async def _async_setup_automation( + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType, +) -> None: """Set up an MQTT device automation.""" if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER: await device_trigger.async_setup_trigger( @@ -35,6 +45,6 @@ async def _async_setup_automation(hass, config, config_entry, discovery_data): ) -async def async_removed_from_device(hass, device_id): +async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None: """Handle Mqtt removed from a device.""" await device_trigger.async_removed_from_device(hass, device_id) diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py index f51731284cc..e8bcf1abc48 100644 --- a/homeassistant/components/mqtt/device_trigger.py +++ b/homeassistant/components/mqtt/device_trigger.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Callable import logging -from typing import cast +from typing import Any, cast import attr import voluptuous as vol @@ -23,7 +23,7 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import debug_info, trigger as mqtt_trigger from .config import MQTT_BASE_SCHEMA @@ -35,7 +35,7 @@ from .const import ( CONF_TOPIC, DOMAIN, ) -from .discovery import MQTT_DISCOVERY_DONE +from .discovery import MQTT_DISCOVERY_DONE, MQTTDiscoveryPayload from .mixins import ( MQTT_ENTITY_DEVICE_INFO_SCHEMA, MqttDiscoveryDeviceUpdate, @@ -96,7 +96,7 @@ class TriggerInstance: async def async_attach_trigger(self) -> None: """Attach MQTT trigger.""" - mqtt_config = { + mqtt_config: dict[str, Any] = { CONF_PLATFORM: DOMAIN, CONF_TOPIC: self.trigger.topic, CONF_ENCODING: DEFAULT_ENCODING, @@ -123,7 +123,7 @@ class Trigger: """Device trigger settings.""" device_id: str = attr.ib() - discovery_data: dict | None = attr.ib() + discovery_data: DiscoveryInfoType | None = attr.ib() hass: HomeAssistant = attr.ib() payload: str | None = attr.ib() qos: int | None = attr.ib() @@ -193,7 +193,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): hass: HomeAssistant, config: ConfigType, device_id: str, - discovery_data: dict, + discovery_data: DiscoveryInfoType, config_entry: ConfigEntry, ) -> None: """Initialize.""" @@ -237,7 +237,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): self.hass, discovery_hash, self.discovery_data, self.device_id ) - async def async_update(self, discovery_data: dict) -> None: + async def async_update(self, discovery_data: MQTTDiscoveryPayload) -> None: """Handle MQTT device trigger discovery updates.""" discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] @@ -261,11 +261,14 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): async def async_setup_trigger( - hass, config: ConfigType, config_entry: ConfigEntry, discovery_data: dict + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType, ) -> None: """Set up the MQTT device trigger.""" config = TRIGGER_DISCOVERY_SCHEMA(config) - discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] + discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH] if (device_id := update_device(hass, config_entry, config)) is None: async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) diff --git a/homeassistant/components/mqtt/fan.py b/homeassistant/components/mqtt/fan.py index 1b6c3e425a4..8a65b909eb8 100644 --- a/homeassistant/components/mqtt/fan.py +++ b/homeassistant/components/mqtt/fan.py @@ -1,6 +1,7 @@ """Support for MQTT fans.""" from __future__ import annotations +from collections.abc import Callable import functools import logging import math @@ -27,6 +28,7 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.template import Template from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.util.percentage import ( int_states_in_range, @@ -54,7 +56,13 @@ from .mixins import ( async_setup_platform_helper, warn_for_legacy_schema, ) -from .models import MqttCommandTemplate, MqttValueTemplate +from .models import ( + MqttCommandTemplate, + MqttValueTemplate, + PublishPayloadType, + ReceiveMessage, + ReceivePayloadType, +) from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic CONF_PERCENTAGE_STATE_TOPIC = "percentage_state_topic" @@ -110,18 +118,18 @@ MQTT_FAN_ATTRIBUTES_BLOCKED = frozenset( _LOGGER = logging.getLogger(__name__) -def valid_speed_range_configuration(config): +def valid_speed_range_configuration(config: ConfigType) -> ConfigType: """Validate that the fan speed_range configuration is valid, throws if it isn't.""" - if config.get(CONF_SPEED_RANGE_MIN) == 0: + if config[CONF_SPEED_RANGE_MIN] == 0: raise ValueError("speed_range_min must be > 0") - if config.get(CONF_SPEED_RANGE_MIN) >= config.get(CONF_SPEED_RANGE_MAX): + if config[CONF_SPEED_RANGE_MIN] >= config[CONF_SPEED_RANGE_MAX]: raise ValueError("speed_range_max must be > speed_range_min") return config -def valid_preset_mode_configuration(config): +def valid_preset_mode_configuration(config: ConfigType) -> ConfigType: """Validate that the preset mode reset payload is not one of the preset modes.""" - if config.get(CONF_PAYLOAD_RESET_PRESET_MODE) in config.get(CONF_PRESET_MODES_LIST): + if config[CONF_PAYLOAD_RESET_PRESET_MODE] in config[CONF_PRESET_MODES_LIST]: raise ValueError("preset_modes must not contain payload_reset_preset_mode") return config @@ -250,8 +258,8 @@ async def _async_setup_entity( hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Set up the MQTT fan.""" async_add_entities([MqttFan(hass, config, config_entry, discovery_data)]) @@ -263,32 +271,41 @@ class MqttFan(MqttEntity, FanEntity): _entity_id_format = fan.ENTITY_ID_FORMAT _attributes_extra_blocked = MQTT_FAN_ATTRIBUTES_BLOCKED - def __init__(self, hass, config, config_entry, discovery_data): + _command_templates: dict[str, Callable[[PublishPayloadType], PublishPayloadType]] + _value_templates: dict[str, Callable[[ReceivePayloadType], ReceivePayloadType]] + _feature_percentage: bool + _feature_preset_mode: bool + _topic: dict[str, Any] + _optimistic: bool + _optimistic_oscillation: bool + _optimistic_percentage: bool + _optimistic_preset_mode: bool + _payload: dict[str, Any] + _speed_range: tuple[int, int] + + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: """Initialize the MQTT fan.""" self._attr_percentage = None self._attr_preset_mode = None - self._topic = None - self._payload = None - self._value_templates = None - self._command_templates = None - self._optimistic = None - self._optimistic_oscillation = None - self._optimistic_percentage = None - self._optimistic_preset_mode = None - MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" return DISCOVERY_SCHEMA - def _setup_from_config(self, config): + def _setup_from_config(self, config: ConfigType) -> None: """(Re)Setup the entity.""" self._speed_range = ( - config.get(CONF_SPEED_RANGE_MIN), - config.get(CONF_SPEED_RANGE_MAX), + config[CONF_SPEED_RANGE_MIN], + config[CONF_SPEED_RANGE_MAX], ) self._topic = { key: config.get(key) @@ -303,18 +320,6 @@ class MqttFan(MqttEntity, FanEntity): CONF_OSCILLATION_COMMAND_TOPIC, ) } - self._value_templates = { - CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE), - ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_VALUE_TEMPLATE), - ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_VALUE_TEMPLATE), - ATTR_OSCILLATING: config.get(CONF_OSCILLATION_VALUE_TEMPLATE), - } - self._command_templates = { - CONF_STATE: config.get(CONF_COMMAND_TEMPLATE), - ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_COMMAND_TEMPLATE), - ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_COMMAND_TEMPLATE), - ATTR_OSCILLATING: config.get(CONF_OSCILLATION_COMMAND_TEMPLATE), - } self._payload = { "STATE_ON": config[CONF_PAYLOAD_ON], "STATE_OFF": config[CONF_PAYLOAD_OFF], @@ -359,24 +364,38 @@ class MqttFan(MqttEntity, FanEntity): if self._feature_preset_mode: self._attr_supported_features |= FanEntityFeature.PRESET_MODE - for key, tpl in self._command_templates.items(): + command_templates: dict[str, Template | None] = { + CONF_STATE: config.get(CONF_COMMAND_TEMPLATE), + ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_COMMAND_TEMPLATE), + ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_COMMAND_TEMPLATE), + ATTR_OSCILLATING: config.get(CONF_OSCILLATION_COMMAND_TEMPLATE), + } + self._command_templates = {} + for key, tpl in command_templates.items(): self._command_templates[key] = MqttCommandTemplate( tpl, entity=self ).async_render - for key, tpl in self._value_templates.items(): + self._value_templates = {} + value_templates: dict[str, Template | None] = { + CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE), + ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_VALUE_TEMPLATE), + ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_VALUE_TEMPLATE), + ATTR_OSCILLATING: config.get(CONF_OSCILLATION_VALUE_TEMPLATE), + } + for key, tpl in value_templates.items(): self._value_templates[key] = MqttValueTemplate( tpl, entity=self, ).async_render_with_possible_json_value - def _prepare_subscribe_topics(self): + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - topics = {} + topics: dict[str, Any] = {} @callback @log_messages(self.hass, self.entity_id) - def state_received(msg): + def state_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message.""" payload = self._value_templates[CONF_STATE](msg.payload) if not payload: @@ -400,7 +419,7 @@ class MqttFan(MqttEntity, FanEntity): @callback @log_messages(self.hass, self.entity_id) - def percentage_received(msg): + def percentage_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message for the percentage.""" rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE]( msg.payload @@ -446,9 +465,9 @@ class MqttFan(MqttEntity, FanEntity): @callback @log_messages(self.hass, self.entity_id) - def preset_mode_received(msg): + def preset_mode_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message for preset mode.""" - preset_mode = self._value_templates[ATTR_PRESET_MODE](msg.payload) + preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload)) if preset_mode == self._payload["PRESET_MODE_RESET"]: self._attr_preset_mode = None self.async_write_ha_state() @@ -456,7 +475,7 @@ class MqttFan(MqttEntity, FanEntity): if not preset_mode: _LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic) return - if preset_mode not in self.preset_modes: + if not self.preset_modes or preset_mode not in self.preset_modes: _LOGGER.warning( "'%s' received on topic %s. '%s' is not a valid preset mode", msg.payload, @@ -479,7 +498,7 @@ class MqttFan(MqttEntity, FanEntity): @callback @log_messages(self.hass, self.entity_id) - def oscillation_received(msg): + def oscillation_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message for the oscillation.""" payload = self._value_templates[ATTR_OSCILLATING](msg.payload) if not payload: @@ -504,7 +523,7 @@ class MqttFan(MqttEntity, FanEntity): self.hass, self._sub_state, topics ) - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await subscription.async_subscribe_topics(self.hass, self._sub_state) -- GitLab