diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index b05445eff27dc557ca640ac0eb663f5682ce891c..5c7313f6716c1fa2be1d99e1f04e40cab2b83edc 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta import functools as ft import logging import sys -from typing import Callable, Container, Optional, Set, Union, cast +from typing import Callable, Container, List, Optional, Set, Union, cast from homeassistant.components import zone as zone_cmp from homeassistant.components.device_automation import ( @@ -263,7 +263,7 @@ def async_numeric_state_from_config( def state( hass: HomeAssistant, entity: Union[None, str, State], - req_state: str, + req_state: Union[str, List[str]], for_period: Optional[timedelta] = None, ) -> bool: """Test if state matches requirements. @@ -277,7 +277,10 @@ def state( return False assert isinstance(entity, State) - is_state = entity.state == req_state + if isinstance(req_state, str): + req_state = [req_state] + + is_state = entity.state in req_state if for_period is None or not is_state: return is_state @@ -292,13 +295,16 @@ def state_from_config( if config_validation: config = cv.STATE_CONDITION_SCHEMA(config) entity_ids = config.get(CONF_ENTITY_ID, []) - req_state = cast(str, config.get(CONF_STATE)) + req_states: Union[str, List[str]] = config.get(CONF_STATE, []) for_period = config.get("for") + if not isinstance(req_states, list): + req_states = [req_states] + def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Test if condition.""" return all( - state(hass, entity_id, req_state, for_period) for entity_id in entity_ids + state(hass, entity_id, req_states, for_period) for entity_id in entity_ids ) return if_state @@ -512,11 +518,17 @@ def zone_from_config( if config_validation: config = cv.ZONE_CONDITION_SCHEMA(config) entity_ids = config.get(CONF_ENTITY_ID, []) - zone_entity_id = config.get(CONF_ZONE) + zone_entity_ids = config.get(CONF_ZONE, []) def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Test if condition.""" - return all(zone(hass, zone_entity_id, entity_id) for entity_id in entity_ids) + return all( + any( + zone(hass, zone_entity_id, entity_id) + for zone_entity_id in zone_entity_ids + ) + for entity_id in entity_ids + ) return if_in_zone diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 69cc422da0abd7f6b45e085141b24776267d3391..24ba0d3c0f00df14c7fd3a265bc2928b2dd0969c 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -858,7 +858,7 @@ STATE_CONDITION_SCHEMA = vol.All( { vol.Required(CONF_CONDITION): "state", vol.Required(CONF_ENTITY_ID): entity_ids, - vol.Required(CONF_STATE): str, + vol.Required(CONF_STATE): vol.Any(str, [str]), vol.Optional(CONF_FOR): vol.All(time_period, positive_timedelta), # To support use_trigger_value in automation # Deprecated 2016/04/25 @@ -906,7 +906,7 @@ ZONE_CONDITION_SCHEMA = vol.Schema( { vol.Required(CONF_CONDITION): "zone", vol.Required(CONF_ENTITY_ID): entity_ids, - "zone": entity_id, + "zone": entity_ids, # To support use_trigger_value in automation # Deprecated 2016/04/25 vol.Optional("event"): vol.Any("enter", "leave"), diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index 5d81c1106350f7b8e21ae91cf67df61ac0774d27..b2cb1ff100c87779ec675649588960e83e92cce1 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -295,6 +295,32 @@ async def test_state_multiple_entities(hass): assert not test(hass) +async def test_multiple_states(hass): + """Test with multiple states in condition.""" + test = await condition.async_from_config( + hass, + { + "condition": "and", + "conditions": [ + { + "condition": "state", + "entity_id": "sensor.temperature", + "state": ["100", "200"], + }, + ], + }, + ) + + hass.states.async_set("sensor.temperature", 100) + assert test(hass) + + hass.states.async_set("sensor.temperature", 200) + assert test(hass) + + hass.states.async_set("sensor.temperature", 42) + assert not test(hass) + + async def test_numeric_state_multiple_entities(hass): """Test with multiple entities in condition.""" test = await condition.async_from_config( @@ -383,6 +409,55 @@ async def test_zone_multiple_entities(hass): assert not test(hass) +async def test_multiple_zones(hass): + """Test with multiple entities in condition.""" + test = await condition.async_from_config( + hass, + { + "condition": "and", + "conditions": [ + { + "condition": "zone", + "entity_id": "device_tracker.person", + "zone": ["zone.home", "zone.work"], + }, + ], + }, + ) + + hass.states.async_set( + "zone.home", + "zoning", + {"name": "home", "latitude": 2.1, "longitude": 1.1, "radius": 10}, + ) + hass.states.async_set( + "zone.work", + "zoning", + {"name": "work", "latitude": 20.1, "longitude": 10.1, "radius": 10}, + ) + + hass.states.async_set( + "device_tracker.person", + "home", + {"friendly_name": "person", "latitude": 2.1, "longitude": 1.1}, + ) + assert test(hass) + + hass.states.async_set( + "device_tracker.person", + "home", + {"friendly_name": "person", "latitude": 20.1, "longitude": 10.1}, + ) + assert test(hass) + + hass.states.async_set( + "device_tracker.person", + "home", + {"friendly_name": "person", "latitude": 50.1, "longitude": 20.1}, + ) + assert not test(hass) + + async def test_extract_entities(): """Test extracting entities.""" assert condition.async_extract_entities(