From 24289d5dbb158872f15e7b23ba9251c5b2e939d6 Mon Sep 17 00:00:00 2001 From: mdegat01 <michael.degatano@gmail.com> Date: Tue, 30 Jun 2020 14:02:25 -0400 Subject: [PATCH] Refactor Influx logic to reduce V1 vs V2 code paths (#37232) * refactoring to share logic and sensor startup error test * Added handling for V1 InfluxDBServerError to start-up and runtime and test for it * Added InfluxDBServerError test to sensor setup tests * Raising PlatformNotReady exception from sensor for setup failure * Proper testing of PlatformNotReady error --- homeassistant/components/influxdb/__init__.py | 370 +++++++++++------- homeassistant/components/influxdb/const.py | 32 +- homeassistant/components/influxdb/sensor.py | 308 ++++++--------- tests/components/influxdb/test_init.py | 7 + tests/components/influxdb/test_sensor.py | 88 ++++- 5 files changed, 460 insertions(+), 345 deletions(-) diff --git a/homeassistant/components/influxdb/__init__.py b/homeassistant/components/influxdb/__init__.py index c6e7cd19da0..0989ce098fe 100644 --- a/homeassistant/components/influxdb/__init__.py +++ b/homeassistant/components/influxdb/__init__.py @@ -1,10 +1,11 @@ """Support for sending data to an Influx database.""" +from dataclasses import dataclass import logging import math import queue import threading import time -from typing import Dict +from typing import Any, Callable, Dict, List from influxdb import InfluxDBClient, exceptions from influxdb_client import InfluxDBClient as InfluxDBClientV2 @@ -15,6 +16,10 @@ import urllib3.exceptions import voluptuous as vol from homeassistant.const import ( + CONF_DOMAIN, + CONF_ENTITY_ID, + CONF_TIMEOUT, + CONF_UNIT_OF_MEASUREMENT, CONF_URL, EVENT_HOMEASSISTANT_STOP, EVENT_STATE_CHANGED, @@ -33,8 +38,10 @@ from .const import ( API_VERSION_2, BATCH_BUFFER_SIZE, BATCH_TIMEOUT, - CLIENT_ERROR_V1_WITH_RETRY, - CLIENT_ERROR_V2_WITH_RETRY, + CATCHING_UP_MESSAGE, + CLIENT_ERROR_V1, + CLIENT_ERROR_V2, + CODE_INVALID_INPUTS, COMPONENT_CONFIG_SCHEMA_CONNECTION, CONF_API_VERSION, CONF_BUCKET, @@ -56,18 +63,32 @@ from .const import ( CONF_TOKEN, CONF_USERNAME, CONF_VERIFY_SSL, - CONNECTION_ERROR_WITH_RETRY, + CONNECTION_ERROR, DEFAULT_API_VERSION, DEFAULT_HOST_V2, DEFAULT_SSL_V2, DOMAIN, + EVENT_NEW_STATE, + INFLUX_CONF_FIELDS, + INFLUX_CONF_MEASUREMENT, + INFLUX_CONF_ORG, + INFLUX_CONF_STATE, + INFLUX_CONF_TAGS, + INFLUX_CONF_TIME, + INFLUX_CONF_VALUE, + QUERY_ERROR, QUEUE_BACKLOG_SECONDS, RE_DECIMAL, RE_DIGIT_TAIL, + RESUMED_MESSAGE, RETRY_DELAY, RETRY_INTERVAL, + RETRY_MESSAGE, + TEST_QUERY_V1, + TEST_QUERY_V2, TIMEOUT, WRITE_ERROR, + WROTE_MESSAGE, ) _LOGGER = logging.getLogger(__name__) @@ -120,9 +141,11 @@ def validate_version_specific_config(conf: Dict) -> Dict: return conf -_CONFIG_SCHEMA_ENTRY = vol.Schema({vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string}) +_CUSTOMIZE_ENTITY_SCHEMA = vol.Schema( + {vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string} +) -_CONFIG_SCHEMA = INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend( +_INFLUX_BASE_SCHEMA = INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend( { vol.Optional(CONF_RETRY_COUNT, default=0): cv.positive_int, vol.Optional(CONF_DEFAULT_MEASUREMENT): cv.string, @@ -132,89 +155,28 @@ _CONFIG_SCHEMA = INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend( cv.ensure_list, [cv.string] ), vol.Optional(CONF_COMPONENT_CONFIG, default={}): vol.Schema( - {cv.entity_id: _CONFIG_SCHEMA_ENTRY} + {cv.entity_id: _CUSTOMIZE_ENTITY_SCHEMA} ), vol.Optional(CONF_COMPONENT_CONFIG_GLOB, default={}): vol.Schema( - {cv.string: _CONFIG_SCHEMA_ENTRY} + {cv.string: _CUSTOMIZE_ENTITY_SCHEMA} ), vol.Optional(CONF_COMPONENT_CONFIG_DOMAIN, default={}): vol.Schema( - {cv.string: _CONFIG_SCHEMA_ENTRY} + {cv.string: _CUSTOMIZE_ENTITY_SCHEMA} ), } ) -CONFIG_SCHEMA = vol.Schema( - { - DOMAIN: vol.All( - _CONFIG_SCHEMA.extend(COMPONENT_CONFIG_SCHEMA_CONNECTION), - validate_version_specific_config, - create_influx_url, - ), - }, - extra=vol.ALLOW_EXTRA, +INFLUX_SCHEMA = vol.All( + _INFLUX_BASE_SCHEMA.extend(COMPONENT_CONFIG_SCHEMA_CONNECTION), + validate_version_specific_config, + create_influx_url, ) +CONFIG_SCHEMA = vol.Schema({DOMAIN: INFLUX_SCHEMA}, extra=vol.ALLOW_EXTRA,) -def get_influx_connection(client_kwargs, bucket): - """Create and check the correct influx connection for the API version.""" - if bucket is not None: - # Test connection by synchronously writing nothing. - # If config is valid this will generate a `Bad Request` exception but not make anything. - # If config is invalid we will output an error. - # Hopefully a better way to test connection is added in the future. - try: - influx = InfluxDBClientV2(**client_kwargs) - influx.write_api(write_options=SYNCHRONOUS).write(bucket=bucket) - - except ApiException as exc: - # 400 is the success state since it means we can write we just gave a bad point. - if exc.status != 400: - raise exc - - else: - influx = InfluxDBClient(**client_kwargs) - influx.write_points([]) - - return influx - - -def setup(hass, config): - """Set up the InfluxDB component.""" - conf = config[DOMAIN] - use_v2_api = conf[CONF_API_VERSION] == API_VERSION_2 - bucket = None - kwargs = { - "timeout": TIMEOUT, - } - - if use_v2_api: - kwargs["url"] = conf[CONF_URL] - kwargs["token"] = conf[CONF_TOKEN] - kwargs["org"] = conf[CONF_ORG] - bucket = conf[CONF_BUCKET] - - else: - kwargs["database"] = conf[CONF_DB_NAME] - kwargs["verify_ssl"] = conf[CONF_VERIFY_SSL] - - if CONF_USERNAME in conf: - kwargs["username"] = conf[CONF_USERNAME] - - if CONF_PASSWORD in conf: - kwargs["password"] = conf[CONF_PASSWORD] - - if CONF_HOST in conf: - kwargs["host"] = conf[CONF_HOST] - - if CONF_PATH in conf: - kwargs["path"] = conf[CONF_PATH] - - if CONF_PORT in conf: - kwargs["port"] = conf[CONF_PORT] - - if CONF_SSL in conf: - kwargs["ssl"] = conf[CONF_SSL] +def _generate_event_to_json(conf: Dict) -> Callable[[Dict], str]: + """Build event to json converter and add to config.""" entity_filter = convert_include_exclude_filter(conf) tags = conf.get(CONF_TAGS) tags_attributes = conf.get(CONF_TAGS_ATTRIBUTES) @@ -225,32 +187,10 @@ def setup(hass, config): conf[CONF_COMPONENT_CONFIG_DOMAIN], conf[CONF_COMPONENT_CONFIG_GLOB], ) - max_tries = conf.get(CONF_RETRY_COUNT) - try: - influx = get_influx_connection(kwargs, bucket) - if use_v2_api: - write_api = influx.write_api(write_options=ASYNCHRONOUS) - except ( - OSError, - requests.exceptions.ConnectionError, - urllib3.exceptions.HTTPError, - ) as exc: - _LOGGER.error(CONNECTION_ERROR_WITH_RETRY, exc) - event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config)) - return True - except exceptions.InfluxDBClientError as exc: - _LOGGER.error(CLIENT_ERROR_V1_WITH_RETRY, exc) - event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config)) - return True - except ApiException as exc: - _LOGGER.error(CLIENT_ERROR_V2_WITH_RETRY, exc) - event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config)) - return True - - def event_to_json(event): - """Add an event to the outgoing Influx list.""" - state = event.data.get("new_state") + def event_to_json(event: Dict) -> str: + """Convert event into json in format Influx expects.""" + state = event.data.get(EVENT_NEW_STATE) if ( state is None or state.state in (STATE_UNKNOWN, "", STATE_UNAVAILABLE) @@ -278,7 +218,7 @@ def setup(hass, config): if override_measurement: measurement = override_measurement else: - measurement = state.attributes.get("unit_of_measurement") + measurement = state.attributes.get(CONF_UNIT_OF_MEASUREMENT) if measurement in (None, ""): if default_measurement: measurement = default_measurement @@ -288,57 +228,206 @@ def setup(hass, config): include_uom = False json = { - "measurement": measurement, - "tags": {"domain": state.domain, "entity_id": state.object_id}, - "time": event.time_fired, - "fields": {}, + INFLUX_CONF_MEASUREMENT: measurement, + INFLUX_CONF_TAGS: { + CONF_DOMAIN: state.domain, + CONF_ENTITY_ID: state.object_id, + }, + INFLUX_CONF_TIME: event.time_fired, + INFLUX_CONF_FIELDS: {}, } if _include_state: - json["fields"]["state"] = state.state + json[INFLUX_CONF_FIELDS][INFLUX_CONF_STATE] = state.state if _include_value: - json["fields"]["value"] = _state_as_value + json[INFLUX_CONF_FIELDS][INFLUX_CONF_VALUE] = _state_as_value for key, value in state.attributes.items(): if key in tags_attributes: - json["tags"][key] = value - elif key != "unit_of_measurement" or include_uom: + json[INFLUX_CONF_TAGS][key] = value + elif key != CONF_UNIT_OF_MEASUREMENT or include_uom: # If the key is already in fields - if key in json["fields"]: + if key in json[INFLUX_CONF_FIELDS]: key = f"{key}_" # Prevent column data errors in influxDB. # For each value we try to cast it as float # But if we can not do it we store the value # as string add "_str" postfix to the field key try: - json["fields"][key] = float(value) + json[INFLUX_CONF_FIELDS][key] = float(value) except (ValueError, TypeError): new_key = f"{key}_str" new_value = str(value) - json["fields"][new_key] = new_value + json[INFLUX_CONF_FIELDS][new_key] = new_value if RE_DIGIT_TAIL.match(new_value): - json["fields"][key] = float(RE_DECIMAL.sub("", new_value)) + json[INFLUX_CONF_FIELDS][key] = float( + RE_DECIMAL.sub("", new_value) + ) # Infinity and NaN are not valid floats in InfluxDB try: - if not math.isfinite(json["fields"][key]): - del json["fields"][key] + if not math.isfinite(json[INFLUX_CONF_FIELDS][key]): + del json[INFLUX_CONF_FIELDS][key] except (KeyError, TypeError): pass - json["tags"].update(tags) + json[INFLUX_CONF_TAGS].update(tags) return json - if use_v2_api: - instance = hass.data[DOMAIN] = InfluxThread( - hass, None, bucket, write_api, event_to_json, max_tries - ) - else: - instance = hass.data[DOMAIN] = InfluxThread( - hass, influx, None, None, event_to_json, max_tries - ) + return event_to_json + + +@dataclass +class InfluxClient: + """An InfluxDB client wrapper for V1 or V2.""" + write: Callable[[str], None] + query: Callable[[str, str], List[Any]] + close: Callable[[], None] + + +def get_influx_connection(conf, test_write=False, test_read=False): + """Create the correct influx connection for the API version.""" + kwargs = { + CONF_TIMEOUT: TIMEOUT, + } + + if conf[CONF_API_VERSION] == API_VERSION_2: + kwargs[CONF_URL] = conf[CONF_URL] + kwargs[CONF_TOKEN] = conf[CONF_TOKEN] + kwargs[INFLUX_CONF_ORG] = conf[CONF_ORG] + bucket = conf.get(CONF_BUCKET) + + influx = InfluxDBClientV2(**kwargs) + query_api = influx.query_api() + initial_write_mode = SYNCHRONOUS if test_write else ASYNCHRONOUS + write_api = influx.write_api(write_options=initial_write_mode) + + def write_v2(json): + """Write data to V2 influx.""" + try: + write_api.write(bucket=bucket, record=json) + except (urllib3.exceptions.HTTPError, OSError) as exc: + raise ConnectionError(CONNECTION_ERROR % exc) + except ApiException as exc: + if exc.status == CODE_INVALID_INPUTS: + raise ValueError(WRITE_ERROR % (json, exc)) + raise ConnectionError(CLIENT_ERROR_V2 % exc) + + def query_v2(query, _=None): + """Query V2 influx.""" + try: + return query_api.query(query) + except (urllib3.exceptions.HTTPError, OSError) as exc: + raise ConnectionError(CONNECTION_ERROR % exc) + except ApiException as exc: + if exc.status == CODE_INVALID_INPUTS: + raise ValueError(QUERY_ERROR % (query, exc)) + raise ConnectionError(CLIENT_ERROR_V2 % exc) + + def close_v2(): + """Close V2 influx client.""" + influx.close() + + influx_client = InfluxClient(write_v2, query_v2, close_v2) + if test_write: + # Try to write [] to influx. If we can connect and creds are valid + # Then invalid inputs is returned. Anything else is a broken config + try: + influx_client.write([]) + except ValueError: + pass + write_api = influx.write_api(write_options=ASYNCHRONOUS) + + if test_read: + influx_client.query(TEST_QUERY_V2) + + return influx_client + + # Else it's a V1 client + kwargs[CONF_VERIFY_SSL] = conf[CONF_VERIFY_SSL] + + if CONF_DB_NAME in conf: + kwargs[CONF_DB_NAME] = conf[CONF_DB_NAME] + + if CONF_USERNAME in conf: + kwargs[CONF_USERNAME] = conf[CONF_USERNAME] + + if CONF_PASSWORD in conf: + kwargs[CONF_PASSWORD] = conf[CONF_PASSWORD] + + if CONF_HOST in conf: + kwargs[CONF_HOST] = conf[CONF_HOST] + + if CONF_PATH in conf: + kwargs[CONF_PATH] = conf[CONF_PATH] + + if CONF_PORT in conf: + kwargs[CONF_PORT] = conf[CONF_PORT] + + if CONF_SSL in conf: + kwargs[CONF_SSL] = conf[CONF_SSL] + + influx = InfluxDBClient(**kwargs) + + def write_v1(json): + """Write data to V1 influx.""" + try: + influx.write_points(json) + except ( + requests.exceptions.RequestException, + exceptions.InfluxDBServerError, + OSError, + ) as exc: + raise ConnectionError(CONNECTION_ERROR % exc) + except exceptions.InfluxDBClientError as exc: + if exc.code == CODE_INVALID_INPUTS: + raise ValueError(WRITE_ERROR % (json, exc)) + raise ConnectionError(CLIENT_ERROR_V1 % exc) + + def query_v1(query, database=None): + """Query V1 influx.""" + try: + return list(influx.query(query, database=database).get_points()) + except ( + requests.exceptions.RequestException, + exceptions.InfluxDBServerError, + OSError, + ) as exc: + raise ConnectionError(CONNECTION_ERROR % exc) + except exceptions.InfluxDBClientError as exc: + if exc.code == CODE_INVALID_INPUTS: + raise ValueError(QUERY_ERROR % (query, exc)) + raise ConnectionError(CLIENT_ERROR_V1 % exc) + + def close_v1(): + """Close the V1 Influx client.""" + influx.close() + + influx_client = InfluxClient(write_v1, query_v1, close_v1) + if test_write: + influx_client.write([]) + + if test_read: + influx_client.query(TEST_QUERY_V1) + + return influx_client + + +def setup(hass, config): + """Set up the InfluxDB component.""" + conf = config[DOMAIN] + try: + influx = get_influx_connection(conf, test_write=True) + except ConnectionError as exc: + _LOGGER.error(RETRY_MESSAGE, exc) + event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config)) + return True + + event_to_json = _generate_event_to_json(conf) + max_tries = conf.get(CONF_RETRY_COUNT) + instance = hass.data[DOMAIN] = InfluxThread(hass, influx, event_to_json, max_tries) instance.start() def shutdown(event): @@ -355,13 +444,11 @@ def setup(hass, config): class InfluxThread(threading.Thread): """A threaded event handler class.""" - def __init__(self, hass, influx, bucket, write_api, event_to_json, max_tries): + def __init__(self, hass, influx, event_to_json, max_tries): """Initialize the listener.""" - threading.Thread.__init__(self, name="InfluxDB") + threading.Thread.__init__(self, name=DOMAIN) self.queue = queue.Queue() self.influx = influx - self.bucket = bucket - self.write_api = write_api self.event_to_json = event_to_json self.max_tries = max_tries self.write_errors = 0 @@ -410,7 +497,7 @@ class InfluxThread(threading.Thread): pass if dropped: - _LOGGER.warning("Catching up, dropped %d old events", dropped) + _LOGGER.warning(CATCHING_UP_MESSAGE, dropped) return count, json @@ -418,28 +505,23 @@ class InfluxThread(threading.Thread): """Write preprocessed events to influxdb, with retry.""" for retry in range(self.max_tries + 1): try: - if self.write_api is not None: - self.write_api.write(bucket=self.bucket, record=json) - else: - self.influx.write_points(json) + self.influx.write(json) if self.write_errors: - _LOGGER.error("Resumed, lost %d events", self.write_errors) + _LOGGER.error(RESUMED_MESSAGE, self.write_errors) self.write_errors = 0 - _LOGGER.debug("Wrote %d events", len(json)) + _LOGGER.debug(WROTE_MESSAGE, len(json)) + break + except ValueError as err: + _LOGGER.error(err) break - except ( - exceptions.InfluxDBClientError, - exceptions.InfluxDBServerError, - OSError, - ApiException, - ) as err: + except ConnectionError as err: if retry < self.max_tries: time.sleep(RETRY_DELAY) else: if not self.write_errors: - _LOGGER.error(WRITE_ERROR, json, err) + _LOGGER.error(err) self.write_errors += len(json) def run(self): diff --git a/homeassistant/components/influxdb/const.py b/homeassistant/components/influxdb/const.py index b59ead3a849..eff6fbb1e44 100644 --- a/homeassistant/components/influxdb/const.py +++ b/homeassistant/components/influxdb/const.py @@ -53,7 +53,18 @@ DEFAULT_GROUP_FUNCTION = "mean" DEFAULT_FIELD = "value" DEFAULT_RANGE_START = "-15m" DEFAULT_RANGE_STOP = "now()" +DEFAULT_FUNCTION_FLUX = "|> limit(n: 1)" +INFLUX_CONF_MEASUREMENT = "measurement" +INFLUX_CONF_TAGS = "tags" +INFLUX_CONF_TIME = "time" +INFLUX_CONF_FIELDS = "fields" +INFLUX_CONF_STATE = "state" +INFLUX_CONF_VALUE = "value" +INFLUX_CONF_VALUE_V2 = "_value" +INFLUX_CONF_ORG = "org" + +EVENT_NEW_STATE = "new_state" DOMAIN = "influxdb" API_VERSION_2 = "2" TIMEOUT = 5 @@ -65,7 +76,8 @@ BATCH_BUFFER_SIZE = 100 LANGUAGE_INFLUXQL = "influxQL" LANGUAGE_FLUX = "flux" TEST_QUERY_V1 = "SHOW SERIES LIMIT 1;" -TEST_QUERY_V2 = "buckets() |> limit(n:1)" +TEST_QUERY_V2 = f"buckets() {DEFAULT_FUNCTION_FLUX}" +CODE_INVALID_INPUTS = 400 MIN_TIME_BETWEEN_UPDATES = timedelta(seconds=60) @@ -91,11 +103,19 @@ WRITE_ERROR = "Could not write '%s' to influx due to '%s'." QUERY_ERROR = ( "Could not execute query '%s' due to '%s'. Check the syntax of your query." ) -RETRY_MESSAGE = f"Retrying again in {RETRY_INTERVAL} seconds." -CONNECTION_ERROR_WITH_RETRY = f"{CONNECTION_ERROR} {RETRY_MESSAGE}" -CLIENT_ERROR_V1_WITH_RETRY = f"{CLIENT_ERROR_V1} {RETRY_MESSAGE}" -CLIENT_ERROR_V2_WITH_RETRY = f"{CLIENT_ERROR_V2} {RETRY_MESSAGE}" - +RETRY_MESSAGE = f"%s Retrying in {RETRY_INTERVAL} seconds." +CATCHING_UP_MESSAGE = "Catching up, dropped %d old events." +RESUMED_MESSAGE = "Resumed, lost %d events." +WROTE_MESSAGE = "Wrote %d events." +RUNNING_QUERY_MESSAGE = "Running query: %s." +QUERY_NO_RESULTS_MESSAGE = "Query returned no results, sensor state set to UNKNOWN: %s." +QUERY_MULTIPLE_RESULTS_MESSAGE = ( + "Query returned multiple results, only value from first one is shown: %s." +) +RENDERING_QUERY_MESSAGE = "Rendering query: %s." +RENDERING_QUERY_ERROR_MESSAGE = "Could not render query template: %s." +RENDERING_WHERE_MESSAGE = "Rendering where: %s." +RENDERING_WHERE_ERROR_MESSAGE = "Could not render where template: %s." COMPONENT_CONFIG_SCHEMA_CONNECTION = { # Connection config for V1 and V2 APIs. diff --git a/homeassistant/components/influxdb/sensor.py b/homeassistant/components/influxdb/sensor.py index 302bcde2373..eb52179126b 100644 --- a/homeassistant/components/influxdb/sensor.py +++ b/homeassistant/components/influxdb/sensor.py @@ -2,34 +2,23 @@ import logging from typing import Dict -from influxdb import InfluxDBClient, exceptions -from influxdb_client import InfluxDBClient as InfluxDBClientV2 -from influxdb_client.rest import ApiException import voluptuous as vol -from homeassistant.components.sensor import PLATFORM_SCHEMA +from homeassistant.components.sensor import PLATFORM_SCHEMA as SENSOR_PLATFORM_SCHEMA from homeassistant.const import ( CONF_API_VERSION, - CONF_HOST, CONF_NAME, - CONF_PASSWORD, - CONF_PATH, - CONF_PORT, - CONF_SSL, - CONF_TOKEN, CONF_UNIT_OF_MEASUREMENT, - CONF_URL, - CONF_USERNAME, CONF_VALUE_TEMPLATE, - CONF_VERIFY_SSL, + EVENT_HOMEASSISTANT_STOP, STATE_UNKNOWN, ) -from homeassistant.exceptions import TemplateError +from homeassistant.exceptions import PlatformNotReady, TemplateError import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import Entity from homeassistant.util import Throttle -from . import create_influx_url, validate_version_specific_config +from . import create_influx_url, get_influx_connection, validate_version_specific_config from .const import ( API_VERSION_2, COMPONENT_CONFIG_SCHEMA_CONNECTION, @@ -38,8 +27,8 @@ from .const import ( CONF_FIELD, CONF_GROUP_FUNCTION, CONF_IMPORTS, + CONF_LANGUAGE, CONF_MEASUREMENT_NAME, - CONF_ORG, CONF_QUERIES, CONF_QUERIES_FLUX, CONF_QUERY, @@ -48,16 +37,63 @@ from .const import ( CONF_WHERE, DEFAULT_API_VERSION, DEFAULT_FIELD, + DEFAULT_FUNCTION_FLUX, DEFAULT_GROUP_FUNCTION, DEFAULT_RANGE_START, DEFAULT_RANGE_STOP, + INFLUX_CONF_VALUE, + INFLUX_CONF_VALUE_V2, + LANGUAGE_FLUX, + LANGUAGE_INFLUXQL, MIN_TIME_BETWEEN_UPDATES, - TEST_QUERY_V1, - TEST_QUERY_V2, + QUERY_MULTIPLE_RESULTS_MESSAGE, + QUERY_NO_RESULTS_MESSAGE, + RENDERING_QUERY_ERROR_MESSAGE, + RENDERING_QUERY_MESSAGE, + RENDERING_WHERE_ERROR_MESSAGE, + RENDERING_WHERE_MESSAGE, + RUNNING_QUERY_MESSAGE, ) _LOGGER = logging.getLogger(__name__) + +def _merge_connection_config_into_query(conf, query): + """Merge connection details into each configured query.""" + for key in conf: + if key not in query and key not in [CONF_QUERIES, CONF_QUERIES_FLUX]: + query[key] = conf[key] + + +def validate_query_format_for_version(conf: Dict) -> Dict: + """Ensure queries are provided in correct format based on API version.""" + if conf[CONF_API_VERSION] == API_VERSION_2: + if CONF_QUERIES_FLUX not in conf: + raise vol.Invalid( + f"{CONF_QUERIES_FLUX} is required when {CONF_API_VERSION} is {API_VERSION_2}" + ) + + for query in conf[CONF_QUERIES_FLUX]: + _merge_connection_config_into_query(conf, query) + query[CONF_LANGUAGE] = LANGUAGE_FLUX + + del conf[CONF_BUCKET] + + else: + if CONF_QUERIES not in conf: + raise vol.Invalid( + f"{CONF_QUERIES} is required when {CONF_API_VERSION} is {DEFAULT_API_VERSION}" + ) + + for query in conf[CONF_QUERIES]: + _merge_connection_config_into_query(conf, query) + query[CONF_LANGUAGE] = LANGUAGE_INFLUXQL + + del conf[CONF_DB_NAME] + + return conf + + _QUERY_SENSOR_SCHEMA = vol.Schema( { vol.Required(CONF_NAME): cv.string, @@ -67,7 +103,7 @@ _QUERY_SENSOR_SCHEMA = vol.Schema( ) _QUERY_SCHEMA = { - "InfluxQL": _QUERY_SENSOR_SCHEMA.extend( + LANGUAGE_INFLUXQL: _QUERY_SENSOR_SCHEMA.extend( { vol.Optional(CONF_DB_NAME): cv.string, vol.Required(CONF_MEASUREMENT_NAME): cv.string, @@ -78,7 +114,7 @@ _QUERY_SCHEMA = { vol.Required(CONF_WHERE): cv.template, } ), - "Flux": _QUERY_SENSOR_SCHEMA.extend( + LANGUAGE_FLUX: _QUERY_SENSOR_SCHEMA.extend( { vol.Optional(CONF_BUCKET): cv.string, vol.Optional(CONF_RANGE_START, default=DEFAULT_RANGE_START): cv.string, @@ -90,29 +126,11 @@ _QUERY_SCHEMA = { ), } - -def validate_query_format_for_version(conf: Dict) -> Dict: - """Ensure queries are provided in correct format based on API version.""" - if conf[CONF_API_VERSION] == API_VERSION_2: - if CONF_QUERIES_FLUX not in conf: - raise vol.Invalid( - f"{CONF_QUERIES_FLUX} is required when {CONF_API_VERSION} is {API_VERSION_2}" - ) - - else: - if CONF_QUERIES not in conf: - raise vol.Invalid( - f"{CONF_QUERIES} is required when {CONF_API_VERSION} is {DEFAULT_API_VERSION}" - ) - - return conf - - PLATFORM_SCHEMA = vol.All( - PLATFORM_SCHEMA.extend(COMPONENT_CONFIG_SCHEMA_CONNECTION).extend( + SENSOR_PLATFORM_SCHEMA.extend(COMPONENT_CONFIG_SCHEMA_CONNECTION).extend( { - vol.Exclusive(CONF_QUERIES, "queries"): [_QUERY_SCHEMA["InfluxQL"]], - vol.Exclusive(CONF_QUERIES_FLUX, "queries"): [_QUERY_SCHEMA["Flux"]], + vol.Exclusive(CONF_QUERIES, "queries"): [_QUERY_SCHEMA[LANGUAGE_INFLUXQL]], + vol.Exclusive(CONF_QUERIES_FLUX, "queries"): [_QUERY_SCHEMA[LANGUAGE_FLUX]], } ), validate_version_specific_config, @@ -123,61 +141,23 @@ PLATFORM_SCHEMA = vol.All( def setup_platform(hass, config, add_entities, discovery_info=None): """Set up the InfluxDB component.""" - use_v2_api = config[CONF_API_VERSION] == API_VERSION_2 - queries = None - - if use_v2_api: - influx_conf = { - "url": config[CONF_URL], - "token": config[CONF_TOKEN], - "org": config[CONF_ORG], - } - bucket = config[CONF_BUCKET] - queries = config[CONF_QUERIES_FLUX] - - for v2_query in queries: - if CONF_BUCKET not in v2_query: - v2_query[CONF_BUCKET] = bucket - - else: - influx_conf = { - "database": config[CONF_DB_NAME], - "verify_ssl": config[CONF_VERIFY_SSL], - } - - if CONF_USERNAME in config: - influx_conf["username"] = config[CONF_USERNAME] + try: + influx = get_influx_connection(config, test_read=True) + except ConnectionError as exc: + _LOGGER.error(exc) + raise PlatformNotReady() - if CONF_PASSWORD in config: - influx_conf["password"] = config[CONF_PASSWORD] + queries = config[CONF_QUERIES_FLUX if CONF_QUERIES_FLUX in config else CONF_QUERIES] + entities = [InfluxSensor(hass, influx, query) for query in queries] + add_entities(entities, update_before_add=True) - if CONF_HOST in config: - influx_conf["host"] = config[CONF_HOST] - - if CONF_PATH in config: - influx_conf["path"] = config[CONF_PATH] - - if CONF_PORT in config: - influx_conf["port"] = config[CONF_PORT] - - if CONF_SSL in config: - influx_conf["ssl"] = config[CONF_SSL] - - queries = config[CONF_QUERIES] - - entities = [] - for query in queries: - sensor = InfluxSensor(hass, influx_conf, query, use_v2_api) - if sensor.connected: - entities.append(sensor) - - add_entities(entities, True) + hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, lambda _: influx.close()) class InfluxSensor(Entity): """Implementation of a Influxdb sensor.""" - def __init__(self, hass, influx_conf, query, use_v2_api): + def __init__(self, hass, influx, query): """Initialize the sensor.""" self._name = query.get(CONF_NAME) self._unit_of_measurement = query.get(CONF_UNIT_OF_MEASUREMENT) @@ -190,66 +170,30 @@ class InfluxSensor(Entity): self._state = None self._hass = hass - if use_v2_api: - influx = InfluxDBClientV2(**influx_conf) - query_api = influx.query_api() + if query[CONF_LANGUAGE] == LANGUAGE_FLUX: query_clause = query.get(CONF_QUERY) query_clause.hass = hass - bucket = query[CONF_BUCKET] + self.data = InfluxFluxSensorData( + influx, + query.get(CONF_BUCKET), + query.get(CONF_RANGE_START), + query.get(CONF_RANGE_STOP), + query_clause, + query.get(CONF_IMPORTS), + query.get(CONF_GROUP_FUNCTION), + ) else: - if CONF_DB_NAME in query: - kwargs = influx_conf.copy() - kwargs[CONF_DB_NAME] = query[CONF_DB_NAME] - else: - kwargs = influx_conf - - influx = InfluxDBClient(**kwargs) where_clause = query.get(CONF_WHERE) where_clause.hass = hass - query_api = None - - try: - if query_api is not None: - query_api.query(TEST_QUERY_V2) - self.connected = True - self.data = InfluxSensorDataV2( - query_api, - bucket, - query.get(CONF_RANGE_START), - query.get(CONF_RANGE_STOP), - query_clause, - query.get(CONF_IMPORTS), - query.get(CONF_GROUP_FUNCTION), - ) - - else: - influx.query(TEST_QUERY_V1) - self.connected = True - self.data = InfluxSensorDataV1( - influx, - query.get(CONF_GROUP_FUNCTION), - query.get(CONF_FIELD), - query.get(CONF_MEASUREMENT_NAME), - where_clause, - ) - except exceptions.InfluxDBClientError as exc: - _LOGGER.error( - "Database host is not accessible due to '%s', please" - " check your entries in the configuration file and" - " that the database exists and is READ/WRITE", - exc, + self.data = InfluxQLSensorData( + influx, + query.get(CONF_DB_NAME), + query.get(CONF_GROUP_FUNCTION), + query.get(CONF_FIELD), + query.get(CONF_MEASUREMENT_NAME), + where_clause, ) - self.connected = False - except ApiException as exc: - _LOGGER.error( - "Bucket is not accessible due to '%s', please " - "check your entries in the configuration file (url, org, " - "bucket, etc.) and verify that the org and bucket exist and the " - "provided token has READ access.", - exc, - ) - self.connected = False @property def name(self): @@ -285,14 +229,12 @@ class InfluxSensor(Entity): self._state = value -class InfluxSensorDataV2: - """Class for handling the data retrieval with v2 API.""" +class InfluxFluxSensorData: + """Class for handling the data retrieval from Influx with Flux query.""" - def __init__( - self, query_api, bucket, range_start, range_stop, query, imports, group - ): + def __init__(self, influx, bucket, range_start, range_stop, query, imports, group): """Initialize the data object.""" - self.query_api = query_api + self.influx = influx self.bucket = bucket self.range_start = range_start self.range_stop = range_stop @@ -308,57 +250,47 @@ class InfluxSensorDataV2: self.query_prefix = f'import "{i}" {self.query_prefix}' if group is None: - self.query_postfix = "|> limit(n: 1)" + self.query_postfix = DEFAULT_FUNCTION_FLUX else: - self.query_postfix = f'|> {group}(column: "_value")' + self.query_postfix = f'|> {group}(column: "{INFLUX_CONF_VALUE_V2}")' @Throttle(MIN_TIME_BETWEEN_UPDATES) def update(self): """Get the latest data by querying influx.""" - _LOGGER.debug("Rendering query: %s", self.query) + _LOGGER.debug(RENDERING_QUERY_MESSAGE, self.query) try: rendered_query = self.query.render() except TemplateError as ex: - _LOGGER.error("Could not render query template: %s", ex) + _LOGGER.error(RENDERING_QUERY_ERROR_MESSAGE, ex) return self.full_query = f"{self.query_prefix} {rendered_query} {self.query_postfix}" - _LOGGER.info("Running query: %s", self.full_query) + _LOGGER.debug(RUNNING_QUERY_MESSAGE, self.full_query) try: - tables = self.query_api.query(self.full_query) - except (OSError, ApiException) as exc: - _LOGGER.error( - "Could not execute query '%s' due to '%s', " - "Check the syntax of your query", - self.full_query, - exc, - ) + tables = self.influx.query(self.full_query) + except (ConnectionError, ValueError) as exc: + _LOGGER.error(exc) self.value = None return if not tables: - _LOGGER.warning( - "Query returned no results, sensor state set to UNKNOWN: %s", - self.full_query, - ) + _LOGGER.warning(QUERY_NO_RESULTS_MESSAGE, self.full_query) self.value = None else: - if len(tables) > 1: - _LOGGER.warning( - "Query returned multiple tables, only value from first one is shown: %s", - self.full_query, - ) - self.value = tables[0].records[0].values["_value"] + if len(tables) > 1 or len(tables[0].records) > 1: + _LOGGER.warning(QUERY_MULTIPLE_RESULTS_MESSAGE, self.full_query) + self.value = tables[0].records[0].values[INFLUX_CONF_VALUE_V2] -class InfluxSensorDataV1: +class InfluxQLSensorData: """Class for handling the data retrieval with v1 API.""" - def __init__(self, influx, group, field, measurement, where): + def __init__(self, influx, db_name, group, field, measurement, where): """Initialize the data object.""" self.influx = influx + self.db_name = db_name self.group = group self.field = field self.measurement = measurement @@ -369,38 +301,28 @@ class InfluxSensorDataV1: @Throttle(MIN_TIME_BETWEEN_UPDATES) def update(self): """Get the latest data with a shell command.""" - _LOGGER.info("Rendering where: %s", self.where) + _LOGGER.debug(RENDERING_WHERE_MESSAGE, self.where) try: where_clause = self.where.render() except TemplateError as ex: - _LOGGER.error("Could not render where clause template: %s", ex) + _LOGGER.error(RENDERING_WHERE_ERROR_MESSAGE, ex) return - self.query = f"select {self.group}({self.field}) as value from {self.measurement} where {where_clause}" + self.query = f"select {self.group}({self.field}) as {INFLUX_CONF_VALUE} from {self.measurement} where {where_clause}" - _LOGGER.info("Running query: %s", self.query) + _LOGGER.debug(RUNNING_QUERY_MESSAGE, self.query) try: - points = list(self.influx.query(self.query).get_points()) - except (OSError, exceptions.InfluxDBClientError) as exc: - _LOGGER.error( - "Could not execute query '%s' due to '%s', " - "Check the syntax of your query", - self.query, - exc, - ) + points = self.influx.query(self.query, self.db_name) + except (ConnectionError, ValueError) as exc: + _LOGGER.error(exc) self.value = None return if not points: - _LOGGER.warning( - "Query returned no points, sensor state set to UNKNOWN: %s", self.query - ) + _LOGGER.warning(QUERY_NO_RESULTS_MESSAGE, self.query) self.value = None else: if len(points) > 1: - _LOGGER.warning( - "Query returned multiple points, only first one shown: %s", - self.query, - ) - self.value = points[0].get("value") + _LOGGER.warning(QUERY_MULTIPLE_RESULTS_MESSAGE, self.query) + self.value = points[0].get(INFLUX_CONF_VALUE) diff --git a/tests/components/influxdb/test_init.py b/tests/components/influxdb/test_init.py index 04486f8f9b3..d9b41ed967c 100644 --- a/tests/components/influxdb/test_init.py +++ b/tests/components/influxdb/test_init.py @@ -1226,6 +1226,13 @@ async def test_event_listener_attribute_name_conflict( influxdb.DEFAULT_API_VERSION, influxdb.exceptions.InfluxDBClientError("fail"), ), + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + influxdb.exceptions.InfluxDBServerError("fail"), + ), ( influxdb.API_VERSION_2, BASE_V2_CONFIG, diff --git a/tests/components/influxdb/test_sensor.py b/tests/components/influxdb/test_sensor.py index 4ef42d0fa3a..c75c7d35578 100644 --- a/tests/components/influxdb/test_sensor.py +++ b/tests/components/influxdb/test_sensor.py @@ -1,8 +1,9 @@ """The tests for the InfluxDB sensor.""" from dataclasses import dataclass +from datetime import timedelta from typing import Dict, List, Type -from influxdb.exceptions import InfluxDBClientError +from influxdb.exceptions import InfluxDBClientError, InfluxDBServerError from influxdb_client.rest import ApiException import pytest from voluptuous import Invalid @@ -18,12 +19,15 @@ from homeassistant.components.influxdb.sensor import PLATFORM_SCHEMA import homeassistant.components.sensor as sensor from homeassistant.const import STATE_UNKNOWN from homeassistant.setup import async_setup_component +from homeassistant.util import dt as dt_util from tests.async_mock import MagicMock, patch +from tests.common import async_fire_time_changed INFLUXDB_PATH = "homeassistant.components.influxdb" -INFLUXDB_CLIENT_PATH = f"{INFLUXDB_PATH}.sensor.InfluxDBClient" +INFLUXDB_CLIENT_PATH = f"{INFLUXDB_PATH}.InfluxDBClient" INFLUXDB_SENSOR_PATH = f"{INFLUXDB_PATH}.sensor" +PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 BASE_V1_CONFIG = {} BASE_V2_CONFIG = { @@ -137,6 +141,8 @@ def _set_query_mock_v1(mock_influx_client, return_value=None, side_effect=None): query_api.side_effect = get_return_value + return query_api + def _set_query_mock_v2(mock_influx_client, return_value=None, side_effect=None): """Set return value or side effect for the V2 client.""" @@ -149,6 +155,8 @@ def _set_query_mock_v2(mock_influx_client, return_value=None, side_effect=None): query_api.return_value = return_value + return query_api + async def _setup(hass, config_ext, queries, expected_sensors): """Create client and test expected sensors.""" @@ -451,3 +459,79 @@ async def test_error_rendering_template( assert ( len([record for record in caplog.records if record.levelname == "ERROR"]) == 1 ) + + +@pytest.mark.parametrize( + "mock_client, config_ext, queries, set_query_mock, test_exception, make_resultset", + [ + ( + DEFAULT_API_VERSION, + BASE_V1_CONFIG, + BASE_V1_QUERY, + _set_query_mock_v1, + OSError("fail"), + _make_v1_resultset, + ), + ( + DEFAULT_API_VERSION, + BASE_V1_CONFIG, + BASE_V1_QUERY, + _set_query_mock_v1, + InfluxDBClientError("fail"), + _make_v1_resultset, + ), + ( + DEFAULT_API_VERSION, + BASE_V1_CONFIG, + BASE_V1_QUERY, + _set_query_mock_v1, + InfluxDBServerError("fail"), + _make_v1_resultset, + ), + ( + API_VERSION_2, + BASE_V2_CONFIG, + BASE_V2_QUERY, + _set_query_mock_v2, + OSError("fail"), + _make_v2_resultset, + ), + ( + API_VERSION_2, + BASE_V2_CONFIG, + BASE_V2_QUERY, + _set_query_mock_v2, + ApiException(), + _make_v2_resultset, + ), + ], + indirect=["mock_client"], +) +async def test_connection_error_at_startup( + hass, + caplog, + mock_client, + config_ext, + queries, + set_query_mock, + test_exception, + make_resultset, +): + """Test behavior of sensor when influx returns error.""" + query_api = set_query_mock(mock_client, side_effect=test_exception) + expected_sensor = "sensor.test" + + # Test sensor is not setup first time due to connection error + await _setup(hass, config_ext, queries, []) + assert hass.states.get(expected_sensor) is None + assert ( + len([record for record in caplog.records if record.levelname == "ERROR"]) == 1 + ) + + # Stop throwing exception and advance time to test setup succeeds + query_api.reset_mock(side_effect=True) + set_query_mock(mock_client, return_value=make_resultset(42)) + new_time = dt_util.utcnow() + timedelta(seconds=PLATFORM_NOT_READY_BASE_WAIT_TIME) + async_fire_time_changed(hass, new_time) + await hass.async_block_till_done() + assert hass.states.get(expected_sensor) is not None -- GitLab