diff --git a/.coveragerc b/.coveragerc index b5427435636577ca3b47432706d96cb63aefbe43..3bf3fa10947f785f6c1b9fb390370fb0cfdf94c7 100644 --- a/.coveragerc +++ b/.coveragerc @@ -67,6 +67,8 @@ omit = homeassistant/components/arwn/sensor.py homeassistant/components/asterisk_cdr/mailbox.py homeassistant/components/asterisk_mbox/* + homeassistant/components/asuswrt/__init__.py + homeassistant/components/asuswrt/router.py homeassistant/components/aten_pe/* homeassistant/components/atome/* homeassistant/components/aurora/__init__.py diff --git a/homeassistant/components/asuswrt/__init__.py b/homeassistant/components/asuswrt/__init__.py index 9cd47d803dede3aecef3173718b2a0660c3b0cdf..d2eb47fa2d297790b4e89228968b22cf72e433e3 100644 --- a/homeassistant/components/asuswrt/__init__.py +++ b/homeassistant/components/asuswrt/__init__.py @@ -1,9 +1,10 @@ """Support for ASUSWRT devices.""" +import asyncio import logging -from aioasuswrt.asuswrt import AsusWrt import voluptuous as vol +from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.const import ( CONF_HOST, CONF_MODE, @@ -12,108 +13,165 @@ from homeassistant.const import ( CONF_PROTOCOL, CONF_SENSORS, CONF_USERNAME, + EVENT_HOMEASSISTANT_STOP, ) from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.discovery import async_load_platform -from homeassistant.helpers.event import async_call_later +from homeassistant.helpers.typing import HomeAssistantType + +from .const import ( + CONF_DNSMASQ, + CONF_INTERFACE, + CONF_REQUIRE_IP, + CONF_SSH_KEY, + DATA_ASUSWRT, + DEFAULT_DNSMASQ, + DEFAULT_INTERFACE, + DEFAULT_SSH_PORT, + DOMAIN, + MODE_AP, + MODE_ROUTER, + PROTOCOL_SSH, + PROTOCOL_TELNET, + SENSOR_TYPES, +) +from .router import AsusWrtRouter -_LOGGER = logging.getLogger(__name__) +PLATFORMS = ["device_tracker", "sensor"] -CONF_DNSMASQ = "dnsmasq" -CONF_INTERFACE = "interface" CONF_PUB_KEY = "pub_key" -CONF_REQUIRE_IP = "require_ip" -CONF_SSH_KEY = "ssh_key" - -DOMAIN = "asuswrt" -DATA_ASUSWRT = DOMAIN - -DEFAULT_SSH_PORT = 22 -DEFAULT_INTERFACE = "eth0" -DEFAULT_DNSMASQ = "/var/lib/misc" - -FIRST_RETRY_TIME = 60 -MAX_RETRY_TIME = 900 - SECRET_GROUP = "Password or SSH Key" -SENSOR_TYPES = ["devices", "upload_speed", "download_speed", "download", "upload"] + +_LOGGER = logging.getLogger(__name__) CONFIG_SCHEMA = vol.Schema( - { - DOMAIN: vol.Schema( - { - vol.Required(CONF_HOST): cv.string, - vol.Required(CONF_USERNAME): cv.string, - vol.Optional(CONF_PROTOCOL, default="ssh"): vol.In(["ssh", "telnet"]), - vol.Optional(CONF_MODE, default="router"): vol.In(["router", "ap"]), - vol.Optional(CONF_PORT, default=DEFAULT_SSH_PORT): cv.port, - vol.Optional(CONF_REQUIRE_IP, default=True): cv.boolean, - vol.Exclusive(CONF_PASSWORD, SECRET_GROUP): cv.string, - vol.Exclusive(CONF_SSH_KEY, SECRET_GROUP): cv.isfile, - vol.Exclusive(CONF_PUB_KEY, SECRET_GROUP): cv.isfile, - vol.Optional(CONF_SENSORS): vol.All( - cv.ensure_list, [vol.In(SENSOR_TYPES)] - ), - vol.Optional(CONF_INTERFACE, default=DEFAULT_INTERFACE): cv.string, - vol.Optional(CONF_DNSMASQ, default=DEFAULT_DNSMASQ): cv.string, - } - ) - }, + vol.All( + cv.deprecated(DOMAIN), + { + DOMAIN: vol.Schema( + { + vol.Required(CONF_HOST): cv.string, + vol.Required(CONF_USERNAME): cv.string, + vol.Optional(CONF_PROTOCOL, default=PROTOCOL_SSH): vol.In( + [PROTOCOL_SSH, PROTOCOL_TELNET] + ), + vol.Optional(CONF_MODE, default=MODE_ROUTER): vol.In( + [MODE_ROUTER, MODE_AP] + ), + vol.Optional(CONF_PORT, default=DEFAULT_SSH_PORT): cv.port, + vol.Optional(CONF_REQUIRE_IP, default=True): cv.boolean, + vol.Exclusive(CONF_PASSWORD, SECRET_GROUP): cv.string, + vol.Exclusive(CONF_SSH_KEY, SECRET_GROUP): cv.isfile, + vol.Exclusive(CONF_PUB_KEY, SECRET_GROUP): cv.isfile, + vol.Optional(CONF_SENSORS): vol.All( + cv.ensure_list, [vol.In(SENSOR_TYPES)] + ), + vol.Optional(CONF_INTERFACE, default=DEFAULT_INTERFACE): cv.string, + vol.Optional(CONF_DNSMASQ, default=DEFAULT_DNSMASQ): cv.string, + } + ) + }, + ), extra=vol.ALLOW_EXTRA, ) -async def async_setup(hass, config, retry_delay=FIRST_RETRY_TIME): - """Set up the asuswrt component.""" - conf = config[DOMAIN] - - api = AsusWrt( - conf[CONF_HOST], - conf[CONF_PORT], - conf[CONF_PROTOCOL] == "telnet", - conf[CONF_USERNAME], - conf.get(CONF_PASSWORD, ""), - conf.get("ssh_key", conf.get("pub_key", "")), - conf[CONF_MODE], - conf[CONF_REQUIRE_IP], - interface=conf[CONF_INTERFACE], - dnsmasq=conf[CONF_DNSMASQ], - ) +async def async_setup(hass, config): + """Set up the AsusWrt integration.""" + conf = config.get(DOMAIN) + if conf is None: + return True + + # save the options from config yaml + options = {} + mode = conf.get(CONF_MODE, MODE_ROUTER) + for name, value in conf.items(): + if name in ([CONF_DNSMASQ, CONF_INTERFACE, CONF_REQUIRE_IP]): + if name == CONF_REQUIRE_IP and mode != MODE_AP: + continue + options[name] = value + hass.data[DOMAIN] = {"yaml_options": options} + + # check if already configured + domains_list = hass.config_entries.async_domains() + if DOMAIN in domains_list: + return True + + # remove not required config keys + pub_key = conf.pop(CONF_PUB_KEY, "") + if pub_key: + conf[CONF_SSH_KEY] = pub_key - try: - await api.connection.async_connect() - except OSError as ex: - _LOGGER.warning( - "Error [%s] connecting %s to %s. Will retry in %s seconds...", - str(ex), - DOMAIN, - conf[CONF_HOST], - retry_delay, + conf.pop(CONF_REQUIRE_IP, True) + conf.pop(CONF_SENSORS, {}) + conf.pop(CONF_INTERFACE, "") + conf.pop(CONF_DNSMASQ, "") + + hass.async_create_task( + hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_IMPORT}, data=conf ) + ) - async def retry_setup(now): - """Retry setup if a error happens on asuswrt API.""" - await async_setup( - hass, config, retry_delay=min(2 * retry_delay, MAX_RETRY_TIME) - ) + return True - async_call_later(hass, retry_delay, retry_setup) - return True +async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry): + """Set up AsusWrt platform.""" - if not api.is_connected: - _LOGGER.error("Error connecting %s to %s", DOMAIN, conf[CONF_HOST]) - return False + # import options from yaml if empty + yaml_options = hass.data.get(DOMAIN, {}).pop("yaml_options", {}) + if not entry.options and yaml_options: + hass.config_entries.async_update_entry(entry, options=yaml_options) - hass.data[DATA_ASUSWRT] = api + router = AsusWrtRouter(hass, entry) + await router.setup() - hass.async_create_task( - async_load_platform( - hass, "sensor", DOMAIN, config[DOMAIN].get(CONF_SENSORS), config + router.async_on_close(entry.add_update_listener(update_listener)) + + for platform in PLATFORMS: + hass.async_create_task( + hass.config_entries.async_forward_entry_setup(entry, platform) ) + + async def async_close_connection(event): + """Close AsusWrt connection on HA Stop.""" + await router.close() + + stop_listener = hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_STOP, async_close_connection ) - hass.async_create_task( - async_load_platform(hass, "device_tracker", DOMAIN, {}, config) - ) + + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = { + DATA_ASUSWRT: router, + "stop_listener": stop_listener, + } return True + + +async def async_unload_entry(hass: HomeAssistantType, entry: ConfigEntry): + """Unload a config entry.""" + unload_ok = all( + await asyncio.gather( + *[ + hass.config_entries.async_forward_entry_unload(entry, platform) + for platform in PLATFORMS + ] + ) + ) + if unload_ok: + hass.data[DOMAIN][entry.entry_id]["stop_listener"]() + router = hass.data[DOMAIN][entry.entry_id][DATA_ASUSWRT] + await router.close() + + hass.data[DOMAIN].pop(entry.entry_id) + + return unload_ok + + +async def update_listener(hass: HomeAssistantType, entry: ConfigEntry): + """Update when config_entry options update.""" + router = hass.data[DOMAIN][entry.entry_id][DATA_ASUSWRT] + + if router.update_options(entry.options): + await hass.config_entries.async_reload(entry.entry_id) diff --git a/homeassistant/components/asuswrt/config_flow.py b/homeassistant/components/asuswrt/config_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..303b3cc3822ce964e50c4c66ae9d7e5aa97b0ad2 --- /dev/null +++ b/homeassistant/components/asuswrt/config_flow.py @@ -0,0 +1,238 @@ +"""Config flow to configure the AsusWrt integration.""" +import logging +import os +import socket + +import voluptuous as vol + +from homeassistant import config_entries +from homeassistant.components.device_tracker.const import ( + CONF_CONSIDER_HOME, + DEFAULT_CONSIDER_HOME, +) +from homeassistant.const import ( + CONF_HOST, + CONF_MODE, + CONF_PASSWORD, + CONF_PORT, + CONF_PROTOCOL, + CONF_USERNAME, +) +from homeassistant.core import callback +from homeassistant.helpers import config_validation as cv + +# pylint:disable=unused-import +from .const import ( + CONF_DNSMASQ, + CONF_INTERFACE, + CONF_REQUIRE_IP, + CONF_SSH_KEY, + CONF_TRACK_UNKNOWN, + DEFAULT_DNSMASQ, + DEFAULT_INTERFACE, + DEFAULT_SSH_PORT, + DEFAULT_TRACK_UNKNOWN, + DOMAIN, + MODE_AP, + MODE_ROUTER, + PROTOCOL_SSH, + PROTOCOL_TELNET, +) +from .router import get_api + +RESULT_CONN_ERROR = "cannot_connect" +RESULT_UNKNOWN = "unknown" +RESULT_SUCCESS = "success" + +_LOGGER = logging.getLogger(__name__) + + +def _is_file(value) -> bool: + """Validate that the value is an existing file.""" + file_in = os.path.expanduser(str(value)) + + if not os.path.isfile(file_in): + return False + if not os.access(file_in, os.R_OK): + return False + return True + + +def _get_ip(host): + """Get the ip address from the host name.""" + try: + return socket.gethostbyname(host) + except socket.gaierror: + return None + + +class AsusWrtFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): + """Handle a config flow.""" + + VERSION = 1 + CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL + + def __init__(self): + """Initialize AsusWrt config flow.""" + self._host = None + + @callback + def _show_setup_form(self, user_input=None, errors=None): + """Show the setup form to the user.""" + + if user_input is None: + user_input = {} + + return self.async_show_form( + step_id="user", + data_schema=vol.Schema( + { + vol.Required(CONF_HOST, default=user_input.get(CONF_HOST, "")): str, + vol.Required( + CONF_USERNAME, default=user_input.get(CONF_USERNAME, "") + ): str, + vol.Optional(CONF_PASSWORD): str, + vol.Optional(CONF_SSH_KEY): str, + vol.Required(CONF_PROTOCOL, default=PROTOCOL_SSH): vol.In( + {PROTOCOL_SSH: "SSH", PROTOCOL_TELNET: "Telnet"} + ), + vol.Required(CONF_PORT, default=DEFAULT_SSH_PORT): cv.port, + vol.Required(CONF_MODE, default=MODE_ROUTER): vol.In( + {MODE_ROUTER: "Router", MODE_AP: "Access Point"} + ), + } + ), + errors=errors or {}, + ) + + async def _async_check_connection(self, user_input): + """Attempt to connect the AsusWrt router.""" + + api = get_api(user_input) + try: + await api.connection.async_connect() + + except OSError: + _LOGGER.error("Error connecting to the AsusWrt router at %s", self._host) + return RESULT_CONN_ERROR + + except Exception: # pylint: disable=broad-except + _LOGGER.exception( + "Unknown error connecting with AsusWrt router at %s", self._host + ) + return RESULT_UNKNOWN + + if not api.is_connected: + _LOGGER.error("Error connecting to the AsusWrt router at %s", self._host) + return RESULT_CONN_ERROR + + conf_protocol = user_input[CONF_PROTOCOL] + if conf_protocol == PROTOCOL_TELNET: + await api.connection.disconnect() + return RESULT_SUCCESS + + async def async_step_user(self, user_input=None): + """Handle a flow initiated by the user.""" + if self._async_current_entries(): + return self.async_abort(reason="single_instance_allowed") + + if user_input is None: + return self._show_setup_form(user_input) + + errors = {} + self._host = user_input[CONF_HOST] + pwd = user_input.get(CONF_PASSWORD) + ssh = user_input.get(CONF_SSH_KEY) + + if not (pwd or ssh): + errors["base"] = "pwd_or_ssh" + elif ssh: + if pwd: + errors["base"] = "pwd_and_ssh" + else: + isfile = await self.hass.async_add_executor_job(_is_file, ssh) + if not isfile: + errors["base"] = "ssh_not_file" + + if not errors: + ip_address = await self.hass.async_add_executor_job(_get_ip, self._host) + if not ip_address: + errors["base"] = "invalid_host" + + if not errors: + result = await self._async_check_connection(user_input) + if result != RESULT_SUCCESS: + errors["base"] = result + + if errors: + return self._show_setup_form(user_input, errors) + + return self.async_create_entry( + title=self._host, + data=user_input, + ) + + async def async_step_import(self, user_input=None): + """Import a config entry.""" + return await self.async_step_user(user_input) + + @staticmethod + @callback + def async_get_options_flow(config_entry): + """Get the options flow for this handler.""" + return OptionsFlowHandler(config_entry) + + +class OptionsFlowHandler(config_entries.OptionsFlow): + """Handle a option flow for AsusWrt.""" + + def __init__(self, config_entry: config_entries.ConfigEntry): + """Initialize options flow.""" + self.config_entry = config_entry + + async def async_step_init(self, user_input=None): + """Handle options flow.""" + if user_input is not None: + return self.async_create_entry(title="", data=user_input) + + data_schema = vol.Schema( + { + vol.Optional( + CONF_CONSIDER_HOME, + default=self.config_entry.options.get( + CONF_CONSIDER_HOME, DEFAULT_CONSIDER_HOME.total_seconds() + ), + ): vol.All(vol.Coerce(int), vol.Clamp(min=0, max=900)), + vol.Optional( + CONF_TRACK_UNKNOWN, + default=self.config_entry.options.get( + CONF_TRACK_UNKNOWN, DEFAULT_TRACK_UNKNOWN + ), + ): bool, + vol.Required( + CONF_INTERFACE, + default=self.config_entry.options.get( + CONF_INTERFACE, DEFAULT_INTERFACE + ), + ): str, + vol.Required( + CONF_DNSMASQ, + default=self.config_entry.options.get( + CONF_DNSMASQ, DEFAULT_DNSMASQ + ), + ): str, + } + ) + + conf_mode = self.config_entry.data[CONF_MODE] + if conf_mode == MODE_AP: + data_schema = data_schema.extend( + { + vol.Optional( + CONF_REQUIRE_IP, + default=self.config_entry.options.get(CONF_REQUIRE_IP, True), + ): bool, + } + ) + + return self.async_show_form(step_id="init", data_schema=data_schema) diff --git a/homeassistant/components/asuswrt/const.py b/homeassistant/components/asuswrt/const.py new file mode 100644 index 0000000000000000000000000000000000000000..40752e81a08fd5d2a8907e360e7a9c1506bf9051 --- /dev/null +++ b/homeassistant/components/asuswrt/const.py @@ -0,0 +1,24 @@ +"""AsusWrt component constants.""" +DOMAIN = "asuswrt" + +CONF_DNSMASQ = "dnsmasq" +CONF_INTERFACE = "interface" +CONF_REQUIRE_IP = "require_ip" +CONF_SSH_KEY = "ssh_key" +CONF_TRACK_UNKNOWN = "track_unknown" + +DATA_ASUSWRT = DOMAIN + +DEFAULT_DNSMASQ = "/var/lib/misc" +DEFAULT_INTERFACE = "eth0" +DEFAULT_SSH_PORT = 22 +DEFAULT_TRACK_UNKNOWN = False + +MODE_AP = "ap" +MODE_ROUTER = "router" + +PROTOCOL_SSH = "ssh" +PROTOCOL_TELNET = "telnet" + +# Sensor +SENSOR_TYPES = ["devices", "upload_speed", "download_speed", "download", "upload"] diff --git a/homeassistant/components/asuswrt/device_tracker.py b/homeassistant/components/asuswrt/device_tracker.py index a3545183d2e597a9c60ae6fdc7928ab543b1ed4e..85553674dbaf35e73ed8f9762d082ce4a0696573 100644 --- a/homeassistant/components/asuswrt/device_tracker.py +++ b/homeassistant/components/asuswrt/device_tracker.py @@ -1,64 +1,143 @@ """Support for ASUSWRT routers.""" import logging +from typing import Dict -from homeassistant.components.device_tracker import DeviceScanner +from homeassistant.components.device_tracker import SOURCE_TYPE_ROUTER +from homeassistant.components.device_tracker.config_entry import ScannerEntity +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import callback +from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC +from homeassistant.helpers.dispatcher import async_dispatcher_connect +from homeassistant.helpers.typing import HomeAssistantType -from . import DATA_ASUSWRT +from .const import DATA_ASUSWRT, DOMAIN +from .router import AsusWrtRouter + +DEFAULT_DEVICE_NAME = "Unknown device" _LOGGER = logging.getLogger(__name__) -async def async_get_scanner(hass, config): - """Validate the configuration and return an ASUS-WRT scanner.""" - scanner = AsusWrtDeviceScanner(hass.data[DATA_ASUSWRT]) - await scanner.async_connect() - return scanner if scanner.success_init else None - - -class AsusWrtDeviceScanner(DeviceScanner): - """This class queries a router running ASUSWRT firmware.""" - - # Eighth attribute needed for mode (AP mode vs router mode) - def __init__(self, api): - """Initialize the scanner.""" - self.last_results = {} - self.success_init = False - self.connection = api - self._connect_error = False - - async def async_connect(self): - """Initialize connection to the router.""" - # Test the router is accessible. - data = await self.connection.async_get_connected_devices() - self.success_init = data is not None - - async def async_scan_devices(self): - """Scan for new devices and return a list with found device IDs.""" - await self.async_update_info() - return list(self.last_results) - - async def async_get_device_name(self, device): - """Return the name of the given device or None if we don't know.""" - if device not in self.last_results: - return None - return self.last_results[device].name - - async def async_update_info(self): - """Ensure the information from the ASUSWRT router is up to date. - - Return boolean if scanning successful. - """ - _LOGGER.debug("Checking Devices") - - try: - self.last_results = await self.connection.async_get_connected_devices() - if self._connect_error: - self._connect_error = False - _LOGGER.info("Reconnected to ASUS router for device update") - - except OSError as err: - if not self._connect_error: - self._connect_error = True - _LOGGER.error( - "Error connecting to ASUS router for device update: %s", err - ) +async def async_setup_entry( + hass: HomeAssistantType, entry: ConfigEntry, async_add_entities +) -> None: + """Set up device tracker for AsusWrt component.""" + router = hass.data[DOMAIN][entry.entry_id][DATA_ASUSWRT] + tracked = set() + + @callback + def update_router(): + """Update the values of the router.""" + add_entities(router, async_add_entities, tracked) + + router.async_on_close( + async_dispatcher_connect(hass, router.signal_device_new, update_router) + ) + + update_router() + + +@callback +def add_entities(router, async_add_entities, tracked): + """Add new tracker entities from the router.""" + new_tracked = [] + + for mac, device in router.devices.items(): + if mac in tracked: + continue + + new_tracked.append(AsusWrtDevice(router, device)) + tracked.add(mac) + + if new_tracked: + async_add_entities(new_tracked) + + +class AsusWrtDevice(ScannerEntity): + """Representation of a AsusWrt device.""" + + def __init__(self, router: AsusWrtRouter, device) -> None: + """Initialize a AsusWrt device.""" + self._router = router + self._mac = device.mac + self._name = device.name or DEFAULT_DEVICE_NAME + self._active = False + self._icon = None + self._attrs = {} + + @callback + def async_update_state(self) -> None: + """Update the AsusWrt device.""" + device = self._router.devices[self._mac] + self._active = device.is_connected + + self._attrs = { + "mac": device.mac, + "ip_address": device.ip_address, + } + if device.last_activity: + self._attrs["last_time_reachable"] = device.last_activity.isoformat( + timespec="seconds" + ) + + @property + def unique_id(self) -> str: + """Return a unique ID.""" + return self._mac + + @property + def name(self) -> str: + """Return the name.""" + return self._name + + @property + def is_connected(self): + """Return true if the device is connected to the network.""" + return self._active + + @property + def source_type(self) -> str: + """Return the source type.""" + return SOURCE_TYPE_ROUTER + + @property + def icon(self) -> str: + """Return the icon.""" + return self._icon + + @property + def device_state_attributes(self) -> Dict[str, any]: + """Return the attributes.""" + return self._attrs + + @property + def device_info(self) -> Dict[str, any]: + """Return the device information.""" + return { + "connections": {(CONNECTION_NETWORK_MAC, self._mac)}, + "identifiers": {(DOMAIN, self.unique_id)}, + "name": self.name, + "manufacturer": "AsusWRT Tracked device", + } + + @property + def should_poll(self) -> bool: + """No polling needed.""" + return False + + @callback + def async_on_demand_update(self): + """Update state.""" + self.async_update_state() + self.async_write_ha_state() + + async def async_added_to_hass(self): + """Register state update callback.""" + self.async_update_state() + self.async_on_remove( + async_dispatcher_connect( + self.hass, + self._router.signal_device_update, + self.async_on_demand_update, + ) + ) diff --git a/homeassistant/components/asuswrt/manifest.json b/homeassistant/components/asuswrt/manifest.json index 9afb7849f8cebddd0b5a15ec0f4da5c9425ba3f9..744a05b9728ec47a6ff2d336baa6be8e10c21868 100644 --- a/homeassistant/components/asuswrt/manifest.json +++ b/homeassistant/components/asuswrt/manifest.json @@ -1,6 +1,7 @@ { "domain": "asuswrt", "name": "ASUSWRT", + "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/asuswrt", "requirements": ["aioasuswrt==1.3.1"], "codeowners": ["@kennedyshead"] diff --git a/homeassistant/components/asuswrt/router.py b/homeassistant/components/asuswrt/router.py new file mode 100644 index 0000000000000000000000000000000000000000..11545919b43abb021333a92dd9bbbe488664ccf5 --- /dev/null +++ b/homeassistant/components/asuswrt/router.py @@ -0,0 +1,274 @@ +"""Represent the AsusWrt router.""" +from datetime import datetime, timedelta +import logging +from typing import Any, Dict, Optional + +from aioasuswrt.asuswrt import AsusWrt + +from homeassistant.components.device_tracker.const import ( + CONF_CONSIDER_HOME, + DEFAULT_CONSIDER_HOME, + DOMAIN as TRACKER_DOMAIN, +) +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import ( + CONF_HOST, + CONF_MODE, + CONF_PASSWORD, + CONF_PORT, + CONF_PROTOCOL, + CONF_USERNAME, +) +from homeassistant.core import CALLBACK_TYPE, callback +from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers.dispatcher import async_dispatcher_send +from homeassistant.helpers.event import async_track_time_interval +from homeassistant.helpers.typing import HomeAssistantType +from homeassistant.util import dt as dt_util + +from .const import ( + CONF_DNSMASQ, + CONF_INTERFACE, + CONF_REQUIRE_IP, + CONF_SSH_KEY, + CONF_TRACK_UNKNOWN, + DEFAULT_DNSMASQ, + DEFAULT_INTERFACE, + DEFAULT_TRACK_UNKNOWN, + DOMAIN, + PROTOCOL_TELNET, +) + +CONF_REQ_RELOAD = [CONF_DNSMASQ, CONF_INTERFACE, CONF_REQUIRE_IP] +SCAN_INTERVAL = timedelta(seconds=30) + +_LOGGER = logging.getLogger(__name__) + + +class AsusWrtDevInfo: + """Representation of a AsusWrt device info.""" + + def __init__(self, mac, name=None): + """Initialize a AsusWrt device info.""" + self._mac = mac + self._name = name + self._ip_address = None + self._last_activity = None + self._connected = False + + def update(self, dev_info=None, consider_home=0): + """Update AsusWrt device info.""" + utc_point_in_time = dt_util.utcnow() + if dev_info: + if not self._name: + self._name = dev_info.name or self._mac.replace(":", "_") + self._ip_address = dev_info.ip + self._last_activity = utc_point_in_time + self._connected = True + + elif self._connected: + self._connected = ( + utc_point_in_time - self._last_activity + ).total_seconds() < consider_home + self._ip_address = None + + @property + def is_connected(self): + """Return connected status.""" + return self._connected + + @property + def mac(self): + """Return device mac address.""" + return self._mac + + @property + def name(self): + """Return device name.""" + return self._name + + @property + def ip_address(self): + """Return device ip address.""" + return self._ip_address + + @property + def last_activity(self): + """Return device last activity.""" + return self._last_activity + + +class AsusWrtRouter: + """Representation of a AsusWrt router.""" + + def __init__(self, hass: HomeAssistantType, entry: ConfigEntry) -> None: + """Initialize a AsusWrt router.""" + self.hass = hass + self._entry = entry + + self._api: AsusWrt = None + self._protocol = entry.data[CONF_PROTOCOL] + self._host = entry.data[CONF_HOST] + + self._devices: Dict[str, Any] = {} + self._connect_error = False + + self._on_close = [] + + self._options = { + CONF_DNSMASQ: DEFAULT_DNSMASQ, + CONF_INTERFACE: DEFAULT_INTERFACE, + CONF_REQUIRE_IP: True, + } + self._options.update(entry.options) + + async def setup(self) -> None: + """Set up a AsusWrt router.""" + self._api = get_api(self._entry.data, self._options) + + try: + await self._api.connection.async_connect() + except OSError as exp: + raise ConfigEntryNotReady from exp + + if not self._api.is_connected: + raise ConfigEntryNotReady + + # Load tracked entities from registry + entity_registry = await self.hass.helpers.entity_registry.async_get_registry() + track_entries = ( + self.hass.helpers.entity_registry.async_entries_for_config_entry( + entity_registry, self._entry.entry_id + ) + ) + for entry in track_entries: + if entry.domain == TRACKER_DOMAIN: + self._devices[entry.unique_id] = AsusWrtDevInfo( + entry.unique_id, entry.original_name + ) + + # Update devices + await self.update_devices() + + self.async_on_close( + async_track_time_interval(self.hass, self.update_all, SCAN_INTERVAL) + ) + + async def update_all(self, now: Optional[datetime] = None) -> None: + """Update all AsusWrt platforms.""" + await self.update_devices() + + async def update_devices(self) -> None: + """Update AsusWrt devices tracker.""" + new_device = False + _LOGGER.debug("Checking devices for ASUS router %s", self._host) + try: + wrt_devices = await self._api.async_get_connected_devices() + except OSError as exc: + if not self._connect_error: + self._connect_error = True + _LOGGER.error( + "Error connecting to ASUS router %s for device update: %s", + self._host, + exc, + ) + return + + if self._connect_error: + self._connect_error = False + _LOGGER.info("Reconnected to ASUS router %s", self._host) + + consider_home = self._options.get( + CONF_CONSIDER_HOME, DEFAULT_CONSIDER_HOME.total_seconds() + ) + track_unknown = self._options.get(CONF_TRACK_UNKNOWN, DEFAULT_TRACK_UNKNOWN) + + for device_mac in self._devices: + dev_info = wrt_devices.get(device_mac) + self._devices[device_mac].update(dev_info, consider_home) + + for device_mac, dev_info in wrt_devices.items(): + if device_mac in self._devices: + continue + if not track_unknown and not dev_info.name: + continue + new_device = True + device = AsusWrtDevInfo(device_mac) + device.update(dev_info) + self._devices[device_mac] = device + + async_dispatcher_send(self.hass, self.signal_device_update) + if new_device: + async_dispatcher_send(self.hass, self.signal_device_new) + + async def close(self) -> None: + """Close the connection.""" + if self._api is not None: + if self._protocol == PROTOCOL_TELNET: + await self._api.connection.disconnect() + self._api = None + + for func in self._on_close: + func() + self._on_close.clear() + + @callback + def async_on_close(self, func: CALLBACK_TYPE) -> None: + """Add a function to call when router is closed.""" + self._on_close.append(func) + + def update_options(self, new_options: Dict) -> bool: + """Update router options.""" + req_reload = False + for name, new_opt in new_options.items(): + if name in (CONF_REQ_RELOAD): + old_opt = self._options.get(name) + if not old_opt or old_opt != new_opt: + req_reload = True + break + + self._options.update(new_options) + return req_reload + + @property + def signal_device_new(self) -> str: + """Event specific per AsusWrt entry to signal new device.""" + return f"{DOMAIN}-device-new" + + @property + def signal_device_update(self) -> str: + """Event specific per AsusWrt entry to signal updates in devices.""" + return f"{DOMAIN}-device-update" + + @property + def host(self) -> str: + """Return router hostname.""" + return self._host + + @property + def devices(self) -> Dict[str, Any]: + """Return devices.""" + return self._devices + + @property + def api(self) -> AsusWrt: + """Return router API.""" + return self._api + + +def get_api(conf: Dict, options: Optional[Dict] = None) -> AsusWrt: + """Get the AsusWrt API.""" + opt = options or {} + + return AsusWrt( + conf[CONF_HOST], + conf[CONF_PORT], + conf[CONF_PROTOCOL] == PROTOCOL_TELNET, + conf[CONF_USERNAME], + conf.get(CONF_PASSWORD, ""), + conf.get(CONF_SSH_KEY, ""), + conf[CONF_MODE], + opt.get(CONF_REQUIRE_IP, True), + interface=opt.get(CONF_INTERFACE, DEFAULT_INTERFACE), + dnsmasq=opt.get(CONF_DNSMASQ, DEFAULT_DNSMASQ), + ) diff --git a/homeassistant/components/asuswrt/sensor.py b/homeassistant/components/asuswrt/sensor.py index aa13bee81d0946ec6c52df237de3ca792eb4b58d..2a39d339f069017d49818df090949bbafa448efa 100644 --- a/homeassistant/components/asuswrt/sensor.py +++ b/homeassistant/components/asuswrt/sensor.py @@ -6,13 +6,15 @@ from typing import Any, Dict, List, Optional from aioasuswrt.asuswrt import AsusWrt -from homeassistant.const import DATA_GIGABYTES, DATA_RATE_MEGABITS_PER_SECOND +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_NAME, DATA_GIGABYTES, DATA_RATE_MEGABITS_PER_SECOND +from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.update_coordinator import ( CoordinatorEntity, DataUpdateCoordinator, ) -from . import DATA_ASUSWRT +from .const import DATA_ASUSWRT, DOMAIN, SENSOR_TYPES UPLOAD_ICON = "mdi:upload-network" DOWNLOAD_ICON = "mdi:download-network" @@ -35,6 +37,8 @@ class _SensorTypes(enum.Enum): return DATA_GIGABYTES if self in (_SensorTypes.UPLOAD_SPEED, _SensorTypes.DOWNLOAD_SPEED): return DATA_RATE_MEGABITS_PER_SECOND + if self == _SensorTypes.DEVICES: + return "devices" return None @property @@ -72,15 +76,26 @@ class _SensorTypes(enum.Enum): return self in (_SensorTypes.UPLOAD, _SensorTypes.DOWNLOAD) -async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): +class _SensorInfo: + """Class handling sensor information.""" + + def __init__(self, sensor_type: _SensorTypes): + """Initialize the handler class.""" + self.type = sensor_type + self.enabled = False + + +async def async_setup_entry( + hass: HomeAssistantType, entry: ConfigEntry, async_add_entities +) -> None: """Set up the asuswrt sensors.""" - if discovery_info is None: - return - api: AsusWrt = hass.data[DATA_ASUSWRT] + router = hass.data[DOMAIN][entry.entry_id][DATA_ASUSWRT] + api: AsusWrt = router.api + device_name = entry.data.get(CONF_NAME, "AsusWRT") # Let's discover the valid sensor types. - sensors = [_SensorTypes(x) for x in discovery_info] + sensors = [_SensorInfo(_SensorTypes(x)) for x in SENSOR_TYPES] data_handler = AsuswrtDataHandler(sensors, api) coordinator = DataUpdateCoordinator( @@ -93,34 +108,50 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= ) await coordinator.async_refresh() - async_add_entities([AsuswrtSensor(coordinator, x) for x in sensors]) + async_add_entities( + [AsuswrtSensor(coordinator, data_handler, device_name, x.type) for x in sensors] + ) class AsuswrtDataHandler: """Class handling the API updates.""" - def __init__(self, sensors: List[_SensorTypes], api: AsusWrt): + def __init__(self, sensors: List[_SensorInfo], api: AsusWrt): """Initialize the handler class.""" self._api = api self._sensors = sensors self._connected = True + def enable_sensor(self, sensor_type: _SensorTypes): + """Enable a specific sensor type.""" + for index, sensor in enumerate(self._sensors): + if sensor.type == sensor_type: + self._sensors[index].enabled = True + return + + def disable_sensor(self, sensor_type: _SensorTypes): + """Disable a specific sensor type.""" + for index, sensor in enumerate(self._sensors): + if sensor.type == sensor_type: + self._sensors[index].enabled = False + return + async def update_data(self) -> Dict[_SensorTypes, Any]: """Fetch the relevant data from the router.""" ret_dict: Dict[_SensorTypes, Any] = {} try: - if _SensorTypes.DEVICES in self._sensors: + if _SensorTypes.DEVICES in [x.type for x in self._sensors if x.enabled]: # Let's check the nr of devices. devices = await self._api.async_get_connected_devices() ret_dict[_SensorTypes.DEVICES] = len(devices) - if any(x.is_speed for x in self._sensors): + if any(x.type.is_speed for x in self._sensors if x.enabled): # Let's check the upload and download speed speed = await self._api.async_get_current_transfer_rates() ret_dict[_SensorTypes.DOWNLOAD_SPEED] = round(speed[0] / 125000, 2) ret_dict[_SensorTypes.UPLOAD_SPEED] = round(speed[1] / 125000, 2) - if any(x.is_size for x in self._sensors): + if any(x.type.is_size for x in self._sensors if x.enabled): rates = await self._api.async_get_bytes_total() ret_dict[_SensorTypes.DOWNLOAD] = round(rates[0] / 1000000000, 1) ret_dict[_SensorTypes.UPLOAD] = round(rates[1] / 1000000000, 1) @@ -142,9 +173,17 @@ class AsuswrtDataHandler: class AsuswrtSensor(CoordinatorEntity): """The asuswrt specific sensor class.""" - def __init__(self, coordinator: DataUpdateCoordinator, sensor_type: _SensorTypes): + def __init__( + self, + coordinator: DataUpdateCoordinator, + data_handler: AsuswrtDataHandler, + device_name: str, + sensor_type: _SensorTypes, + ): """Initialize the sensor class.""" super().__init__(coordinator) + self._handler = data_handler + self._device_name = device_name self._type = sensor_type @property @@ -164,5 +203,34 @@ class AsuswrtSensor(CoordinatorEntity): @property def unit_of_measurement(self) -> Optional[str]: - """Return the unit of measurement of this entity, if any.""" + """Return the unit.""" return self._type.unit_of_measurement + + @property + def unique_id(self) -> str: + """Return the unique_id of the sensor.""" + return f"{DOMAIN} {self._type.sensor_name}" + + @property + def device_info(self) -> Dict[str, any]: + """Return the device information.""" + return { + "identifiers": {(DOMAIN, "AsusWRT")}, + "name": self._device_name, + "model": "Asus Router", + "manufacturer": "Asus", + } + + @property + def entity_registry_enabled_default(self) -> bool: + """Return if the entity should be enabled when first added to the entity registry.""" + return False + + async def async_added_to_hass(self) -> None: + """When entity is added to hass.""" + self._handler.enable_sensor(self._type) + await super().async_added_to_hass() + + async def async_will_remove_from_hass(self): + """Call when entity is removed from hass.""" + self._handler.disable_sensor(self._type) diff --git a/homeassistant/components/asuswrt/strings.json b/homeassistant/components/asuswrt/strings.json new file mode 100644 index 0000000000000000000000000000000000000000..079ee35bf95a9882ea2e68b933e3f71147aa4479 --- /dev/null +++ b/homeassistant/components/asuswrt/strings.json @@ -0,0 +1,45 @@ +{ + "config": { + "step": { + "user": { + "title": "AsusWRT", + "description": "Set required parameter to connect to your router", + "data": { + "host": "[%key:common::config_flow::data::host%]", + "name": "[%key:common::config_flow::data::name%]", + "username": "[%key:common::config_flow::data::username%]", + "password": "[%key:common::config_flow::data::password%]", + "ssh_key": "Path to your SSH key file (instead of password)", + "protocol": "Communication protocol to use", + "port": "[%key:common::config_flow::data::port%]", + "mode": "[%key:common::config_flow::data::mode%]" + } + } + }, + "error": { + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", + "invalid_host": "[%key:common::config_flow::error::invalid_host%]", + "pwd_and_ssh": "Only provide password or SSH key file", + "pwd_or_ssh": "Please provide password or SSH key file", + "ssh_not_file": "SSH key file not found", + "unknown": "[%key:common::config_flow::error::unknown%]" + }, + "abort": { + "single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]" + } + }, + "options": { + "step": { + "init": { + "title": "AsusWRT Options", + "data": { + "consider_home": "Seconds to wait before considering a device away", + "track_unknown": "Track unknown / unamed devices", + "interface": "The interface that you want statistics from (e.g. eth0,eth1 etc)", + "dnsmasq": "The location in the router of the dnsmasq.leases files", + "require_ip": "Devices must have IP (for access point mode)" + } + } + } + } +} diff --git a/homeassistant/components/asuswrt/translations/en.json b/homeassistant/components/asuswrt/translations/en.json new file mode 100644 index 0000000000000000000000000000000000000000..5ac87e277f4e660f4c480bdbfada19606027a60d --- /dev/null +++ b/homeassistant/components/asuswrt/translations/en.json @@ -0,0 +1,45 @@ +{ + "config": { + "abort": { + "single_instance_allowed": "Already configured. Only a single configuration possible." + }, + "error": { + "cannot_connect": "Failed to connect", + "invalid_host": "Invalid hostname or IP address", + "pwd_and_ssh": "Only provide password or SSH key file", + "pwd_or_ssh": "Please provide password or SSH key file", + "ssh_not_file": "SSH key file not found", + "unknown": "Unexpected error" + }, + "step": { + "user": { + "data": { + "host": "Host", + "mode": "Mode", + "name": "Name", + "password": "Password", + "port": "Port", + "protocol": "Communication protocol to use", + "ssh_key": "Path to your SSH key file (instead of password)", + "username": "Username" + }, + "description": "Set required parameter to connect to your router", + "title": "AsusWRT" + } + } + }, + "options": { + "step": { + "init": { + "data": { + "consider_home": "Seconds to wait before considering a device away", + "dnsmasq": "The location in the router of the dnsmasq.leases files", + "interface": "The interface that you want statistics from (e.g. eth0,eth1 etc)", + "require_ip": "Devices must have IP (for access point mode)", + "track_unknown": "Track unknown / unamed devices" + }, + "title": "AsusWRT Options" + } + } + } +} \ No newline at end of file diff --git a/homeassistant/generated/config_flows.py b/homeassistant/generated/config_flows.py index 06e2516633e5ab68adf407418eac831d4061112e..f5f550b3073e38a1f4bc4531c95a6711deba2054 100644 --- a/homeassistant/generated/config_flows.py +++ b/homeassistant/generated/config_flows.py @@ -21,6 +21,7 @@ FLOWS = [ "ambient_station", "apple_tv", "arcam_fmj", + "asuswrt", "atag", "august", "aurora", diff --git a/tests/components/asuswrt/test_config_flow.py b/tests/components/asuswrt/test_config_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..7faec5d336cce91fb380365737c41dbcaef986d9 --- /dev/null +++ b/tests/components/asuswrt/test_config_flow.py @@ -0,0 +1,296 @@ +"""Tests for the AsusWrt config flow.""" +from socket import gaierror +from unittest.mock import AsyncMock, patch + +import pytest + +from homeassistant import data_entry_flow +from homeassistant.components.asuswrt.const import ( + CONF_DNSMASQ, + CONF_INTERFACE, + CONF_REQUIRE_IP, + CONF_SSH_KEY, + CONF_TRACK_UNKNOWN, + DOMAIN, +) +from homeassistant.components.device_tracker.const import CONF_CONSIDER_HOME +from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER +from homeassistant.const import ( + CONF_HOST, + CONF_MODE, + CONF_PASSWORD, + CONF_PORT, + CONF_PROTOCOL, + CONF_USERNAME, +) + +from tests.common import MockConfigEntry + +HOST = "myrouter.asuswrt.com" +IP_ADDRESS = "192.168.1.1" +SSH_KEY = "1234" + +CONFIG_DATA = { + CONF_HOST: HOST, + CONF_PORT: 22, + CONF_PROTOCOL: "telnet", + CONF_USERNAME: "user", + CONF_PASSWORD: "pwd", + CONF_MODE: "ap", +} + + +@pytest.fixture(name="connect") +def mock_controller_connect(): + """Mock a successful connection.""" + with patch("homeassistant.components.asuswrt.router.AsusWrt") as service_mock: + service_mock.return_value.connection.async_connect = AsyncMock() + service_mock.return_value.is_connected = True + service_mock.return_value.connection.disconnect = AsyncMock() + yield service_mock + + +async def test_user(hass, connect): + """Test user config.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER} + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "user" + + # test with all provided + with patch( + "homeassistant.components.asuswrt.async_setup_entry", + return_value=True, + ) as mock_setup_entry, patch( + "homeassistant.components.asuswrt.config_flow.socket.gethostbyname", + return_value=IP_ADDRESS, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + data=CONFIG_DATA, + ) + await hass.async_block_till_done() + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["title"] == HOST + assert result["data"] == CONFIG_DATA + + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_import(hass, connect): + """Test import step.""" + with patch( + "homeassistant.components.asuswrt.async_setup_entry", + return_value=True, + ) as mock_setup_entry, patch( + "homeassistant.components.asuswrt.config_flow.socket.gethostbyname", + return_value=IP_ADDRESS, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_IMPORT}, + data=CONFIG_DATA, + ) + await hass.async_block_till_done() + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["title"] == HOST + assert result["data"] == CONFIG_DATA + + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_import_ssh(hass, connect): + """Test import step with ssh file.""" + config_data = CONFIG_DATA.copy() + config_data.pop(CONF_PASSWORD) + config_data[CONF_SSH_KEY] = SSH_KEY + + with patch( + "homeassistant.components.asuswrt.async_setup_entry", + return_value=True, + ) as mock_setup_entry, patch( + "homeassistant.components.asuswrt.config_flow.socket.gethostbyname", + return_value=IP_ADDRESS, + ), patch( + "homeassistant.components.asuswrt.config_flow.os.path.isfile", + return_value=True, + ), patch( + "homeassistant.components.asuswrt.config_flow.os.access", + return_value=True, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_IMPORT}, + data=config_data, + ) + await hass.async_block_till_done() + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["title"] == HOST + assert result["data"] == config_data + + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_error_no_password_ssh(hass): + """Test we abort if component is already setup.""" + config_data = CONFIG_DATA.copy() + config_data.pop(CONF_PASSWORD) + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + data=config_data, + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {"base": "pwd_or_ssh"} + + +async def test_error_both_password_ssh(hass): + """Test we abort if component is already setup.""" + config_data = CONFIG_DATA.copy() + config_data[CONF_SSH_KEY] = SSH_KEY + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + data=config_data, + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {"base": "pwd_and_ssh"} + + +async def test_error_invalid_ssh(hass): + """Test we abort if component is already setup.""" + config_data = CONFIG_DATA.copy() + config_data.pop(CONF_PASSWORD) + config_data[CONF_SSH_KEY] = SSH_KEY + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + data=config_data, + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {"base": "ssh_not_file"} + + +async def test_error_invalid_host(hass): + """Test we abort if host name is invalid.""" + with patch( + "homeassistant.components.asuswrt.config_flow.socket.gethostbyname", + side_effect=gaierror, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + data=CONFIG_DATA, + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {"base": "invalid_host"} + + +async def test_abort_if_already_setup(hass): + """Test we abort if component is already setup.""" + MockConfigEntry( + domain=DOMAIN, + data=CONFIG_DATA, + ).add_to_hass(hass) + + with patch( + "homeassistant.components.asuswrt.config_flow.socket.gethostbyname", + return_value=IP_ADDRESS, + ): + # Should fail, same HOST (flow) + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + data=CONFIG_DATA, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "single_instance_allowed" + + # Should fail, same HOST (import) + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_IMPORT}, + data=CONFIG_DATA, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "single_instance_allowed" + + +async def test_on_connect_failed(hass): + """Test when we have errors connecting the router.""" + flow_result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + ) + + with patch("homeassistant.components.asuswrt.router.AsusWrt") as asus_wrt: + asus_wrt.return_value.connection.async_connect = AsyncMock() + asus_wrt.return_value.is_connected = False + result = await hass.config_entries.flow.async_configure( + flow_result["flow_id"], user_input=CONFIG_DATA + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {"base": "cannot_connect"} + + with patch("homeassistant.components.asuswrt.router.AsusWrt") as asus_wrt: + asus_wrt.return_value.connection.async_connect = AsyncMock(side_effect=OSError) + result = await hass.config_entries.flow.async_configure( + flow_result["flow_id"], user_input=CONFIG_DATA + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {"base": "cannot_connect"} + + with patch("homeassistant.components.asuswrt.router.AsusWrt") as asus_wrt: + asus_wrt.return_value.connection.async_connect = AsyncMock( + side_effect=TypeError + ) + result = await hass.config_entries.flow.async_configure( + flow_result["flow_id"], user_input=CONFIG_DATA + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {"base": "unknown"} + + +async def test_options_flow(hass): + """Test config flow options.""" + config_entry = MockConfigEntry( + domain=DOMAIN, + data=CONFIG_DATA, + options={CONF_REQUIRE_IP: True}, + ) + config_entry.add_to_hass(hass) + + with patch("homeassistant.components.asuswrt.async_setup_entry", return_value=True): + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + result = await hass.config_entries.options.async_init(config_entry.entry_id) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "init" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + CONF_CONSIDER_HOME: 20, + CONF_TRACK_UNKNOWN: True, + CONF_INTERFACE: "aaa", + CONF_DNSMASQ: "bbb", + CONF_REQUIRE_IP: False, + }, + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert config_entry.options[CONF_CONSIDER_HOME] == 20 + assert config_entry.options[CONF_TRACK_UNKNOWN] is True + assert config_entry.options[CONF_INTERFACE] == "aaa" + assert config_entry.options[CONF_DNSMASQ] == "bbb" + assert config_entry.options[CONF_REQUIRE_IP] is False diff --git a/tests/components/asuswrt/test_device_tracker.py b/tests/components/asuswrt/test_device_tracker.py deleted file mode 100644 index 941b0c340d6177bd44eb73193bc4c116230bc611..0000000000000000000000000000000000000000 --- a/tests/components/asuswrt/test_device_tracker.py +++ /dev/null @@ -1,119 +0,0 @@ -"""The tests for the ASUSWRT device tracker platform.""" - -from unittest.mock import AsyncMock, patch - -from homeassistant.components.asuswrt import ( - CONF_DNSMASQ, - CONF_INTERFACE, - DATA_ASUSWRT, - DOMAIN, -) -from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME -from homeassistant.setup import async_setup_component - - -async def test_password_or_pub_key_required(hass): - """Test creating an AsusWRT scanner without a pass or pubkey.""" - with patch("homeassistant.components.asuswrt.AsusWrt") as AsusWrt: - AsusWrt().connection.async_connect = AsyncMock() - AsusWrt().is_connected = False - result = await async_setup_component( - hass, DOMAIN, {DOMAIN: {CONF_HOST: "fake_host", CONF_USERNAME: "fake_user"}} - ) - assert not result - - -async def test_network_unreachable(hass): - """Test creating an AsusWRT scanner without a pass or pubkey.""" - with patch("homeassistant.components.asuswrt.AsusWrt") as AsusWrt: - AsusWrt().connection.async_connect = AsyncMock(side_effect=OSError) - AsusWrt().is_connected = False - result = await async_setup_component( - hass, DOMAIN, {DOMAIN: {CONF_HOST: "fake_host", CONF_USERNAME: "fake_user"}} - ) - assert result - assert hass.data.get(DATA_ASUSWRT) is None - - -async def test_get_scanner_with_password_no_pubkey(hass): - """Test creating an AsusWRT scanner with a password and no pubkey.""" - with patch("homeassistant.components.asuswrt.AsusWrt") as AsusWrt: - AsusWrt().connection.async_connect = AsyncMock() - AsusWrt().connection.async_get_connected_devices = AsyncMock(return_value={}) - result = await async_setup_component( - hass, - DOMAIN, - { - DOMAIN: { - CONF_HOST: "fake_host", - CONF_USERNAME: "fake_user", - CONF_PASSWORD: "4321", - CONF_DNSMASQ: "/", - } - }, - ) - assert result - assert hass.data[DATA_ASUSWRT] is not None - - -async def test_specify_non_directory_path_for_dnsmasq(hass): - """Test creating an AsusWRT scanner with a dnsmasq location which is not a valid directory.""" - with patch("homeassistant.components.asuswrt.AsusWrt") as AsusWrt: - AsusWrt().connection.async_connect = AsyncMock() - AsusWrt().is_connected = False - result = await async_setup_component( - hass, - DOMAIN, - { - DOMAIN: { - CONF_HOST: "fake_host", - CONF_USERNAME: "fake_user", - CONF_PASSWORD: "4321", - CONF_DNSMASQ: 1234, - } - }, - ) - assert not result - - -async def test_interface(hass): - """Test creating an AsusWRT scanner using interface eth1.""" - with patch("homeassistant.components.asuswrt.AsusWrt") as AsusWrt: - AsusWrt().connection.async_connect = AsyncMock() - AsusWrt().connection.async_get_connected_devices = AsyncMock(return_value={}) - result = await async_setup_component( - hass, - DOMAIN, - { - DOMAIN: { - CONF_HOST: "fake_host", - CONF_USERNAME: "fake_user", - CONF_PASSWORD: "4321", - CONF_DNSMASQ: "/", - CONF_INTERFACE: "eth1", - } - }, - ) - assert result - assert hass.data[DATA_ASUSWRT] is not None - - -async def test_no_interface(hass): - """Test creating an AsusWRT scanner using no interface.""" - with patch("homeassistant.components.asuswrt.AsusWrt") as AsusWrt: - AsusWrt().connection.async_connect = AsyncMock() - AsusWrt().is_connected = False - result = await async_setup_component( - hass, - DOMAIN, - { - DOMAIN: { - CONF_HOST: "fake_host", - CONF_USERNAME: "fake_user", - CONF_PASSWORD: "4321", - CONF_DNSMASQ: "/", - CONF_INTERFACE: None, - } - }, - ) - assert not result diff --git a/tests/components/asuswrt/test_sensor.py b/tests/components/asuswrt/test_sensor.py index 69c70c409d5d41d50964ed0e10f769ee9ab0587f..994111370fda886f2503a200c8e65ee45df08464 100644 --- a/tests/components/asuswrt/test_sensor.py +++ b/tests/components/asuswrt/test_sensor.py @@ -1,71 +1,150 @@ -"""The tests for the AsusWrt sensor platform.""" - +"""Tests for the AsusWrt sensor.""" +from datetime import timedelta from unittest.mock import AsyncMock, patch from aioasuswrt.asuswrt import Device +import pytest -from homeassistant.components import sensor -from homeassistant.components.asuswrt import ( - CONF_DNSMASQ, - CONF_INTERFACE, +from homeassistant.components import device_tracker, sensor +from homeassistant.components.asuswrt.const import DOMAIN +from homeassistant.components.asuswrt.sensor import _SensorTypes +from homeassistant.components.device_tracker.const import CONF_CONSIDER_HOME +from homeassistant.const import ( + CONF_HOST, CONF_MODE, + CONF_PASSWORD, CONF_PORT, CONF_PROTOCOL, - CONF_SENSORS, - DOMAIN, + CONF_USERNAME, + STATE_HOME, + STATE_NOT_HOME, ) -from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME -from homeassistant.core import HomeAssistant -from homeassistant.setup import async_setup_component - -VALID_CONFIG_ROUTER_SSH = { - DOMAIN: { - CONF_DNSMASQ: "/", - CONF_HOST: "fake_host", - CONF_INTERFACE: "eth0", - CONF_MODE: "router", - CONF_PORT: "22", - CONF_PROTOCOL: "ssh", - CONF_USERNAME: "fake_user", - CONF_PASSWORD: "fake_pass", - CONF_SENSORS: [ - "devices", - "download_speed", - "download", - "upload_speed", - "upload", - ], - } +from homeassistant.util.dt import utcnow + +from tests.common import MockConfigEntry, async_fire_time_changed + +HOST = "myrouter.asuswrt.com" +IP_ADDRESS = "192.168.1.1" + +CONFIG_DATA = { + CONF_HOST: HOST, + CONF_PORT: 22, + CONF_PROTOCOL: "ssh", + CONF_USERNAME: "user", + CONF_PASSWORD: "pwd", + CONF_MODE: "router", } MOCK_DEVICES = { "a1:b1:c1:d1:e1:f1": Device("a1:b1:c1:d1:e1:f1", "192.168.1.2", "Test"), "a2:b2:c2:d2:e2:f2": Device("a2:b2:c2:d2:e2:f2", "192.168.1.3", "TestTwo"), - "a3:b3:c3:d3:e3:f3": Device("a3:b3:c3:d3:e3:f3", "192.168.1.4", "TestThree"), } MOCK_BYTES_TOTAL = [60000000000, 50000000000] MOCK_CURRENT_TRANSFER_RATES = [20000000, 10000000] -async def test_sensors(hass: HomeAssistant, mock_device_tracker_conf): - """Test creating an AsusWRT sensor.""" - with patch("homeassistant.components.asuswrt.AsusWrt") as AsusWrt: - AsusWrt().connection.async_connect = AsyncMock() - AsusWrt().async_get_connected_devices = AsyncMock(return_value=MOCK_DEVICES) - AsusWrt().async_get_bytes_total = AsyncMock(return_value=MOCK_BYTES_TOTAL) - AsusWrt().async_get_current_transfer_rates = AsyncMock( +@pytest.fixture(name="connect") +def mock_controller_connect(): + """Mock a successful connection.""" + with patch("homeassistant.components.asuswrt.router.AsusWrt") as service_mock: + service_mock.return_value.connection.async_connect = AsyncMock() + service_mock.return_value.is_connected = True + service_mock.return_value.connection.disconnect = AsyncMock() + service_mock.return_value.async_get_connected_devices = AsyncMock( + return_value=MOCK_DEVICES + ) + service_mock.return_value.async_get_bytes_total = AsyncMock( + return_value=MOCK_BYTES_TOTAL + ) + service_mock.return_value.async_get_current_transfer_rates = AsyncMock( return_value=MOCK_CURRENT_TRANSFER_RATES ) + yield service_mock - assert await async_setup_component(hass, DOMAIN, VALID_CONFIG_ROUTER_SSH) - await hass.async_block_till_done() - assert ( - hass.states.get(f"{sensor.DOMAIN}.asuswrt_devices_connected").state == "3" - ) - assert ( - hass.states.get(f"{sensor.DOMAIN}.asuswrt_download_speed").state == "160.0" - ) - assert hass.states.get(f"{sensor.DOMAIN}.asuswrt_download").state == "60.0" - assert hass.states.get(f"{sensor.DOMAIN}.asuswrt_upload_speed").state == "80.0" - assert hass.states.get(f"{sensor.DOMAIN}.asuswrt_upload").state == "50.0" +async def test_sensors(hass, connect): + """Test creating an AsusWRT sensor.""" + entity_reg = await hass.helpers.entity_registry.async_get_registry() + + # Pre-enable the status sensor + entity_reg.async_get_or_create( + sensor.DOMAIN, + DOMAIN, + f"{DOMAIN} {_SensorTypes(_SensorTypes.DEVICES).sensor_name}", + suggested_object_id="asuswrt_connected_devices", + disabled_by=None, + ) + entity_reg.async_get_or_create( + sensor.DOMAIN, + DOMAIN, + f"{DOMAIN} {_SensorTypes(_SensorTypes.DOWNLOAD_SPEED).sensor_name}", + suggested_object_id="asuswrt_download_speed", + disabled_by=None, + ) + entity_reg.async_get_or_create( + sensor.DOMAIN, + DOMAIN, + f"{DOMAIN} {_SensorTypes(_SensorTypes.DOWNLOAD).sensor_name}", + suggested_object_id="asuswrt_download", + disabled_by=None, + ) + entity_reg.async_get_or_create( + sensor.DOMAIN, + DOMAIN, + f"{DOMAIN} {_SensorTypes(_SensorTypes.UPLOAD_SPEED).sensor_name}", + suggested_object_id="asuswrt_upload_speed", + disabled_by=None, + ) + entity_reg.async_get_or_create( + sensor.DOMAIN, + DOMAIN, + f"{DOMAIN} {_SensorTypes(_SensorTypes.UPLOAD).sensor_name}", + suggested_object_id="asuswrt_upload", + disabled_by=None, + ) + + # init config entry + config_entry = MockConfigEntry( + domain=DOMAIN, + data=CONFIG_DATA, + options={CONF_CONSIDER_HOME: 60}, + ) + config_entry.add_to_hass(hass) + + # initial devices setup + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=30)) + await hass.async_block_till_done() + + assert hass.states.get(f"{device_tracker.DOMAIN}.test").state == STATE_HOME + assert hass.states.get(f"{device_tracker.DOMAIN}.testtwo").state == STATE_HOME + assert hass.states.get(f"{sensor.DOMAIN}.asuswrt_connected_devices").state == "2" + assert hass.states.get(f"{sensor.DOMAIN}.asuswrt_download_speed").state == "160.0" + assert hass.states.get(f"{sensor.DOMAIN}.asuswrt_download").state == "60.0" + assert hass.states.get(f"{sensor.DOMAIN}.asuswrt_upload_speed").state == "80.0" + assert hass.states.get(f"{sensor.DOMAIN}.asuswrt_upload").state == "50.0" + + # add one device and remove another + MOCK_DEVICES.pop("a1:b1:c1:d1:e1:f1") + MOCK_DEVICES["a3:b3:c3:d3:e3:f3"] = Device( + "a3:b3:c3:d3:e3:f3", "192.168.1.4", "TestThree" + ) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=30)) + await hass.async_block_till_done() + + # consider home option set, all devices still home + assert hass.states.get(f"{device_tracker.DOMAIN}.test").state == STATE_HOME + assert hass.states.get(f"{device_tracker.DOMAIN}.testtwo").state == STATE_HOME + assert hass.states.get(f"{device_tracker.DOMAIN}.testthree").state == STATE_HOME + assert hass.states.get(f"{sensor.DOMAIN}.asuswrt_connected_devices").state == "2" + + hass.config_entries.async_update_entry( + config_entry, options={CONF_CONSIDER_HOME: 0} + ) + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=30)) + await hass.async_block_till_done() + + # consider home option not set, device "test" not home + assert hass.states.get(f"{device_tracker.DOMAIN}.test").state == STATE_NOT_HOME