From c35f9ee35f8efb318a8ab59c1e6d9d277a415379 Mon Sep 17 00:00:00 2001
From: Santobert <tobhaase@gmail.com>
Date: Fri, 22 Nov 2019 22:21:28 +0100
Subject: [PATCH] Creating a scene by snapshotting entities (#28939)

* Initial commit

* Add tests

* service.yaml

* typo

* snapshooted -> snapshot

* snapshot_entities instead of snapshot

* Edit validator

* Fix tests

* Remove keys()

* Improve coverage

* Activate scenes

* Use pytest.raise

* snapshot -> snapshotted
---
 .../components/homeassistant/scene.py         | 46 +++++++++-
 homeassistant/components/scene/services.yaml  |  5 +
 tests/components/homeassistant/test_scene.py  | 92 +++++++++++++++++++
 3 files changed, 140 insertions(+), 3 deletions(-)

diff --git a/homeassistant/components/homeassistant/scene.py b/homeassistant/components/homeassistant/scene.py
index c505d1534de..576bf540e00 100644
--- a/homeassistant/components/homeassistant/scene.py
+++ b/homeassistant/components/homeassistant/scene.py
@@ -55,7 +55,22 @@ def _convert_states(states):
     return result
 
 
+def _ensure_no_intersection(value):
+    """Validate that entities and snapshot_entities do not overlap."""
+    if (
+        CONF_SNAPSHOT not in value
+        or CONF_ENTITIES not in value
+        or not any(
+            entity_id in value[CONF_SNAPSHOT] for entity_id in value[CONF_ENTITIES]
+        )
+    ):
+        return value
+
+    raise vol.Invalid("entities and snapshot_entities must not overlap")
+
+
 CONF_SCENE_ID = "scene_id"
+CONF_SNAPSHOT = "snapshot_entities"
 
 STATES_SCHEMA = vol.All(dict, _convert_states)
 
@@ -75,8 +90,16 @@ PLATFORM_SCHEMA = vol.Schema(
     extra=vol.ALLOW_EXTRA,
 )
 
-CREATE_SCENE_SCHEMA = vol.Schema(
-    {vol.Required(CONF_SCENE_ID): cv.slug, vol.Required(CONF_ENTITIES): STATES_SCHEMA}
+CREATE_SCENE_SCHEMA = vol.All(
+    cv.has_at_least_one_key(CONF_ENTITIES, CONF_SNAPSHOT),
+    _ensure_no_intersection,
+    vol.Schema(
+        {
+            vol.Required(CONF_SCENE_ID): cv.slug,
+            vol.Optional(CONF_ENTITIES, default={}): STATES_SCHEMA,
+            vol.Optional(CONF_SNAPSHOT, default=[]): cv.entity_ids,
+        }
+    ),
 )
 
 SERVICE_APPLY = "apply"
@@ -139,7 +162,24 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
 
     async def create_service(call):
         """Create a scene."""
-        scene_config = SCENECONFIG(call.data[CONF_SCENE_ID], call.data[CONF_ENTITIES])
+        snapshot = call.data[CONF_SNAPSHOT]
+        entities = call.data[CONF_ENTITIES]
+
+        for entity_id in snapshot:
+            state = hass.states.get(entity_id)
+            if state is None:
+                _LOGGER.warning(
+                    "Entity %s does not exist and therefore cannot be snapshotted",
+                    entity_id,
+                )
+                continue
+            entities[entity_id] = State(entity_id, state.state, state.attributes)
+
+        if not entities:
+            _LOGGER.warning("Empty scenes are not allowed")
+            return
+
+        scene_config = SCENECONFIG(call.data[CONF_SCENE_ID], entities)
         entity_id = f"{SCENE_DOMAIN}.{scene_config.name}"
         old = platform.entities.get(entity_id)
         if old is not None:
diff --git a/homeassistant/components/scene/services.yaml b/homeassistant/components/scene/services.yaml
index 9cf1b9010a8..0c261ed60b5 100644
--- a/homeassistant/components/scene/services.yaml
+++ b/homeassistant/components/scene/services.yaml
@@ -34,3 +34,8 @@ create:
         light.ceiling:
           state: "on"
           brightness: 200
+    snapshot_entities:
+      description: The entities of which a snapshot is to be taken
+      example:
+        - light.ceiling
+        - light.kitchen
diff --git a/tests/components/homeassistant/test_scene.py b/tests/components/homeassistant/test_scene.py
index 25ce6088a51..d3bbac44df8 100644
--- a/tests/components/homeassistant/test_scene.py
+++ b/tests/components/homeassistant/test_scene.py
@@ -1,8 +1,13 @@
 """Test Home Assistant scenes."""
 from unittest.mock import patch
 
+import pytest
+import voluptuous as vol
+
 from homeassistant.setup import async_setup_component
 
+from tests.common import async_mock_service
+
 
 async def test_reload_config_service(hass):
     """Test the reload config service."""
@@ -63,6 +68,16 @@ async def test_create_service(hass, caplog):
     assert hass.states.get("scene.hallo") is None
     assert hass.states.get("scene.hallo_2") is not None
 
+    assert await hass.services.async_call(
+        "scene",
+        "create",
+        {"scene_id": "hallo", "entities": {}, "snapshot_entities": []},
+        blocking=True,
+    )
+    await hass.async_block_till_done()
+    assert "Empty scenes are not allowed" in caplog.text
+    assert hass.states.get("scene.hallo") is None
+
     assert await hass.services.async_call(
         "scene",
         "create",
@@ -117,3 +132,80 @@ async def test_create_service(hass, caplog):
     assert scene.name == "hallo_2"
     assert scene.state == "scening"
     assert scene.attributes.get("entity_id") == ["light.kitchen"]
+
+
+async def test_snapshot_service(hass, caplog):
+    """Test the snapshot option."""
+    assert await async_setup_component(hass, "scene", {"scene": {}})
+    hass.states.async_set("light.my_light", "on", {"hs_color": (345, 75)})
+    assert hass.states.get("scene.hallo") is None
+
+    assert await hass.services.async_call(
+        "scene",
+        "create",
+        {"scene_id": "hallo", "snapshot_entities": ["light.my_light"]},
+        blocking=True,
+    )
+    await hass.async_block_till_done()
+    scene = hass.states.get("scene.hallo")
+    assert scene is not None
+    assert scene.attributes.get("entity_id") == ["light.my_light"]
+
+    hass.states.async_set("light.my_light", "off", {"hs_color": (123, 45)})
+    turn_on_calls = async_mock_service(hass, "light", "turn_on")
+    assert await hass.services.async_call(
+        "scene", "turn_on", {"entity_id": "scene.hallo"}, blocking=True
+    )
+    await hass.async_block_till_done()
+    assert len(turn_on_calls) == 1
+    assert turn_on_calls[0].data.get("entity_id") == "light.my_light"
+    assert turn_on_calls[0].data.get("hs_color") == (345, 75)
+
+    assert await hass.services.async_call(
+        "scene",
+        "create",
+        {"scene_id": "hallo_2", "snapshot_entities": ["light.not_existent"]},
+        blocking=True,
+    )
+    await hass.async_block_till_done()
+    assert hass.states.get("scene.hallo_2") is None
+    assert (
+        "Entity light.not_existent does not exist and therefore cannot be snapshotted"
+        in caplog.text
+    )
+
+    assert await hass.services.async_call(
+        "scene",
+        "create",
+        {
+            "scene_id": "hallo_3",
+            "entities": {"light.bed_light": {"state": "on", "brightness": 50}},
+            "snapshot_entities": ["light.my_light"],
+        },
+        blocking=True,
+    )
+    await hass.async_block_till_done()
+    scene = hass.states.get("scene.hallo_3")
+    assert scene is not None
+    assert "light.my_light" in scene.attributes.get("entity_id")
+    assert "light.bed_light" in scene.attributes.get("entity_id")
+
+
+async def test_ensure_no_intersection(hass):
+    """Test that entities and snapshot_entities do not overlap."""
+    assert await async_setup_component(hass, "scene", {"scene": {}})
+
+    with pytest.raises(vol.MultipleInvalid) as ex:
+        assert await hass.services.async_call(
+            "scene",
+            "create",
+            {
+                "scene_id": "hallo",
+                "entities": {"light.my_light": {"state": "on", "brightness": 50}},
+                "snapshot_entities": ["light.my_light"],
+            },
+            blocking=True,
+        )
+        await hass.async_block_till_done()
+    assert "entities and snapshot_entities must not overlap" in str(ex.value)
+    assert hass.states.get("scene.hallo") is None
-- 
GitLab