diff --git a/homeassistant/components/python_script/__init__.py b/homeassistant/components/python_script/__init__.py index 098603b94941aa723259797d397ba8a9649c0685..7b49a6b1b0d173800649dca87e5d4000932c692c 100644 --- a/homeassistant/components/python_script/__init__.py +++ b/homeassistant/components/python_script/__init__.py @@ -20,8 +20,13 @@ from RestrictedPython.Guards import ( import voluptuous as vol from homeassistant.const import CONF_DESCRIPTION, CONF_NAME, SERVICE_RELOAD -from homeassistant.core import HomeAssistant, ServiceCall -from homeassistant.exceptions import HomeAssistantError +from homeassistant.core import ( + HomeAssistant, + ServiceCall, + ServiceResponse, + SupportsResponse, +) +from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass @@ -107,9 +112,9 @@ def discover_scripts(hass): _LOGGER.warning("Folder %s not found in configuration folder", FOLDER) return False - def python_script_service_handler(call: ServiceCall) -> None: + def python_script_service_handler(call: ServiceCall) -> ServiceResponse: """Handle python script service calls.""" - execute_script(hass, call.service, call.data) + return execute_script(hass, call.service, call.data, call.return_response) existing = hass.services.services.get(DOMAIN, {}).keys() for existing_service in existing: @@ -126,7 +131,12 @@ def discover_scripts(hass): for fil in glob.iglob(os.path.join(path, "*.py")): name = os.path.splitext(os.path.basename(fil))[0] - hass.services.register(DOMAIN, name, python_script_service_handler) + hass.services.register( + DOMAIN, + name, + python_script_service_handler, + supports_response=SupportsResponse.OPTIONAL, + ) service_desc = { CONF_NAME: services_dict.get(name, {}).get("name", name), @@ -137,17 +147,17 @@ def discover_scripts(hass): @bind_hass -def execute_script(hass, name, data=None): +def execute_script(hass, name, data=None, return_response=False): """Execute a script.""" filename = f"{name}.py" raise_if_invalid_filename(filename) with open(hass.config.path(FOLDER, filename), encoding="utf8") as fil: source = fil.read() - execute(hass, filename, source, data) + return execute(hass, filename, source, data, return_response=return_response) @bind_hass -def execute(hass, filename, source, data=None): +def execute(hass, filename, source, data=None, return_response=False): """Execute Python source.""" compiled = compile_restricted_exec(source, filename=filename) @@ -216,16 +226,39 @@ def execute(hass, filename, source, data=None): "hass": hass, "data": data or {}, "logger": logger, + "output": {}, } try: _LOGGER.info("Executing %s: %s", filename, data) # pylint: disable-next=exec-used exec(compiled.code, restricted_globals) # noqa: S102 + _LOGGER.debug( + "Output of python_script: `%s`:\n%s", + filename, + restricted_globals["output"], + ) + # Ensure that we're always returning a dictionary + if not isinstance(restricted_globals["output"], dict): + output_type = type(restricted_globals["output"]) + restricted_globals["output"] = {} + raise ScriptError( + f"Expected `output` to be a dictionary, was {output_type}" + ) except ScriptError as err: + if return_response: + raise ServiceValidationError(f"Error executing script: {err}") from err logger.error("Error executing script: %s", err) + return None except Exception as err: # pylint: disable=broad-except + if return_response: + raise HomeAssistantError( + f"Error executing script ({type(err).__name__}): {err}" + ) from err logger.exception("Error executing script: %s", err) + return None + + return restricted_globals["output"] class StubPrinter: diff --git a/tests/components/python_script/test_init.py b/tests/components/python_script/test_init.py index 4744c065ede0e9a0a47f8b6e202ecefd4e4d4409..ee7fedee0d5cf06a8a6c55826a8eb4124d4f7e5a 100644 --- a/tests/components/python_script/test_init.py +++ b/tests/components/python_script/test_init.py @@ -6,6 +6,7 @@ import pytest from homeassistant.components.python_script import DOMAIN, FOLDER, execute from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.setup import async_setup_component @@ -136,6 +137,19 @@ raise Exception('boom') assert "Error executing script: boom" in caplog.text +async def test_execute_runtime_error_with_response(hass: HomeAssistant) -> None: + """Test compile error logs error.""" + source = """ +raise Exception('boom') + """ + + task = hass.async_add_executor_job(execute, hass, "test.py", source, {}, True) + await hass.async_block_till_done() + + assert type(task.exception()) == HomeAssistantError + assert "Error executing script (Exception): boom" in str(task.exception()) + + async def test_accessing_async_methods( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: @@ -151,6 +165,19 @@ hass.async_stop() assert "Not allowed to access async methods" in caplog.text +async def test_accessing_async_methods_with_response(hass: HomeAssistant) -> None: + """Test compile error logs error.""" + source = """ +hass.async_stop() + """ + + task = hass.async_add_executor_job(execute, hass, "test.py", source, {}, True) + await hass.async_block_till_done() + + assert type(task.exception()) == ServiceValidationError + assert "Not allowed to access async methods" in str(task.exception()) + + async def test_using_complex_structures( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: @@ -186,6 +213,21 @@ async def test_accessing_forbidden_methods( assert f"Not allowed to access {name}" in caplog.text +async def test_accessing_forbidden_methods_with_response(hass: HomeAssistant) -> None: + """Test compile error logs error.""" + for source, name in { + "hass.stop()": "HomeAssistant.stop", + "dt_util.set_default_time_zone()": "module.set_default_time_zone", + "datetime.non_existing": "module.non_existing", + "time.tzset()": "TimeWrapper.tzset", + }.items(): + task = hass.async_add_executor_job(execute, hass, "test.py", source, {}, True) + await hass.async_block_till_done() + + assert type(task.exception()) == ServiceValidationError + assert f"Not allowed to access {name}" in str(task.exception()) + + async def test_iterating(hass: HomeAssistant) -> None: """Test compile error logs error.""" source = """ @@ -449,3 +491,108 @@ time.sleep(5) await hass.async_block_till_done() assert caplog.text.count("time.sleep") == 1 + + +async def test_execute_with_output( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test executing a script with a return value.""" + caplog.set_level(logging.WARNING) + + scripts = [ + "/some/config/dir/python_scripts/hello.py", + ] + with patch( + "homeassistant.components.python_script.os.path.isdir", return_value=True + ), patch("homeassistant.components.python_script.glob.iglob", return_value=scripts): + await async_setup_component(hass, "python_script", {}) + + source = """ +output = {"result": f"hello {data.get('name', 'World')}"} + """ + + with patch( + "homeassistant.components.python_script.open", + mock_open(read_data=source), + create=True, + ): + response = await hass.services.async_call( + "python_script", + "hello", + {"name": "paulus"}, + blocking=True, + return_response=True, + ) + + assert isinstance(response, dict) + assert len(response) == 1 + assert response["result"] == "hello paulus" + + # No errors logged = good + assert caplog.text == "" + + +async def test_execute_no_output( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test executing a script without a return value.""" + caplog.set_level(logging.WARNING) + + scripts = [ + "/some/config/dir/python_scripts/hello.py", + ] + with patch( + "homeassistant.components.python_script.os.path.isdir", return_value=True + ), patch("homeassistant.components.python_script.glob.iglob", return_value=scripts): + await async_setup_component(hass, "python_script", {}) + + source = """ +no_output = {"result": f"hello {data.get('name', 'World')}"} + """ + + with patch( + "homeassistant.components.python_script.open", + mock_open(read_data=source), + create=True, + ): + response = await hass.services.async_call( + "python_script", + "hello", + {"name": "paulus"}, + blocking=True, + return_response=True, + ) + + assert isinstance(response, dict) + assert len(response) == 0 + + # No errors logged = good + assert caplog.text == "" + + +async def test_execute_wrong_output_type(hass: HomeAssistant) -> None: + """Test executing a script without a return value.""" + scripts = [ + "/some/config/dir/python_scripts/hello.py", + ] + with patch( + "homeassistant.components.python_script.os.path.isdir", return_value=True + ), patch("homeassistant.components.python_script.glob.iglob", return_value=scripts): + await async_setup_component(hass, "python_script", {}) + + source = """ +output = f"hello {data.get('name', 'World')}" + """ + + with patch( + "homeassistant.components.python_script.open", + mock_open(read_data=source), + create=True, + ), pytest.raises(ServiceValidationError): + await hass.services.async_call( + "python_script", + "hello", + {"name": "paulus"}, + blocking=True, + return_response=True, + )