From 6d0ce814e79b72b13e8f2eff371e926ddba155ee Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" <nick@koston.org> Date: Wed, 18 Aug 2021 12:33:26 -0500 Subject: [PATCH] Add new network apis to reduce code duplication (#54832) --- homeassistant/components/network/__init__.py | 32 +++++++++++++++- homeassistant/components/ssdp/__init__.py | 30 ++++----------- homeassistant/components/zeroconf/__init__.py | 37 +++++-------------- 3 files changed, 48 insertions(+), 51 deletions(-) diff --git a/homeassistant/components/network/__init__.py b/homeassistant/components/network/__init__.py index 48903d145e7..a7dffad7084 100644 --- a/homeassistant/components/network/__init__.py +++ b/homeassistant/components/network/__init__.py @@ -1,13 +1,14 @@ """The Network Configuration integration.""" from __future__ import annotations +from ipaddress import IPv4Address, IPv6Address import logging import voluptuous as vol from homeassistant.components import websocket_api from homeassistant.components.websocket_api.connection import ActiveConnection -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass @@ -45,6 +46,35 @@ async def async_get_source_ip(hass: HomeAssistant, target_ip: str) -> str: return source_ip if source_ip in all_ipv4s else all_ipv4s[0] +@bind_hass +async def async_get_enabled_source_ips( + hass: HomeAssistant, +) -> list[IPv4Address | IPv6Address]: + """Build the list of enabled source ips.""" + adapters = await async_get_adapters(hass) + sources: list[IPv4Address | IPv6Address] = [] + for adapter in adapters: + if not adapter["enabled"]: + continue + if adapter["ipv4"]: + sources.extend(IPv4Address(ipv4["address"]) for ipv4 in adapter["ipv4"]) + if adapter["ipv6"]: + # With python 3.9 add scope_ids can be + # added by enumerating adapter["ipv6"]s + # IPv6Address(f"::%{ipv6['scope_id']}") + sources.extend(IPv6Address(ipv6["address"]) for ipv6 in adapter["ipv6"]) + + return sources + + +@callback +def async_only_default_interface_enabled(adapters: list[Adapter]) -> bool: + """Check to see if any non-default adapter is enabled.""" + return not any( + adapter["enabled"] and not adapter["default"] for adapter in adapters + ) + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up network for Home Assistant.""" diff --git a/homeassistant/components/ssdp/__init__.py b/homeassistant/components/ssdp/__init__.py index 4d21fdb6aab..1fd2bba77cc 100644 --- a/homeassistant/components/ssdp/__init__.py +++ b/homeassistant/components/ssdp/__init__.py @@ -116,14 +116,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -@core_callback -def _async_use_default_interface(adapters: list[network.Adapter]) -> bool: - for adapter in adapters: - if adapter["enabled"] and not adapter["default"]: - return False - return True - - @core_callback def _async_process_callbacks( callbacks: list[Callable[[dict], None]], discovery_info: dict[str, str] @@ -204,24 +196,16 @@ class Scanner: """Build the list of ssdp sources.""" adapters = await network.async_get_adapters(self.hass) sources: set[IPv4Address | IPv6Address] = set() - if _async_use_default_interface(adapters): + if network.async_only_default_interface_enabled(adapters): sources.add(IPv4Address("0.0.0.0")) return sources - for adapter in adapters: - if not adapter["enabled"]: - continue - if adapter["ipv4"]: - ipv4 = adapter["ipv4"][0] - sources.add(IPv4Address(ipv4["address"])) - if adapter["ipv6"]: - ipv6 = adapter["ipv6"][0] - # With python 3.9 add scope_ids can be - # added by enumerating adapter["ipv6"]s - # IPv6Address(f"::%{ipv6['scope_id']}") - sources.add(IPv6Address(ipv6["address"])) - - return sources + return { + source_ip + for source_ip in await network.async_get_enabled_source_ips(self.hass) + if not source_ip.is_loopback + and not (isinstance(source_ip, IPv6Address) and source_ip.is_global) + } async def async_scan(self, *_: Any) -> None: """Scan for new entries using ssdp default and broadcast target.""" diff --git a/homeassistant/components/zeroconf/__init__.py b/homeassistant/components/zeroconf/__init__.py index 6829c9c5e17..8b1f482e05e 100644 --- a/homeassistant/components/zeroconf/__init__.py +++ b/homeassistant/components/zeroconf/__init__.py @@ -5,7 +5,7 @@ import asyncio from collections.abc import Coroutine from contextlib import suppress import fnmatch -import ipaddress +from ipaddress import IPv6Address, ip_address import logging import socket from typing import Any, TypedDict, cast @@ -131,13 +131,6 @@ async def _async_get_instance(hass: HomeAssistant, **zcargs: Any) -> HaAsyncZero return aio_zc -def _async_use_default_interface(adapters: list[Adapter]) -> bool: - for adapter in adapters: - if adapter["enabled"] and not adapter["default"]: - return False - return True - - async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up Zeroconf and make Home Assistant discoverable.""" zc_args: dict = {} @@ -151,25 +144,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: else: zc_args["ip_version"] = IPVersion.All - if not ipv6 and _async_use_default_interface(adapters): + if not ipv6 and network.async_only_default_interface_enabled(adapters): zc_args["interfaces"] = InterfaceChoice.Default else: - interfaces = zc_args["interfaces"] = [] - for adapter in adapters: - if not adapter["enabled"]: - continue - if ipv4s := adapter["ipv4"]: - interfaces.extend( - ipv4["address"] - for ipv4 in ipv4s - if not ipaddress.IPv4Address(ipv4["address"]).is_loopback - ) - if ipv6s := adapter["ipv6"]: - for ipv6_addr in ipv6s: - address = ipv6_addr["address"] - v6_ip_address = ipaddress.IPv6Address(address) - if not v6_ip_address.is_global and not v6_ip_address.is_loopback: - interfaces.append(ipv6_addr["address"]) + zc_args["interfaces"] = [ + str(source_ip) + for source_ip in await network.async_get_enabled_source_ips(hass) + if not source_ip.is_loopback + and not (isinstance(source_ip, IPv6Address) and source_ip.is_global) + ] aio_zc = await _async_get_instance(hass, **zc_args) zeroconf = cast(HaZeroconf, aio_zc.zeroconf) @@ -213,7 +196,7 @@ def _get_announced_addresses( addresses = { addr.packed for addr in [ - ipaddress.ip_address(ip["address"]) + ip_address(ip["address"]) for adapter in adapters if adapter["enabled"] for ip in cast(list, adapter["ipv6"]) + cast(list, adapter["ipv4"]) @@ -530,7 +513,7 @@ def info_from_service(service: AsyncServiceInfo) -> HaServiceInfo | None: address = service.addresses[0] return { - "host": str(ipaddress.ip_address(address)), + "host": str(ip_address(address)), "port": service.port, "hostname": service.server, "type": service.type, -- GitLab