Skip to content
Snippets Groups Projects
Unverified Commit f1f4d80f authored by cdce8p's avatar cdce8p Committed by GitHub
Browse files

Homekit Bugfixes (#14689)

* Fix async bug
* Fix debounce bug
parent e746b92e
No related branches found
No related tags found
No related merge requests found
"""Extend the basic Accessory and Bridge functions.""" """Extend the basic Accessory and Bridge functions."""
from datetime import timedelta from datetime import timedelta
from functools import wraps from functools import partial, wraps
from inspect import getmodule from inspect import getmodule
import logging import logging
...@@ -27,35 +27,25 @@ _LOGGER = logging.getLogger(__name__) ...@@ -27,35 +27,25 @@ _LOGGER = logging.getLogger(__name__)
def debounce(func): def debounce(func):
"""Decorator function. Debounce callbacks form HomeKit.""" """Decorator function. Debounce callbacks form HomeKit."""
@ha_callback @ha_callback
def call_later_listener(*args): def call_later_listener(self, *args):
"""Callback listener called from call_later.""" """Callback listener called from call_later."""
# pylint: disable=unsubscriptable-object debounce_params = self.debounce.pop(func.__name__, None)
nonlocal lastargs, remove_listener if debounce_params:
hass = lastargs['hass'] self.hass.async_add_job(func, self, *debounce_params[1:])
hass.async_add_job(func, *lastargs['args'])
lastargs = remove_listener = None
@wraps(func) @wraps(func)
def wrapper(*args): def wrapper(self, *args):
"""Wrapper starts async timer. """Wrapper starts async timer."""
debounce_params = self.debounce.pop(func.__name__, None)
The accessory must have 'self.hass' and 'self.entity_id' as attributes. if debounce_params:
""" debounce_params[0]() # remove listener
# pylint: disable=not-callable
hass = args[0].hass
nonlocal lastargs, remove_listener
if remove_listener:
remove_listener()
lastargs = remove_listener = None
lastargs = {'hass': hass, 'args': [*args]}
remove_listener = track_point_in_utc_time( remove_listener = track_point_in_utc_time(
hass, call_later_listener, self.hass, partial(call_later_listener, self),
dt_util.utcnow() + timedelta(seconds=DEBOUNCE_TIMEOUT)) dt_util.utcnow() + timedelta(seconds=DEBOUNCE_TIMEOUT))
logger.debug('%s: Start %s timeout', args[0].entity_id, self.debounce[func.__name__] = (remove_listener, *args)
logger.debug('%s: Start %s timeout', self.entity_id,
func.__name__.replace('set_', '')) func.__name__.replace('set_', ''))
remove_listener = None
lastargs = None
name = getmodule(func).__name__ name = getmodule(func).__name__
logger = logging.getLogger(name) logger = logging.getLogger(name)
return wrapper return wrapper
...@@ -76,11 +66,15 @@ class HomeAccessory(Accessory): ...@@ -76,11 +66,15 @@ class HomeAccessory(Accessory):
self.config = config self.config = config
self.entity_id = entity_id self.entity_id = entity_id
self.hass = hass self.hass = hass
self.debounce = {}
def run(self): async def run(self):
"""Method called by accessory after driver is started.""" """Method called by accessory after driver is started.
Run inside the HAP-python event loop.
"""
state = self.hass.states.get(self.entity_id) state = self.hass.states.get(self.entity_id)
self.update_state_callback(new_state=state) self.hass.add_job(self.update_state_callback, None, None, state)
async_track_state_change( async_track_state_change(
self.hass, self.entity_id, self.update_state_callback) self.hass, self.entity_id, self.update_state_callback)
...@@ -127,10 +121,10 @@ class HomeDriver(AccessoryDriver): ...@@ -127,10 +121,10 @@ class HomeDriver(AccessoryDriver):
def pair(self, client_uuid, client_public): def pair(self, client_uuid, client_public):
"""Override super function to dismiss setup message if paired.""" """Override super function to dismiss setup message if paired."""
value = super().pair(client_uuid, client_public) success = super().pair(client_uuid, client_public)
if value: if success:
dismiss_setup_message(self.hass) dismiss_setup_message(self.hass)
return value return success
def unpair(self, client_uuid): def unpair(self, client_uuid):
"""Override super function to show setup message if unpaired.""" """Override super function to show setup message if unpaired."""
......
...@@ -26,7 +26,7 @@ async def test_debounce(hass): ...@@ -26,7 +26,7 @@ async def test_debounce(hass):
arguments = None arguments = None
counter = 0 counter = 0
mock = Mock(hass=hass) mock = Mock(hass=hass, debounce={})
debounce_demo = debounce(demo_func) debounce_demo = debounce(demo_func)
assert debounce_demo.__name__ == 'demo_func' assert debounce_demo.__name__ == 'demo_func'
...@@ -76,6 +76,7 @@ async def test_home_accessory(hass, hk_driver): ...@@ -76,6 +76,7 @@ async def test_home_accessory(hass, hk_driver):
with patch('homeassistant.components.homekit.accessories.' with patch('homeassistant.components.homekit.accessories.'
'HomeAccessory.update_state') as mock_update_state: 'HomeAccessory.update_state') as mock_update_state:
await hass.async_add_job(acc.run) await hass.async_add_job(acc.run)
await hass.async_block_till_done()
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
mock_update_state.assert_called_with(state) mock_update_state.assert_called_with(state)
......
...@@ -35,6 +35,7 @@ async def test_default_thermostat(hass, hk_driver, cls): ...@@ -35,6 +35,7 @@ async def test_default_thermostat(hass, hk_driver, cls):
await hass.async_block_till_done() await hass.async_block_till_done()
acc = cls.thermostat(hass, hk_driver, 'Climate', entity_id, 2, None) acc = cls.thermostat(hass, hk_driver, 'Climate', entity_id, 2, None)
await hass.async_add_job(acc.run) await hass.async_add_job(acc.run)
await hass.async_block_till_done()
assert acc.aid == 2 assert acc.aid == 2
assert acc.category == 9 # Thermostat assert acc.category == 9 # Thermostat
...@@ -175,6 +176,7 @@ async def test_auto_thermostat(hass, hk_driver, cls): ...@@ -175,6 +176,7 @@ async def test_auto_thermostat(hass, hk_driver, cls):
await hass.async_block_till_done() await hass.async_block_till_done()
acc = cls.thermostat(hass, hk_driver, 'Climate', entity_id, 2, None) acc = cls.thermostat(hass, hk_driver, 'Climate', entity_id, 2, None)
await hass.async_add_job(acc.run) await hass.async_add_job(acc.run)
await hass.async_block_till_done()
assert acc.char_cooling_thresh_temp.value == 23.0 assert acc.char_cooling_thresh_temp.value == 23.0
assert acc.char_heating_thresh_temp.value == 19.0 assert acc.char_heating_thresh_temp.value == 19.0
...@@ -254,6 +256,7 @@ async def test_power_state(hass, hk_driver, cls): ...@@ -254,6 +256,7 @@ async def test_power_state(hass, hk_driver, cls):
await hass.async_block_till_done() await hass.async_block_till_done()
acc = cls.thermostat(hass, hk_driver, 'Climate', entity_id, 2, None) acc = cls.thermostat(hass, hk_driver, 'Climate', entity_id, 2, None)
await hass.async_add_job(acc.run) await hass.async_add_job(acc.run)
await hass.async_block_till_done()
assert acc.support_power_state is True assert acc.support_power_state is True
assert acc.char_current_heat_cool.value == 1 assert acc.char_current_heat_cool.value == 1
...@@ -306,6 +309,7 @@ async def test_thermostat_fahrenheit(hass, hk_driver, cls): ...@@ -306,6 +309,7 @@ async def test_thermostat_fahrenheit(hass, hk_driver, cls):
await hass.async_block_till_done() await hass.async_block_till_done()
acc = cls.thermostat(hass, hk_driver, 'Climate', entity_id, 2, None) acc = cls.thermostat(hass, hk_driver, 'Climate', entity_id, 2, None)
await hass.async_add_job(acc.run) await hass.async_add_job(acc.run)
await hass.async_block_till_done()
hass.states.async_set(entity_id, STATE_AUTO, hass.states.async_set(entity_id, STATE_AUTO,
{ATTR_OPERATION_MODE: STATE_AUTO, {ATTR_OPERATION_MODE: STATE_AUTO,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment