diff --git a/homeassistant/components/device_automation/__init__.py b/homeassistant/components/device_automation/__init__.py index 603e88fd8c8e038b217f776d4d4d0e908f30d3d7..5cbc8a1e678c12ad3792a5603b09c1ae7efc75a2 100644 --- a/homeassistant/components/device_automation/__init__.py +++ b/homeassistant/components/device_automation/__init__.py @@ -7,14 +7,14 @@ from enum import Enum from functools import wraps import logging from types import ModuleType -from typing import Any, NamedTuple +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Protocol, Union, overload import voluptuous as vol import voluptuous_serialize from homeassistant.components import websocket_api from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM -from homeassistant.core import HomeAssistant +from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant from homeassistant.helpers import ( config_validation as cv, device_registry as dr, @@ -27,6 +27,13 @@ from homeassistant.requirements import async_get_integration_with_requirements from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig +if TYPE_CHECKING: + from homeassistant.components.automation import ( + AutomationActionType, + AutomationTriggerInfo, + ) + from homeassistant.helpers import condition + # mypy: allow-untyped-calls, allow-untyped-defs DOMAIN = "device_automation" @@ -76,6 +83,77 @@ TYPES = { } +class DeviceAutomationTriggerProtocol(Protocol): + """Define the format of device_trigger modules. + + Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config. + """ + + TRIGGER_SCHEMA: vol.Schema + + async def async_validate_trigger_config( + self, hass: HomeAssistant, config: ConfigType + ) -> ConfigType: + """Validate config.""" + raise NotImplementedError + + async def async_attach_trigger( + self, + hass: HomeAssistant, + config: ConfigType, + action: AutomationActionType, + automation_info: AutomationTriggerInfo, + ) -> CALLBACK_TYPE: + """Attach a trigger.""" + raise NotImplementedError + + +class DeviceAutomationConditionProtocol(Protocol): + """Define the format of device_condition modules. + + Each module must define either CONDITION_SCHEMA or async_validate_condition_config. + """ + + CONDITION_SCHEMA: vol.Schema + + async def async_validate_condition_config( + self, hass: HomeAssistant, config: ConfigType + ) -> ConfigType: + """Validate config.""" + raise NotImplementedError + + def async_condition_from_config( + self, hass: HomeAssistant, config: ConfigType + ) -> condition.ConditionCheckerType: + """Evaluate state based on configuration.""" + raise NotImplementedError + + +class DeviceAutomationActionProtocol(Protocol): + """Define the format of device_action modules. + + Each module must define either ACTION_SCHEMA or async_validate_action_config. + """ + + ACTION_SCHEMA: vol.Schema + + async def async_validate_action_config( + self, hass: HomeAssistant, config: ConfigType + ) -> ConfigType: + """Validate config.""" + raise NotImplementedError + + async def async_call_action_from_config( + self, + hass: HomeAssistant, + config: ConfigType, + variables: dict[str, Any], + context: Context | None, + ) -> None: + """Execute a device action.""" + raise NotImplementedError + + @bind_hass async def async_get_device_automations( hass: HomeAssistant, @@ -115,9 +193,51 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True +DeviceAutomationPlatformType = Union[ + ModuleType, + DeviceAutomationTriggerProtocol, + DeviceAutomationConditionProtocol, + DeviceAutomationActionProtocol, +] + + +@overload +async def async_get_device_automation_platform( # noqa: D103 + hass: HomeAssistant, + domain: str, + automation_type: Literal[DeviceAutomationType.TRIGGER], +) -> DeviceAutomationTriggerProtocol: + ... + + +@overload +async def async_get_device_automation_platform( # noqa: D103 + hass: HomeAssistant, + domain: str, + automation_type: Literal[DeviceAutomationType.CONDITION], +) -> DeviceAutomationConditionProtocol: + ... + + +@overload +async def async_get_device_automation_platform( # noqa: D103 + hass: HomeAssistant, + domain: str, + automation_type: Literal[DeviceAutomationType.ACTION], +) -> DeviceAutomationActionProtocol: + ... + + +@overload +async def async_get_device_automation_platform( # noqa: D103 + hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str +) -> DeviceAutomationPlatformType: + ... + + async def async_get_device_automation_platform( hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str -) -> ModuleType: +) -> DeviceAutomationPlatformType: """Load device automation platform for integration. Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation. diff --git a/homeassistant/components/device_automation/trigger.py b/homeassistant/components/device_automation/trigger.py index 008a7603dba608ad674a057ab6b58b827a017aa7..f2962d6544e5dfaa48318b5bfb54ba71252c2218 100644 --- a/homeassistant/components/device_automation/trigger.py +++ b/homeassistant/components/device_automation/trigger.py @@ -1,7 +1,15 @@ """Offer device oriented automation.""" +from typing import cast + import voluptuous as vol +from homeassistant.components.automation import ( + AutomationActionType, + AutomationTriggerInfo, +) from homeassistant.const import CONF_DOMAIN +from homeassistant.core import CALLBACK_TYPE, HomeAssistant +from homeassistant.helpers.typing import ConfigType from . import ( DEVICE_TRIGGER_BASE_SCHEMA, @@ -10,26 +18,31 @@ from . import ( ) from .exceptions import InvalidDeviceAutomationConfig -# mypy: allow-untyped-defs, no-check-untyped-defs - TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) -async def async_validate_trigger_config(hass, config): +async def async_validate_trigger_config( + hass: HomeAssistant, config: ConfigType +) -> ConfigType: """Validate config.""" platform = await async_get_device_automation_platform( hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER ) if not hasattr(platform, "async_validate_trigger_config"): - return platform.TRIGGER_SCHEMA(config) + return cast(ConfigType, platform.TRIGGER_SCHEMA(config)) try: - return await getattr(platform, "async_validate_trigger_config")(hass, config) + return await platform.async_validate_trigger_config(hass, config) except InvalidDeviceAutomationConfig as err: raise vol.Invalid(str(err) or "Invalid trigger configuration") from err -async def async_attach_trigger(hass, config, action, automation_info): +async def async_attach_trigger( + hass: HomeAssistant, + config: ConfigType, + action: AutomationActionType, + automation_info: AutomationTriggerInfo, +) -> CALLBACK_TYPE: """Listen for trigger.""" platform = await async_get_device_automation_platform( hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 80bed9137d0a0bd64a72221b33230656c84dd823..06853dd945064f5afef44875025d84e15c5c47f7 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -875,12 +875,7 @@ async def async_device_from_config( platform = await async_get_device_automation_platform( hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION ) - return trace_condition_function( - cast( - ConditionCheckerType, - platform.async_condition_from_config(hass, config), - ) - ) + return trace_condition_function(platform.async_condition_from_config(hass, config)) async def async_trigger_from_config( @@ -943,14 +938,15 @@ async def async_validate_condition_config( hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION ) if hasattr(platform, "async_validate_condition_config"): - return await platform.async_validate_condition_config(hass, config) # type: ignore + return await platform.async_validate_condition_config(hass, config) return cast(ConfigType, platform.CONDITION_SCHEMA(config)) if condition in ("numeric_state", "state"): - validator = getattr( - sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition) + validator = cast( + Callable[[HomeAssistant, ConfigType], ConfigType], + getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)), ) - return validator(hass, config) # type: ignore + return validator(hass, config) return config