From 1be61df9c0596be0fbdce9bd0b15a6bfe56199e8 Mon Sep 17 00:00:00 2001
From: Paulus Schoutsen <paulus@paulusschoutsen.nl>
Date: Mon, 20 Aug 2018 17:39:53 +0200
Subject: [PATCH] Add recent context (#15989)

* Add recent context

* Add async_set_context to components not using new services
---
 homeassistant/components/alert.py          |  1 +
 homeassistant/components/light/__init__.py |  4 ++-
 homeassistant/helpers/entity.py            | 27 ++++++++++++++++--
 homeassistant/helpers/service.py           |  4 ++-
 tests/components/counter/test_init.py      | 23 ++++++++++++++-
 tests/components/test_input_boolean.py     | 23 ++++++++++++++-
 tests/components/test_input_datetime.py    | 26 ++++++++++++++++-
 tests/components/test_input_number.py      | 26 ++++++++++++++++-
 tests/components/test_input_select.py      | 29 ++++++++++++++++++-
 tests/components/test_input_text.py        | 26 ++++++++++++++++-
 tests/helpers/test_entity.py               | 33 +++++++++++++++++++++-
 11 files changed, 211 insertions(+), 11 deletions(-)

diff --git a/homeassistant/components/alert.py b/homeassistant/components/alert.py
index 80a02b3275d..3ec01fc6ab8 100644
--- a/homeassistant/components/alert.py
+++ b/homeassistant/components/alert.py
@@ -111,6 +111,7 @@ def async_setup(hass, config):
 
         for alert_id in alert_ids:
             alert = all_alerts[alert_id]
+            alert.async_set_context(service_call.context)
             if service_call.service == SERVICE_TURN_ON:
                 yield from alert.async_turn_on()
             elif service_call.service == SERVICE_TOGGLE:
diff --git a/homeassistant/components/light/__init__.py b/homeassistant/components/light/__init__.py
index ddee4108e31..bc7f136322b 100644
--- a/homeassistant/components/light/__init__.py
+++ b/homeassistant/components/light/__init__.py
@@ -345,6 +345,8 @@ async def async_setup(hass, config):
 
         update_tasks = []
         for light in target_lights:
+            light.async_set_context(service.context)
+
             pars = params
             if not pars:
                 pars = params.copy()
@@ -356,7 +358,7 @@ async def async_setup(hass, config):
                 continue
 
             update_tasks.append(
-                light.async_update_ha_state(True, service.context))
+                light.async_update_ha_state(True))
 
         if update_tasks:
             await asyncio.wait(update_tasks, loop=hass.loop)
diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py
index c356c266db6..f27a387b9ee 100644
--- a/homeassistant/helpers/entity.py
+++ b/homeassistant/helpers/entity.py
@@ -1,5 +1,6 @@
 """An abstract class for entities."""
 import asyncio
+from datetime import timedelta
 import logging
 import functools as ft
 from timeit import default_timer as timer
@@ -16,6 +17,7 @@ from homeassistant.config import DATA_CUSTOMIZE
 from homeassistant.exceptions import NoEntitySpecifiedError
 from homeassistant.util import ensure_unique_string, slugify
 from homeassistant.util.async_ import run_callback_threadsafe
+from homeassistant.util import dt as dt_util
 
 _LOGGER = logging.getLogger(__name__)
 SLOW_UPDATE_WARNING = 10
@@ -85,6 +87,10 @@ class Entity:
     # Hold list for functions to call on remove.
     _on_remove = None
 
+    # Context
+    _context = None
+    _context_set = None
+
     @property
     def should_poll(self) -> bool:
         """Return True if entity has to be polled for state.
@@ -173,13 +179,24 @@ class Entity:
         """Flag supported features."""
         return None
 
+    @property
+    def context_recent_time(self):
+        """Time that a context is considered recent."""
+        return timedelta(seconds=5)
+
     # DO NOT OVERWRITE
     # These properties and methods are either managed by Home Assistant or they
     # are used to perform a very specific function. Overwriting these may
     # produce undesirable effects in the entity's operation.
 
+    @callback
+    def async_set_context(self, context):
+        """Set the context the entity currently operates under."""
+        self._context = context
+        self._context_set = dt_util.utcnow()
+
     @asyncio.coroutine
-    def async_update_ha_state(self, force_refresh=False, context=None):
+    def async_update_ha_state(self, force_refresh=False):
         """Update Home Assistant with current state of entity.
 
         If force_refresh == True will update entity before setting state.
