From b64ae55c66cd46138e626d0503bb47db3b31f303 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" <nick@koston.org> Date: Sun, 5 Jul 2020 11:03:23 -0500 Subject: [PATCH] Prebake common history queries (#37496) * Prebake common history queries The python overhead of to construct the queries exceeded the database overhead. We now prebake the queries that get frequently polled. This reduces the time it takes to update history_stats sensors and can make quite a difference if there are a lot of them. When using the mini-graph-card card, all the entities on the card being graphed are queried every few seconds for new states. Previously this would tie up the database if there are lot of these graphs in the UI. * Update homeassistant/components/history/__init__.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * Update homeassistant/components/history/__init__.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * cache entity_filter in the lambda * switch to yield * Revert "switch to yield" This reverts commit f8386f494002178729b67b54dd299affd406f2f2. * get_states always returns a list * query wasnt actually reusable so revert part of the breakout Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> --- homeassistant/components/history/__init__.py | 160 +++++++++++++------ tests/components/history/test_init.py | 12 +- 2 files changed, 120 insertions(+), 52 deletions(-) diff --git a/homeassistant/components/history/__init__.py b/homeassistant/components/history/__init__.py index fe9a6c82825..4267127e209 100644 --- a/homeassistant/components/history/__init__.py +++ b/homeassistant/components/history/__init__.py @@ -8,7 +8,8 @@ import time from typing import Optional, cast from aiohttp import web -from sqlalchemy import and_, func +from sqlalchemy import and_, bindparam, func +from sqlalchemy.ext import baked import voluptuous as vol from homeassistant.components import recorder @@ -88,6 +89,8 @@ QUERY_STATES = [ States.last_updated, ] +HISTORY_BAKERY = "history_bakery" + def get_significant_states(hass, *args, **kwargs): """Wrap _get_significant_states with a sql session.""" @@ -115,26 +118,34 @@ def _get_significant_states( """ timer_start = time.perf_counter() + baked_query = hass.data[HISTORY_BAKERY]( + lambda session: session.query(*QUERY_STATES) + ) + if significant_changes_only: - query = session.query(*QUERY_STATES).filter( + baked_query += lambda q: q.filter( ( States.domain.in_(SIGNIFICANT_DOMAINS) | (States.last_changed == States.last_updated) ) - & (States.last_updated > start_time) + & (States.last_updated > bindparam("start_time")) ) else: - query = session.query(*QUERY_STATES).filter(States.last_updated > start_time) + baked_query += lambda q: q.filter(States.last_updated > bindparam("start_time")) if filters: - query = filters.apply(query, entity_ids) + filters.bake(baked_query, entity_ids) if end_time is not None: - query = query.filter(States.last_updated < end_time) + baked_query += lambda q: q.filter(States.last_updated < bindparam("end_time")) - query = query.order_by(States.entity_id, States.last_updated) + baked_query += lambda q: q.order_by(States.entity_id, States.last_updated) - states = execute(query) + states = execute( + baked_query(session).params( + start_time=start_time, end_time=end_time, entity_ids=entity_ids + ) + ) if _LOGGER.isEnabledFor(logging.DEBUG): elapsed = time.perf_counter() - timer_start @@ -155,20 +166,33 @@ def _get_significant_states( def state_changes_during_period(hass, start_time, end_time=None, entity_id=None): """Return states changes during UTC period start_time - end_time.""" with session_scope(hass=hass) as session: - query = session.query(*QUERY_STATES).filter( + baked_query = hass.data[HISTORY_BAKERY]( + lambda session: session.query(*QUERY_STATES) + ) + + baked_query += lambda q: q.filter( (States.last_changed == States.last_updated) - & (States.last_updated > start_time) + & (States.last_updated > bindparam("start_time")) ) if end_time is not None: - query = query.filter(States.last_updated < end_time) + baked_query += lambda q: q.filter( + States.last_updated < bindparam("end_time") + ) if entity_id is not None: - query = query.filter_by(entity_id=entity_id.lower()) + baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id")) + entity_id = entity_id.lower() - entity_ids = [entity_id] if entity_id is not None else None + baked_query += lambda q: q.order_by(States.entity_id, States.last_updated) - states = execute(query.order_by(States.entity_id, States.last_updated)) + states = execute( + baked_query(session).params( + start_time=start_time, end_time=end_time, entity_id=entity_id + ) + ) + + entity_ids = [entity_id] if entity_id is not None else None return _sorted_states_to_json(hass, session, states, start_time, entity_ids) @@ -178,21 +202,29 @@ def get_last_state_changes(hass, number_of_states, entity_id): start_time = dt_util.utcnow() with session_scope(hass=hass) as session: - query = session.query(*QUERY_STATES).filter( - States.last_changed == States.last_updated + baked_query = hass.data[HISTORY_BAKERY]( + lambda session: session.query(*QUERY_STATES) ) + baked_query += lambda q: q.filter(States.last_changed == States.last_updated) if entity_id is not None: - query = query.filter_by(entity_id=entity_id.lower()) + baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id")) + entity_id = entity_id.lower() - entity_ids = [entity_id] if entity_id is not None else None + baked_query += lambda q: q.order_by( + States.entity_id, States.last_updated.desc() + ) + + baked_query += lambda q: q.limit(bindparam("number_of_states")) states = execute( - query.order_by(States.entity_id, States.last_updated.desc()).limit( - number_of_states + baked_query(session).params( + number_of_states=number_of_states, entity_id=entity_id ) ) + entity_ids = [entity_id] if entity_id is not None else None + return _sorted_states_to_json( hass, session, @@ -214,28 +246,18 @@ def get_states(hass, utc_point_in_time, entity_ids=None, run=None, filters=None) with session_scope(hass=hass) as session: return _get_states_with_session( - session, utc_point_in_time, entity_ids, run, filters + hass, session, utc_point_in_time, entity_ids, run, filters ) def _get_states_with_session( - session, utc_point_in_time, entity_ids=None, run=None, filters=None + hass, session, utc_point_in_time, entity_ids=None, run=None, filters=None ): """Return the states at a specific point in time.""" - query = session.query(*QUERY_STATES) - if entity_ids and len(entity_ids) == 1: - # Use an entirely different (and extremely fast) query if we only - # have a single entity id - query = ( - query.filter( - States.last_updated < utc_point_in_time, - States.entity_id.in_(entity_ids), - ) - .order_by(States.last_updated.desc()) - .limit(1) + return _get_single_entity_states_with_session( + hass, session, utc_point_in_time, entity_ids[0] ) - return [LazyState(row) for row in execute(query)] if run is None: run = recorder.run_information_with_session(session, utc_point_in_time) @@ -247,6 +269,7 @@ def _get_states_with_session( # We have more than one entity to look at (most commonly we want # all entities,) so we need to do a search on all states since the # last recorder run started. + query = session.query(*QUERY_STATES) most_recent_states_by_date = session.query( States.entity_id.label("max_entity_id"), @@ -286,6 +309,26 @@ def _get_states_with_session( return [LazyState(row) for row in execute(query)] +def _get_single_entity_states_with_session(hass, session, utc_point_in_time, entity_id): + # Use an entirely different (and extremely fast) query if we only + # have a single entity id + baked_query = hass.data[HISTORY_BAKERY]( + lambda session: session.query(*QUERY_STATES) + ) + baked_query += lambda q: q.filter( + States.last_updated < bindparam("utc_point_in_time"), + States.entity_id == bindparam("entity_id"), + ) + baked_query += lambda q: q.order_by(States.last_updated.desc()) + baked_query += lambda q: q.limit(1) + + query = baked_query(session).params( + utc_point_in_time=utc_point_in_time, entity_id=entity_id + ) + + return [LazyState(row) for row in execute(query)] + + def _sorted_states_to_json( hass, session, @@ -318,7 +361,7 @@ def _sorted_states_to_json( if include_start_time_state: run = recorder.run_information_from_instance(hass, start_time) for state in _get_states_with_session( - session, start_time, entity_ids, run=run, filters=filters + hass, session, start_time, entity_ids, run=run, filters=filters ): state.last_changed = start_time state.last_updated = start_time @@ -337,16 +380,16 @@ def _sorted_states_to_json( domain = split_entity_id(ent_id)[0] ent_results = result[ent_id] if not minimal_response or domain in NEED_ATTRIBUTE_DOMAINS: - ent_results.extend( - [ - native_state - for native_state in (LazyState(db_state) for db_state in group) - if ( - domain != SCRIPT_DOMAIN - or native_state.attributes.get(ATTR_CAN_CANCEL) - ) - ] - ) + if domain == SCRIPT_DOMAIN: + ent_results.extend( + [ + native_state + for native_state in (LazyState(db_state) for db_state in group) + if native_state.attributes.get(ATTR_CAN_CANCEL) + ] + ) + else: + ent_results.extend(LazyState(db_state) for db_state in group) continue # With minimal response we only provide a native @@ -387,7 +430,7 @@ def _sorted_states_to_json( def get_state(hass, utc_point_in_time, entity_id, run=None): """Return a state at a specific point in time.""" - states = list(get_states(hass, utc_point_in_time, (entity_id,), run)) + states = get_states(hass, utc_point_in_time, (entity_id,), run) return states[0] if states else None @@ -396,6 +439,9 @@ async def async_setup(hass, config): conf = config.get(DOMAIN, {}) filters = sqlalchemy_filter_from_include_exclude_conf(conf) + + hass.data[HISTORY_BAKERY] = baked.bakery() + use_include_order = conf.get(CONF_ORDER) hass.http.register_view(HistoryPeriodView(filters, use_include_order)) @@ -560,6 +606,7 @@ class Filters: # specific entities requested - do not in/exclude anything if entity_ids is not None: return query.filter(States.entity_id.in_(entity_ids)) + query = query.filter(~States.domain.in_(IGNORE_DOMAINS)) entity_filter = self.entity_filter() @@ -568,6 +615,27 @@ class Filters: return query + def bake(self, baked_query, entity_ids=None): + """Update a baked query. + + Works the same as apply on a baked_query. + """ + if entity_ids is not None: + baked_query += lambda q: q.filter( + States.entity_id.in_(bindparam("entity_ids", expanding=True)) + ) + return + + baked_query += lambda q: q.filter(~States.domain.in_(IGNORE_DOMAINS)) + + if ( + self.excluded_entities + or self.excluded_domains + or self.included_entities + or self.included_domains + ): + baked_query += lambda q: q.filter(self.entity_filter()) + def entity_filter(self): """Generate the entity filter query.""" entity_filter = None diff --git a/tests/components/history/test_init.py b/tests/components/history/test_init.py index 34b22481400..56318f3e9fb 100644 --- a/tests/components/history/test_init.py +++ b/tests/components/history/test_init.py @@ -61,7 +61,7 @@ class TestComponentHistory(unittest.TestCase): def test_get_states(self): """Test getting states at a specific point in time.""" - self.init_recorder() + self.test_setup() states = [] now = dt_util.utcnow() @@ -115,7 +115,7 @@ class TestComponentHistory(unittest.TestCase): def test_state_changes_during_period(self): """Test state change during period.""" - self.init_recorder() + self.test_setup() entity_id = "media_player.test" def set_state(state): @@ -156,7 +156,7 @@ class TestComponentHistory(unittest.TestCase): def test_get_last_state_changes(self): """Test number of state changes.""" - self.init_recorder() + self.test_setup() entity_id = "sensor.test" def set_state(state): @@ -195,7 +195,7 @@ class TestComponentHistory(unittest.TestCase): The filter integration uses copy() on states from history. """ - self.init_recorder() + self.test_setup() entity_id = "sensor.test" def set_state(state): @@ -608,7 +608,7 @@ class TestComponentHistory(unittest.TestCase): def test_get_significant_states_only(self): """Test significant states when significant_states_only is set.""" - self.init_recorder() + self.test_setup() entity_id = "sensor.test" def set_state(state, **kwargs): @@ -683,7 +683,7 @@ class TestComponentHistory(unittest.TestCase): We inject a bunch of state updates from media player, zone and thermostat. """ - self.init_recorder() + self.test_setup() mp = "media_player.test" mp2 = "media_player.test2" mp3 = "media_player.test3" -- GitLab