Skip to content
Snippets Groups Projects
Unverified Commit b3403d7f authored by Jan Bouwhuis's avatar Jan Bouwhuis Committed by GitHub
Browse files

Improve MQTT type hints part 3 (#80542)

* Improve typing debug_info

* Improve typing device_automation

* Improve typing device_trigger

* Improve typing fan

* Additional type hints device_trigger

* Set fan type hints to class level

* Cleanup and mypy

* Follow up and missed hint

* Follow up comment
parent dcd1ab7e
No related branches found
No related tags found
No related merge requests found
...@@ -28,7 +28,7 @@ def log_messages( ...@@ -28,7 +28,7 @@ def log_messages(
debug_info_entities = get_mqtt_data(hass).debug_info_entities debug_info_entities = get_mqtt_data(hass).debug_info_entities
def _log_message(msg): def _log_message(msg: Any) -> None:
"""Log message.""" """Log message."""
messages = debug_info_entities[entity_id]["subscriptions"][ messages = debug_info_entities[entity_id]["subscriptions"][
msg.subscribed_topic msg.subscribed_topic
......
"""Provides device automations for MQTT.""" """Provides device automations for MQTT."""
from __future__ import annotations
import functools import functools
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import device_trigger from . import device_trigger
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
...@@ -20,14 +25,19 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( ...@@ -20,14 +25,19 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
).extend(MQTT_BASE_SCHEMA.schema) ).extend(MQTT_BASE_SCHEMA.schema)
async def async_setup_entry(hass, config_entry): async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Set up MQTT device automation dynamically through MQTT discovery.""" """Set up MQTT device automation dynamically through MQTT discovery."""
setup = functools.partial(_async_setup_automation, hass, config_entry=config_entry) setup = functools.partial(_async_setup_automation, hass, config_entry=config_entry)
await async_setup_entry_helper(hass, "device_automation", setup, PLATFORM_SCHEMA) await async_setup_entry_helper(hass, "device_automation", setup, PLATFORM_SCHEMA)
async def _async_setup_automation(hass, config, config_entry, discovery_data): async def _async_setup_automation(
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType,
) -> None:
"""Set up an MQTT device automation.""" """Set up an MQTT device automation."""
if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER: if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER:
await device_trigger.async_setup_trigger( await device_trigger.async_setup_trigger(
...@@ -35,6 +45,6 @@ async def _async_setup_automation(hass, config, config_entry, discovery_data): ...@@ -35,6 +45,6 @@ async def _async_setup_automation(hass, config, config_entry, discovery_data):
) )
async def async_removed_from_device(hass, device_id): async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
"""Handle Mqtt removed from a device.""" """Handle Mqtt removed from a device."""
await device_trigger.async_removed_from_device(hass, device_id) await device_trigger.async_removed_from_device(hass, device_id)
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import logging import logging
from typing import cast from typing import Any, cast
import attr import attr
import voluptuous as vol import voluptuous as vol
...@@ -23,7 +23,7 @@ from homeassistant.exceptions import HomeAssistantError ...@@ -23,7 +23,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import debug_info, trigger as mqtt_trigger from . import debug_info, trigger as mqtt_trigger
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
...@@ -35,7 +35,7 @@ from .const import ( ...@@ -35,7 +35,7 @@ from .const import (
CONF_TOPIC, CONF_TOPIC,
DOMAIN, DOMAIN,
) )
from .discovery import MQTT_DISCOVERY_DONE from .discovery import MQTT_DISCOVERY_DONE, MQTTDiscoveryPayload
from .mixins import ( from .mixins import (
MQTT_ENTITY_DEVICE_INFO_SCHEMA, MQTT_ENTITY_DEVICE_INFO_SCHEMA,
MqttDiscoveryDeviceUpdate, MqttDiscoveryDeviceUpdate,
...@@ -96,7 +96,7 @@ class TriggerInstance: ...@@ -96,7 +96,7 @@ class TriggerInstance:
async def async_attach_trigger(self) -> None: async def async_attach_trigger(self) -> None:
"""Attach MQTT trigger.""" """Attach MQTT trigger."""
mqtt_config = { mqtt_config: dict[str, Any] = {
CONF_PLATFORM: DOMAIN, CONF_PLATFORM: DOMAIN,
CONF_TOPIC: self.trigger.topic, CONF_TOPIC: self.trigger.topic,
CONF_ENCODING: DEFAULT_ENCODING, CONF_ENCODING: DEFAULT_ENCODING,
...@@ -123,7 +123,7 @@ class Trigger: ...@@ -123,7 +123,7 @@ class Trigger:
"""Device trigger settings.""" """Device trigger settings."""
device_id: str = attr.ib() device_id: str = attr.ib()
discovery_data: dict | None = attr.ib() discovery_data: DiscoveryInfoType | None = attr.ib()
hass: HomeAssistant = attr.ib() hass: HomeAssistant = attr.ib()
payload: str | None = attr.ib() payload: str | None = attr.ib()
qos: int | None = attr.ib() qos: int | None = attr.ib()
...@@ -193,7 +193,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): ...@@ -193,7 +193,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
hass: HomeAssistant, hass: HomeAssistant,
config: ConfigType, config: ConfigType,
device_id: str, device_id: str,
discovery_data: dict, discovery_data: DiscoveryInfoType,
config_entry: ConfigEntry, config_entry: ConfigEntry,
) -> None: ) -> None:
"""Initialize.""" """Initialize."""
...@@ -237,7 +237,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): ...@@ -237,7 +237,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
self.hass, discovery_hash, self.discovery_data, self.device_id self.hass, discovery_hash, self.discovery_data, self.device_id
) )
async def async_update(self, discovery_data: dict) -> None: async def async_update(self, discovery_data: MQTTDiscoveryPayload) -> None:
"""Handle MQTT device trigger discovery updates.""" """Handle MQTT device trigger discovery updates."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1] discovery_id = discovery_hash[1]
...@@ -261,11 +261,14 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): ...@@ -261,11 +261,14 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
async def async_setup_trigger( async def async_setup_trigger(
hass, config: ConfigType, config_entry: ConfigEntry, discovery_data: dict hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType,
) -> None: ) -> None:
"""Set up the MQTT device trigger.""" """Set up the MQTT device trigger."""
config = TRIGGER_DISCOVERY_SCHEMA(config) config = TRIGGER_DISCOVERY_SCHEMA(config)
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
if (device_id := update_device(hass, config_entry, config)) is None: if (device_id := update_device(hass, config_entry, config)) is None:
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
......
"""Support for MQTT fans.""" """Support for MQTT fans."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import functools import functools
import logging import logging
import math import math
...@@ -27,6 +28,7 @@ from homeassistant.const import ( ...@@ -27,6 +28,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util.percentage import ( from homeassistant.util.percentage import (
int_states_in_range, int_states_in_range,
...@@ -54,7 +56,13 @@ from .mixins import ( ...@@ -54,7 +56,13 @@ from .mixins import (
async_setup_platform_helper, async_setup_platform_helper,
warn_for_legacy_schema, warn_for_legacy_schema,
) )
from .models import MqttCommandTemplate, MqttValueTemplate from .models import (
MqttCommandTemplate,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
CONF_PERCENTAGE_STATE_TOPIC = "percentage_state_topic" CONF_PERCENTAGE_STATE_TOPIC = "percentage_state_topic"
...@@ -110,18 +118,18 @@ MQTT_FAN_ATTRIBUTES_BLOCKED = frozenset( ...@@ -110,18 +118,18 @@ MQTT_FAN_ATTRIBUTES_BLOCKED = frozenset(
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def valid_speed_range_configuration(config): def valid_speed_range_configuration(config: ConfigType) -> ConfigType:
"""Validate that the fan speed_range configuration is valid, throws if it isn't.""" """Validate that the fan speed_range configuration is valid, throws if it isn't."""
if config.get(CONF_SPEED_RANGE_MIN) == 0: if config[CONF_SPEED_RANGE_MIN] == 0:
raise ValueError("speed_range_min must be > 0") raise ValueError("speed_range_min must be > 0")
if config.get(CONF_SPEED_RANGE_MIN) >= config.get(CONF_SPEED_RANGE_MAX): if config[CONF_SPEED_RANGE_MIN] >= config[CONF_SPEED_RANGE_MAX]:
raise ValueError("speed_range_max must be > speed_range_min") raise ValueError("speed_range_max must be > speed_range_min")
return config return config
def valid_preset_mode_configuration(config): def valid_preset_mode_configuration(config: ConfigType) -> ConfigType:
"""Validate that the preset mode reset payload is not one of the preset modes.""" """Validate that the preset mode reset payload is not one of the preset modes."""
if config.get(CONF_PAYLOAD_RESET_PRESET_MODE) in config.get(CONF_PRESET_MODES_LIST): if config[CONF_PAYLOAD_RESET_PRESET_MODE] in config[CONF_PRESET_MODES_LIST]:
raise ValueError("preset_modes must not contain payload_reset_preset_mode") raise ValueError("preset_modes must not contain payload_reset_preset_mode")
return config return config
...@@ -250,8 +258,8 @@ async def _async_setup_entity( ...@@ -250,8 +258,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT fan.""" """Set up the MQTT fan."""
async_add_entities([MqttFan(hass, config, config_entry, discovery_data)]) async_add_entities([MqttFan(hass, config, config_entry, discovery_data)])
...@@ -263,32 +271,41 @@ class MqttFan(MqttEntity, FanEntity): ...@@ -263,32 +271,41 @@ class MqttFan(MqttEntity, FanEntity):
_entity_id_format = fan.ENTITY_ID_FORMAT _entity_id_format = fan.ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_FAN_ATTRIBUTES_BLOCKED _attributes_extra_blocked = MQTT_FAN_ATTRIBUTES_BLOCKED
def __init__(self, hass, config, config_entry, discovery_data): _command_templates: dict[str, Callable[[PublishPayloadType], PublishPayloadType]]
_value_templates: dict[str, Callable[[ReceivePayloadType], ReceivePayloadType]]
_feature_percentage: bool
_feature_preset_mode: bool
_topic: dict[str, Any]
_optimistic: bool
_optimistic_oscillation: bool
_optimistic_percentage: bool
_optimistic_preset_mode: bool
_payload: dict[str, Any]
_speed_range: tuple[int, int]
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the MQTT fan.""" """Initialize the MQTT fan."""
self._attr_percentage = None self._attr_percentage = None
self._attr_preset_mode = None self._attr_preset_mode = None
self._topic = None
self._payload = None
self._value_templates = None
self._command_templates = None
self._optimistic = None
self._optimistic_oscillation = None
self._optimistic_percentage = None
self._optimistic_preset_mode = None
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _setup_from_config(self, config): def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity.""" """(Re)Setup the entity."""
self._speed_range = ( self._speed_range = (
config.get(CONF_SPEED_RANGE_MIN), config[CONF_SPEED_RANGE_MIN],
config.get(CONF_SPEED_RANGE_MAX), config[CONF_SPEED_RANGE_MAX],
) )
self._topic = { self._topic = {
key: config.get(key) key: config.get(key)
...@@ -303,18 +320,6 @@ class MqttFan(MqttEntity, FanEntity): ...@@ -303,18 +320,6 @@ class MqttFan(MqttEntity, FanEntity):
CONF_OSCILLATION_COMMAND_TOPIC, CONF_OSCILLATION_COMMAND_TOPIC,
) )
} }
self._value_templates = {
CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE),
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_VALUE_TEMPLATE),
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_VALUE_TEMPLATE),
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_VALUE_TEMPLATE),
}
self._command_templates = {
CONF_STATE: config.get(CONF_COMMAND_TEMPLATE),
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_COMMAND_TEMPLATE),
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_COMMAND_TEMPLATE),
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_COMMAND_TEMPLATE),
}
self._payload = { self._payload = {
"STATE_ON": config[CONF_PAYLOAD_ON], "STATE_ON": config[CONF_PAYLOAD_ON],
"STATE_OFF": config[CONF_PAYLOAD_OFF], "STATE_OFF": config[CONF_PAYLOAD_OFF],
...@@ -359,24 +364,38 @@ class MqttFan(MqttEntity, FanEntity): ...@@ -359,24 +364,38 @@ class MqttFan(MqttEntity, FanEntity):
if self._feature_preset_mode: if self._feature_preset_mode:
self._attr_supported_features |= FanEntityFeature.PRESET_MODE self._attr_supported_features |= FanEntityFeature.PRESET_MODE
for key, tpl in self._command_templates.items(): command_templates: dict[str, Template | None] = {
CONF_STATE: config.get(CONF_COMMAND_TEMPLATE),
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_COMMAND_TEMPLATE),
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_COMMAND_TEMPLATE),
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_COMMAND_TEMPLATE),
}
self._command_templates = {}
for key, tpl in command_templates.items():
self._command_templates[key] = MqttCommandTemplate( self._command_templates[key] = MqttCommandTemplate(
tpl, entity=self tpl, entity=self
).async_render ).async_render
for key, tpl in self._value_templates.items(): self._value_templates = {}
value_templates: dict[str, Template | None] = {
CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE),
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_VALUE_TEMPLATE),
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_VALUE_TEMPLATE),
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_VALUE_TEMPLATE),
}
for key, tpl in value_templates.items():
self._value_templates[key] = MqttValueTemplate( self._value_templates[key] = MqttValueTemplate(
tpl, tpl,
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics: dict[str, Any] = {}
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def state_received(msg): def state_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message.""" """Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload) payload = self._value_templates[CONF_STATE](msg.payload)
if not payload: if not payload:
...@@ -400,7 +419,7 @@ class MqttFan(MqttEntity, FanEntity): ...@@ -400,7 +419,7 @@ class MqttFan(MqttEntity, FanEntity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def percentage_received(msg): def percentage_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the percentage.""" """Handle new received MQTT message for the percentage."""
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE]( rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
msg.payload msg.payload
...@@ -446,9 +465,9 @@ class MqttFan(MqttEntity, FanEntity): ...@@ -446,9 +465,9 @@ class MqttFan(MqttEntity, FanEntity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def preset_mode_received(msg): def preset_mode_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for preset mode.""" """Handle new received MQTT message for preset mode."""
preset_mode = self._value_templates[ATTR_PRESET_MODE](msg.payload) preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
if preset_mode == self._payload["PRESET_MODE_RESET"]: if preset_mode == self._payload["PRESET_MODE_RESET"]:
self._attr_preset_mode = None self._attr_preset_mode = None
self.async_write_ha_state() self.async_write_ha_state()
...@@ -456,7 +475,7 @@ class MqttFan(MqttEntity, FanEntity): ...@@ -456,7 +475,7 @@ class MqttFan(MqttEntity, FanEntity):
if not preset_mode: if not preset_mode:
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic) _LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
return return
if preset_mode not in self.preset_modes: if not self.preset_modes or preset_mode not in self.preset_modes:
_LOGGER.warning( _LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid preset mode", "'%s' received on topic %s. '%s' is not a valid preset mode",
msg.payload, msg.payload,
...@@ -479,7 +498,7 @@ class MqttFan(MqttEntity, FanEntity): ...@@ -479,7 +498,7 @@ class MqttFan(MqttEntity, FanEntity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def oscillation_received(msg): def oscillation_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the oscillation.""" """Handle new received MQTT message for the oscillation."""
payload = self._value_templates[ATTR_OSCILLATING](msg.payload) payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
if not payload: if not payload:
...@@ -504,7 +523,7 @@ class MqttFan(MqttEntity, FanEntity): ...@@ -504,7 +523,7 @@ class MqttFan(MqttEntity, FanEntity):
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
) )
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment