diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index e80dcfa802726370b3ce17a37ab71f74445d86ee..740a5a21a5f3c804e4e2672e46e10346477cabd1 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -10,6 +10,7 @@ from homeassistant.const import ( ATTR_NAME, CONF_ALIAS, CONF_ICON, + CONF_MODE, SERVICE_RELOAD, SERVICE_TOGGLE, SERVICE_TURN_OFF, @@ -21,7 +22,13 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.config_validation import make_entity_service_schema from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.script import Script +from homeassistant.helpers.script import ( + DEFAULT_QUEUE_MAX, + SCRIPT_MODE_CHOICES, + SCRIPT_MODE_LEGACY, + SCRIPT_MODE_QUEUE, + Script, +) from homeassistant.helpers.service import async_set_service_schema from homeassistant.loader import bind_hass @@ -37,11 +44,47 @@ CONF_DESCRIPTION = "description" CONF_EXAMPLE = "example" CONF_FIELDS = "fields" CONF_SEQUENCE = "sequence" +CONF_QUEUE_MAX = "queue_size" ENTITY_ID_FORMAT = DOMAIN + ".{}" EVENT_SCRIPT_STARTED = "script_started" + +def _deprecated_legacy_mode(config): + legacy_scripts = [] + for object_id, cfg in config.items(): + mode = cfg.get(CONF_MODE) + if mode is None: + legacy_scripts.append(object_id) + cfg[CONF_MODE] = SCRIPT_MODE_LEGACY + if legacy_scripts: + _LOGGER.warning( + "Script behavior has changed. " + "To continue using previous behavior, which is now deprecated, " + "add '%s: %s' to script(s): %s.", + CONF_MODE, + SCRIPT_MODE_LEGACY, + ", ".join(legacy_scripts), + ) + return config + + +def _queue_max(config): + for object_id, cfg in config.items(): + mode = cfg[CONF_MODE] + queue_max = cfg.get(CONF_QUEUE_MAX) + if mode == SCRIPT_MODE_QUEUE: + if queue_max is None: + cfg[CONF_QUEUE_MAX] = DEFAULT_QUEUE_MAX + elif queue_max is not None: + raise vol.Invalid( + f"{CONF_QUEUE_MAX} not valid with {mode} {CONF_MODE} " + f"for script '{object_id}'" + ) + return config + + SCRIPT_ENTRY_SCHEMA = vol.Schema( { vol.Optional(CONF_ALIAS): cv.string, @@ -54,11 +97,20 @@ SCRIPT_ENTRY_SCHEMA = vol.Schema( vol.Optional(CONF_EXAMPLE): cv.string, } }, + vol.Optional(CONF_MODE): vol.In(SCRIPT_MODE_CHOICES), + vol.Optional(CONF_QUEUE_MAX): vol.All(vol.Coerce(int), vol.Range(min=2)), } ) CONFIG_SCHEMA = vol.Schema( - {DOMAIN: cv.schema_with_slug_keys(SCRIPT_ENTRY_SCHEMA)}, extra=vol.ALLOW_EXTRA + { + DOMAIN: vol.All( + cv.schema_with_slug_keys(SCRIPT_ENTRY_SCHEMA), + _deprecated_legacy_mode, + _queue_max, + ) + }, + extra=vol.ALLOW_EXTRA, ) SCRIPT_SERVICE_SCHEMA = vol.Schema(dict) @@ -91,7 +143,7 @@ def scripts_with_entity(hass: HomeAssistant, entity_id: str) -> List[str]: @callback def entities_in_script(hass: HomeAssistant, entity_id: str) -> List[str]: - """Return all entities in a scene.""" + """Return all entities in script.""" if DOMAIN not in hass.data: return [] @@ -122,7 +174,7 @@ def scripts_with_device(hass: HomeAssistant, device_id: str) -> List[str]: @callback def devices_in_script(hass: HomeAssistant, entity_id: str) -> List[str]: - """Return all devices in a scene.""" + """Return all devices in script.""" if DOMAIN not in hass.data: return [] @@ -152,13 +204,16 @@ async def async_setup(hass, config): async def turn_on_service(service): """Call a service to turn script on.""" - # We could turn on script directly here, but we only want to offer - # one way to do it. Otherwise no easy way to detect invocations. - var = service.data.get(ATTR_VARIABLES) - for script in await component.async_extract_from_service(service): - await hass.services.async_call( - DOMAIN, script.object_id, var, context=service.context - ) + variables = service.data.get(ATTR_VARIABLES) + for script_entity in await component.async_extract_from_service(service): + if script_entity.script.is_legacy: + await hass.services.async_call( + DOMAIN, script_entity.object_id, variables, context=service.context + ) + else: + await script_entity.async_turn_on( + variables=variables, context=service.context, wait=False + ) async def turn_off_service(service): """Cancel a script.""" @@ -172,8 +227,8 @@ async def async_setup(hass, config): async def toggle_service(service): """Toggle a script.""" - for script in await component.async_extract_from_service(service): - await script.async_toggle(context=service.context) + for script_entity in await component.async_extract_from_service(service): + await script_entity.async_toggle(context=service.context) hass.services.async_register( DOMAIN, SERVICE_RELOAD, reload_service, schema=RELOAD_SERVICE_SCHEMA @@ -197,24 +252,40 @@ async def _async_process_config(hass, config, component): async def service_handler(service): """Execute a service call to script.<script name>.""" entity_id = ENTITY_ID_FORMAT.format(service.service) - script = component.get_entity(entity_id) - if script.is_on: + script_entity = component.get_entity(entity_id) + if script_entity.script.is_legacy and script_entity.is_on: _LOGGER.warning("Script %s already running.", entity_id) return - await script.async_turn_on(variables=service.data, context=service.context) + await script_entity.async_turn_on( + variables=service.data, context=service.context + ) - scripts = [] + script_entities = [] for object_id, cfg in config.get(DOMAIN, {}).items(): - scripts.append( + script_entities.append( ScriptEntity( hass, object_id, cfg.get(CONF_ALIAS, object_id), cfg.get(CONF_ICON), cfg[CONF_SEQUENCE], + cfg[CONF_MODE], + cfg.get(CONF_QUEUE_MAX, 0), ) ) + + await component.async_add_entities(script_entities) + + # Register services for all entities that were created successfully. + for script_entity in script_entities: + object_id = script_entity.object_id + if component.get_entity(script_entity.entity_id) is None: + _LOGGER.error("Couldn't load script %s", object_id) + continue + + cfg = config[DOMAIN][object_id] + hass.services.async_register( DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA ) @@ -226,22 +297,27 @@ async def _async_process_config(hass, config, component): } async_set_service_schema(hass, DOMAIN, object_id, service_desc) - await component.async_add_entities(scripts) - class ScriptEntity(ToggleEntity): """Representation of a script entity.""" icon = None - def __init__(self, hass, object_id, name, icon, sequence): + def __init__(self, hass, object_id, name, icon, sequence, mode, queue_max): """Initialize the script.""" self.object_id = object_id self.icon = icon self.entity_id = ENTITY_ID_FORMAT.format(object_id) self.script = Script( - hass, sequence, name, self.async_write_ha_state, logger=_LOGGER + hass, + sequence, + name, + self.async_change_listener, + mode, + queue_max, + logging.getLogger(f"{__name__}.{object_id}"), ) + self._changed = asyncio.Event() @property def should_poll(self): @@ -268,16 +344,37 @@ class ScriptEntity(ToggleEntity): """Return true if script is on.""" return self.script.is_running + @callback + def async_change_listener(self): + """Update state.""" + self.async_write_ha_state() + self._changed.set() + async def async_turn_on(self, **kwargs): """Turn the script on.""" + variables = kwargs.get("variables") context = kwargs.get("context") + wait = kwargs.get("wait", True) self.async_set_context(context) self.hass.bus.async_fire( EVENT_SCRIPT_STARTED, {ATTR_NAME: self.script.name, ATTR_ENTITY_ID: self.entity_id}, context=context, ) - await self.script.async_run(kwargs.get(ATTR_VARIABLES), context) + coro = self.script.async_run(variables, context) + if wait: + await coro + return + + # Caller does not want to wait for called script to finish so let script run in + # separate Task. However, wait for first state change so we can guarantee that + # it is written to the State Machine before we return. Only do this for + # non-legacy scripts, since legacy scripts don't necessarily change state + # immediately. + self._changed.clear() + self.hass.async_create_task(coro) + if not self.script.is_legacy: + await self._changed.wait() async def async_turn_off(self, **kwargs): """Turn script off.""" diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index bb7340a08da8e57a2bf2aa0decd89fced1b01fcb..8faa29363526e8da57d73611cd32d4dcb7922120 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -1,5 +1,6 @@ """The tests for the Script component.""" # pylint: disable=protected-access +import asyncio import unittest import pytest @@ -79,26 +80,6 @@ class TestScriptComponent(unittest.TestCase): """Stop down everything that was started.""" self.hass.stop() - def test_setup_with_invalid_configs(self): - """Test setup with invalid configs.""" - for value in ( - {"test": {}}, - {"test hello world": {"sequence": [{"event": "bla"}]}}, - { - "test": { - "sequence": { - "event": "test_event", - "service": "homeassistant.turn_on", - } - } - }, - ): - assert not setup_component( - self.hass, "script", {"script": value} - ), f"Script loaded with wrong config {value}" - - assert 0 == len(self.hass.states.entity_ids("script")) - def test_turn_on_service(self): """Verify that the turn_on service.""" event = "test_event" @@ -213,31 +194,60 @@ class TestScriptComponent(unittest.TestCase): assert calls[1].context is context assert calls[1].data["hello"] == "universe" - def test_reload_service(self): - """Verify that the turn_on service.""" - assert setup_component( - self.hass, - "script", - {"script": {"test": {"sequence": [{"delay": {"seconds": 5}}]}}}, - ) - assert self.hass.states.get(ENTITY_ID) is not None - assert self.hass.services.has_service(script.DOMAIN, "test") +invalid_configs = [ + {"test": {}}, + {"test hello world": {"sequence": [{"event": "bla"}]}}, + {"test": {"sequence": {"event": "test_event", "service": "homeassistant.turn_on"}}}, + {"test": {"sequence": [], "mode": "parallel", "queue_size": 5}}, +] - with patch( - "homeassistant.config.load_yaml_config_file", - return_value={ - "script": {"test2": {"sequence": [{"delay": {"seconds": 5}}]}} - }, - ): - reload(self.hass) - self.hass.block_till_done() - assert self.hass.states.get(ENTITY_ID) is None - assert not self.hass.services.has_service(script.DOMAIN, "test") +@pytest.mark.parametrize("value", invalid_configs) +async def test_setup_with_invalid_configs(hass, value): + """Test setup with invalid configs.""" + assert not await async_setup_component( + hass, "script", {"script": value} + ), f"Script loaded with wrong config {value}" + + assert 0 == len(hass.states.async_entity_ids("script")) + + +@pytest.mark.parametrize("running", ["no", "same", "different"]) +async def test_reload_service(hass, running): + """Verify the reload service.""" + assert await async_setup_component( + hass, "script", {"script": {"test": {"sequence": [{"delay": {"seconds": 5}}]}}} + ) + + assert hass.states.get(ENTITY_ID) is not None + assert hass.services.has_service(script.DOMAIN, "test") + + if running != "no": + _, object_id = split_entity_id(ENTITY_ID) + await hass.services.async_call(DOMAIN, object_id) + await hass.async_block_till_done() + + assert script.is_on(hass, ENTITY_ID) + + object_id = "test" if running == "same" else "test2" + with patch( + "homeassistant.config.load_yaml_config_file", + return_value={"script": {object_id: {"sequence": [{"delay": {"seconds": 5}}]}}}, + ): + await hass.services.async_call(DOMAIN, SERVICE_RELOAD, blocking=True) + await hass.async_block_till_done() - assert self.hass.states.get("script.test2") is not None - assert self.hass.services.has_service(script.DOMAIN, "test2") + if running != "same": + assert hass.states.get(ENTITY_ID) is None + assert not hass.services.has_service(script.DOMAIN, "test") + + assert hass.states.get("script.test2") is not None + assert hass.services.has_service(script.DOMAIN, "test2") + + else: + assert hass.states.get(ENTITY_ID) is not None + assert hass.services.has_service(script.DOMAIN, "test") async def test_service_descriptions(hass): @@ -449,7 +459,7 @@ async def test_extraction_functions(hass): } -async def test_config(hass): +async def test_config_basic(hass): """Test passing info in config.""" assert await async_setup_component( hass, @@ -470,6 +480,14 @@ async def test_config(hass): assert test_script.attributes["icon"] == "mdi:party" +async def test_config_legacy(hass, caplog): + """Test config defaulting to legacy mode.""" + assert await async_setup_component( + hass, "script", {"script": {"test_script": {"sequence": []}}} + ) + assert "To continue using previous behavior, which is now deprecated" in caplog.text + + async def test_logbook_humanify_script_started_event(hass): """Test humanifying script started event.""" hass.config.components.add("recorder") @@ -503,3 +521,89 @@ async def test_logbook_humanify_script_started_event(hass): assert event2["domain"] == "script" assert event2["message"] == "started" assert event2["entity_id"] == "script.bye" + + +@pytest.mark.parametrize("concurrently", [False, True]) +async def test_concurrent_script(hass, concurrently): + """Test calling script concurrently or not.""" + if concurrently: + call_script_2 = { + "service": "script.turn_on", + "data": {"entity_id": "script.script2"}, + } + else: + call_script_2 = {"service": "script.script2"} + assert await async_setup_component( + hass, + "script", + { + "script": { + "script1": { + "mode": "parallel", + "sequence": [ + call_script_2, + { + "wait_template": "{{ is_state('input_boolean.test1', 'on') }}" + }, + {"service": "test.script", "data": {"value": "script1"}}, + ], + }, + "script2": { + "mode": "parallel", + "sequence": [ + {"service": "test.script", "data": {"value": "script2a"}}, + { + "wait_template": "{{ is_state('input_boolean.test2', 'on') }}" + }, + {"service": "test.script", "data": {"value": "script2b"}}, + ], + }, + } + }, + ) + + service_called = asyncio.Event() + service_values = [] + + async def async_service_handler(service): + nonlocal service_values + service_values.append(service.data.get("value")) + service_called.set() + + hass.services.async_register("test", "script", async_service_handler) + hass.states.async_set("input_boolean.test1", "off") + hass.states.async_set("input_boolean.test2", "off") + + await hass.services.async_call("script", "script1") + await asyncio.wait_for(service_called.wait(), 1) + service_called.clear() + + assert "script2a" == service_values[-1] + assert script.is_on(hass, "script.script1") + assert script.is_on(hass, "script.script2") + + if not concurrently: + hass.states.async_set("input_boolean.test2", "on") + await asyncio.wait_for(service_called.wait(), 1) + service_called.clear() + + assert "script2b" == service_values[-1] + + hass.states.async_set("input_boolean.test1", "on") + await asyncio.wait_for(service_called.wait(), 1) + service_called.clear() + + assert "script1" == service_values[-1] + assert concurrently == script.is_on(hass, "script.script2") + + if concurrently: + hass.states.async_set("input_boolean.test2", "on") + await asyncio.wait_for(service_called.wait(), 1) + service_called.clear() + + assert "script2b" == service_values[-1] + + await hass.async_block_till_done() + + assert not script.is_on(hass, "script.script1") + assert not script.is_on(hass, "script.script2")