@@ -278,8 +295,14 @@ class Entity:
             # Could not convert state to float
             pass
 
+        if (self._context is not None and
+                dt_util.utcnow() - self._context_set >
+                self.context_recent_time):
+            self._context = None
+            self._context_set = None
+
         self.hass.states.async_set(
-            self.entity_id, state, attr, self.force_update, context)
+            self.entity_id, state, attr, self.force_update, self._context)
 
     def schedule_update_ha_state(self, force_refresh=False):
         """Schedule an update ha state change task.
diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py
index acad72a860a..fcdc3cfe856 100644
--- a/homeassistant/helpers/service.py
+++ b/homeassistant/helpers/service.py
@@ -218,13 +218,15 @@ async def _handle_service_platform_call(func, data, entities, context):
         if not entity.available:
             continue
 
+        entity.async_set_context(context)
+
         if isinstance(func, str):
             await getattr(entity, func)(**data)
         else:
             await func(entity, data)
 
         if entity.should_poll:
-            tasks.append(entity.async_update_ha_state(True, context))
+            tasks.append(entity.async_update_ha_state(True))
 
     if tasks:
         await asyncio.wait(tasks)
diff --git a/tests/components/counter/test_init.py b/tests/components/counter/test_init.py
index f4c6ee9c7da..af36c1c8f95 100644
--- a/tests/components/counter/test_init.py
+++ b/tests/components/counter/test_init.py
@@ -4,7 +4,7 @@ import asyncio
 import unittest
 import logging
 
-from homeassistant.core import CoreState, State
+from homeassistant.core import CoreState, State, Context
 from homeassistant.setup import setup_component, async_setup_component
 from homeassistant.components.counter import (
     DOMAIN, decrement, increment, reset, CONF_INITIAL, CONF_STEP, CONF_NAME,
@@ -202,3 +202,24 @@ def test_no_initial_state_and_no_restore_state(hass):
     state = hass.states.get('counter.test1')
     assert state
     assert int(state.state) == 0
+
+
+async def test_counter_context(hass):
+    """Test that counter context works."""
+    assert await async_setup_component(hass, 'counter', {
+        'counter': {
+            'test': {}
+        }
+    })
+
+    state = hass.states.get('counter.test')
+    assert state is not None
+
+    await hass.services.async_call('counter', 'increment', {
+        'entity_id': state.entity_id,
+    }, True, Context(user_id='abcd'))
+
+    state2 = hass.states.get('counter.test')
+    assert state2 is not None
+    assert state.state != state2.state
+    assert state2.context.user_id == 'abcd'
diff --git a/tests/components/test_input_boolean.py b/tests/components/test_input_boolean.py
index 964d1763e4e..999e7ac100f 100644
--- a/tests/components/test_input_boolean.py
+++ b/tests/components/test_input_boolean.py
@@ -4,7 +4,7 @@ import asyncio
 import unittest
 import logging
 
-from homeassistant.core import CoreState, State
+from homeassistant.core import CoreState, State, Context
 from homeassistant.setup import setup_component, async_setup_component
 from homeassistant.components.input_boolean import (
     DOMAIN, is_on, toggle, turn_off, turn_on, CONF_INITIAL)
@@ -158,3 +158,24 @@ def test_initial_state_overrules_restore_state(hass):
     state = hass.states.get('input_boolean.b2')
     assert state
     assert state.state == 'on'
+
+
+async def test_input_boolean_context(hass):
+    """Test that input_boolean context works."""
+    assert await async_setup_component(hass, 'input_boolean', {
+        'input_boolean': {
+            'ac': {CONF_INITIAL: True},
+        }
+    })
+
+    state = hass.states.get('input_boolean.ac')
+    assert state is not None
+
+    await hass.services.async_call('input_boolean', 'turn_off', {
+        'entity_id': state.entity_id,
+    }, True, Context(user_id='abcd'))
+
+    state2 = hass.states.get('input_boolean.ac')
+    assert state2 is not None
+    assert state.state != state2.state
+    assert state2.context.user_id == 'abcd'
diff --git a/tests/components/test_input_datetime.py b/tests/components/test_input_datetime.py
index 0d21061e022..9ced2aaa072 100644
--- a/tests/components/test_input_datetime.py
+++ b/tests/components/test_input_datetime.py
@@ -4,7 +4,7 @@ import asyncio
 import unittest
 import datetime
 
-from homeassistant.core import CoreState, State
+from homeassistant.core import CoreState, State, Context
 from homeassistant.setup import setup_component, async_setup_component
 from homeassistant.components.input_datetime import (
     DOMAIN, ATTR_ENTITY_ID, ATTR_DATE, ATTR_TIME, SERVICE_SET_DATETIME)
@@ -208,3 +208,27 @@ def test_restore_state(hass):
 
     state_bogus = hass.states.get('input_datetime.test_bogus_data')
     assert state_bogus.state == str(initial)
+
+
+async def test_input_datetime_context(hass):
+    """Test that input_datetime context works."""
+    assert await async_setup_component(hass, 'input_datetime', {
+        'input_datetime': {
+            'only_date': {
+                'has_date': True,
+            }
+        }
+    })
+
+    state = hass.states.get('input_datetime.only_date')
+    assert state is not None
+
+    await hass.services.async_call('input_datetime', 'set_datetime', {
+        'entity_id': state.entity_id,
+        'date': '2018-01-02'
+    }, True, Context(user_id='abcd'))
+
+    state2 = hass.states.get('input_datetime.only_date')
+    assert state2 is not None
+    assert state.state != state2.state
+    assert state2.context.user_id == 'abcd'
diff --git a/tests/components/test_input_number.py b/tests/components/test_input_number.py
index d416dcae154..659aaa524d9 100644
--- a/tests/components/test_input_number.py
+++ b/tests/components/test_input_number.py
@@ -3,7 +3,7 @@
 import asyncio
 import unittest
 
-from homeassistant.core import CoreState, State
+from homeassistant.core import CoreState, State, Context
 from homeassistant.setup import setup_component, async_setup_component
 from homeassistant.components.input_number import (
     DOMAIN, set_value, increment, decrement)
@@ -236,3 +236,27 @@ def test_no_initial_state_and_no_restore_state(hass):
     state = hass.states.get('input_number.b1')
     assert state
     assert float(state.state) == 0
+
+
+async def test_input_number_context(hass):
+    """Test that input_number context works."""
+    assert await async_setup_component(hass, 'input_number', {
+        'input_number': {
+            'b1': {
+                'min': 0,
+                'max': 100,
+            },
+        }
+    })
+
+    state = hass.states.get('input_number.b1')
+    assert state is not None
+
+    await hass.services.async_call('input_number', 'increment', {
+        'entity_id': state.entity_id,
+    }, True, Context(user_id='abcd'))
+
+    state2 = hass.states.get('input_number.b1')
+    assert state2 is not None
+    assert state.state != state2.state
+    assert state2.context.user_id == 'abcd'
diff --git a/tests/components/test_input_select.py b/tests/components/test_input_select.py
index 82da80253c5..1c73abfbb94 100644
--- a/tests/components/test_input_select.py
+++ b/tests/components/test_input_select.py
@@ -5,7 +5,7 @@ import unittest
 
 from tests.common import get_test_home_assistant, mock_restore_cache
 
-from homeassistant.core import State
+from homeassistant.core import State, Context
 from homeassistant.setup import setup_component, async_setup_component
 from homeassistant.components.input_select import (
     ATTR_OPTIONS, DOMAIN, SERVICE_SET_OPTIONS,
@@ -276,3 +276,30 @@ def test_initial_state_overrules_restore_state(hass):
     state = hass.states.get('input_select.s2')
     assert state
     assert state.state == 'middle option'
+
+
+async def test_input_select_context(hass):
+    """Test that input_select context works."""
+    assert await async_setup_component(hass, 'input_select', {
+        'input_select': {
+            's1': {
+                'options': [
+                    'first option',
+                    'middle option',
+                    'last option',
+                ],
+            }
+        }
+    })
+
+    state = hass.states.get('input_select.s1')
+    assert state is not None
+
+    await hass.services.async_call('input_select', 'select_next', {
+        'entity_id': state.entity_id,
+    }, True, Context(user_id='abcd'))
+
+    state2 = hass.states.get('input_select.s1')
+    assert state2 is not None
+    assert state.state != state2.state
+    assert state2.context.user_id == 'abcd'
diff --git a/tests/components/test_input_text.py b/tests/components/test_input_text.py
index 405f7de8272..7c8a0e65023 100644
--- a/tests/components/test_input_text.py
+++ b/tests/components/test_input_text.py
@@ -3,7 +3,7 @@
 import asyncio
 import unittest
 
-from homeassistant.core import CoreState, State
+from homeassistant.core import CoreState, State, Context
 from homeassistant.setup import setup_component, async_setup_component
 from homeassistant.components.input_text import (DOMAIN, set_value)
 
@@ -180,3 +180,27 @@ def test_no_initial_state_and_no_restore_state(hass):
     state = hass.states.get('input_text.b1')
     assert state
     assert str(state.state) == 'unknown'
+
+
+async def test_input_text_context(hass):
+    """Test that input_text context works."""
+    assert await async_setup_component(hass, 'input_text', {
+        'input_text': {
+            't1': {
+                'initial': 'bla',
+            }
+        }
+    })
+
+    state = hass.states.get('input_text.t1')
+    assert state is not None
+
+    await hass.services.async_call('input_text', 'set_value', {
+        'entity_id': state.entity_id,
+        'value': 'new_value',
+    }, True, Context(user_id='abcd'))
+
+    state2 = hass.states.get('input_text.t1')
+    assert state2 is not None
+    assert state.state != state2.state
+    assert state2.context.user_id == 'abcd'
diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py
index f9355f97ba3..a51787225ca 100644
--- a/tests/helpers/test_entity.py
+++ b/tests/helpers/test_entity.py
@@ -1,11 +1,13 @@
 """Test the entity helper."""
 # pylint: disable=protected-access
 import asyncio
-from unittest.mock import MagicMock, patch
+from datetime import timedelta
+from unittest.mock import MagicMock, patch, PropertyMock
 
 import pytest
 
 import homeassistant.helpers.entity as entity
+from homeassistant.core import Context
 from homeassistant.const import ATTR_HIDDEN, ATTR_DEVICE_CLASS
 from homeassistant.config import DATA_CUSTOMIZE
 from homeassistant.helpers.entity_values import EntityValues
@@ -412,3 +414,32 @@ async def test_async_remove_runs_callbacks(hass):
     ent.async_on_remove(lambda: result.append(1))
     await ent.async_remove()
     assert len(result) == 1
+
+
+async def test_set_context(hass):
+    """Test setting context."""
+    context = Context()
+    ent = entity.Entity()
+    ent.hass = hass
+    ent.entity_id = 'hello.world'
+    ent.async_set_context(context)
+    await ent.async_update_ha_state()
+    assert hass.states.get('hello.world').context == context
+
+
+async def test_set_context_expired(hass):
+    """Test setting context."""
+    context = Context()
+
+    with patch.object(entity.Entity, 'context_recent_time',
+                      new_callable=PropertyMock) as recent:
+        recent.return_value = timedelta(seconds=-5)
+        ent = entity.Entity()
+        ent.hass = hass
+        ent.entity_id = 'hello.world'
+        ent.async_set_context(context)
+        await ent.async_update_ha_state()
+
+    assert hass.states.get('hello.world').context != context
+    assert ent._context is None
+    assert ent._context_set is None
-- 
GitLab