From 868eb3c735c2136cf81a5f18a19cd08a7d0522a0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ville=20Skytt=C3=A4?= <ville.skytta@iki.fi>
Date: Sun, 22 Dec 2019 20:51:39 +0200
Subject: [PATCH] More helpers type improvements (#30145)

---
 homeassistant/helpers/check_config.py      |  26 +++--
 homeassistant/helpers/config_validation.py | 115 ++++++++++++---------
 homeassistant/helpers/device_registry.py   |   8 +-
 homeassistant/helpers/entity_registry.py   |  42 ++++----
 homeassistant/helpers/logging.py           |  23 +++--
 5 files changed, 124 insertions(+), 90 deletions(-)

diff --git a/homeassistant/helpers/check_config.py b/homeassistant/helpers/check_config.py
index 81e654247b7..1b1e136ed89 100644
--- a/homeassistant/helpers/check_config.py
+++ b/homeassistant/helpers/check_config.py
@@ -1,6 +1,6 @@
 """Helper to check the configuration file."""
-from collections import OrderedDict, namedtuple
-from typing import List
+from collections import OrderedDict
+from typing import List, NamedTuple, Optional
 
 import attr
 import voluptuous as vol
@@ -19,15 +19,20 @@ from homeassistant.config import (
 )
 from homeassistant.core import HomeAssistant
 from homeassistant.exceptions import HomeAssistantError
+from homeassistant.helpers.typing import ConfigType
 from homeassistant.requirements import (
     RequirementsNotFound,
     async_get_integration_with_requirements,
 )
 import homeassistant.util.yaml.loader as yaml_loader
 
-# mypy: allow-untyped-calls, allow-untyped-defs, no-warn-return-any
 
-CheckConfigError = namedtuple("CheckConfigError", "message domain config")
+class CheckConfigError(NamedTuple):
+    """Configuration check error."""
+
+    message: str
+    domain: Optional[str]
+    config: Optional[ConfigType]
 
 
 @attr.s
@@ -36,7 +41,12 @@ class HomeAssistantConfig(OrderedDict):
 
     errors: List[CheckConfigError] = attr.ib(default=attr.Factory(list))
 
-    def add_error(self, message, domain=None, config=None):
+    def add_error(
+        self,
+        message: str,
+        domain: Optional[str] = None,
+        config: Optional[ConfigType] = None,
+    ) -> "HomeAssistantConfig":
         """Add a single error."""
         self.errors.append(CheckConfigError(str(message), domain, config))
         return self
@@ -55,7 +65,9 @@ async def async_check_ha_config_file(hass: HomeAssistant) -> HomeAssistantConfig
     config_dir = hass.config.config_dir
     result = HomeAssistantConfig()
 
-    def _pack_error(package, component, config, message):
+    def _pack_error(
+        package: str, component: str, config: ConfigType, message: str
+    ) -> None:
         """Handle errors from packages: _log_pkg_error."""
         message = "Package {} setup failed. Component {} {}".format(
             package, component, message
@@ -64,7 +76,7 @@ async def async_check_ha_config_file(hass: HomeAssistant) -> HomeAssistantConfig
         pack_config = core_config[CONF_PACKAGES].get(package, config)
         result.add_error(message, domain, pack_config)
 
-    def _comp_error(ex, domain, config):
+    def _comp_error(ex: Exception, domain: str, config: ConfigType) -> None:
         """Handle errors from components: async_log_exception."""
         result.add_error(_format_config_error(ex, domain, config), domain, config)
 
diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py
index 5787db65102..035e1f678bf 100644
--- a/homeassistant/helpers/config_validation.py
+++ b/homeassistant/helpers/config_validation.py
@@ -5,13 +5,26 @@ from datetime import (
     time as time_sys,
     timedelta,
 )
+from enum import Enum
 import inspect
 import logging
 from numbers import Number
 import os
 import re
 from socket import _GLOBAL_DEFAULT_TIMEOUT  # type: ignore # private, not in typeshed
-from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Hashable,
+    List,
+    Optional,
+    Pattern,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+)
 from urllib.parse import urlparse
 from uuid import UUID
 
@@ -48,12 +61,11 @@ from homeassistant.const import (
 )
 from homeassistant.core import split_entity_id, valid_entity_id
 from homeassistant.exceptions import TemplateError
+from homeassistant.helpers import template as template_helper
 from homeassistant.helpers.logging import KeywordStyleAdapter
 from homeassistant.util import slugify as util_slugify
 import homeassistant.util.dt as dt_util
 
-# mypy: allow-untyped-calls, allow-untyped-defs
-# mypy: no-check-untyped-defs, no-warn-return-any
 # pylint: disable=invalid-name
 
 TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'"
@@ -126,7 +138,7 @@ def boolean(value: Any) -> bool:
     raise vol.Invalid("invalid boolean value {}".format(value))
 
 
-def isdevice(value):
+def isdevice(value: Any) -> str:
     """Validate that value is a real device."""
     try:
         os.stat(value)
@@ -135,19 +147,19 @@ def isdevice(value):
         raise vol.Invalid("No device at {} found".format(value))
 
 
-def matches_regex(regex):
+def matches_regex(regex: str) -> Callable[[Any], str]:
     """Validate that the value is a string that matches a regex."""
-    regex = re.compile(regex)
+    compiled = re.compile(regex)
 
     def validator(value: Any) -> str:
         """Validate that value matches the given regex."""
         if not isinstance(value, str):
             raise vol.Invalid("not a string value: {}".format(value))
 
-        if not regex.match(value):
+        if not compiled.match(value):
             raise vol.Invalid(
                 "value {} does not match regular expression {}".format(
-                    value, regex.pattern
+                    value, compiled.pattern
                 )
             )
 
@@ -156,14 +168,14 @@ def matches_regex(regex):
     return validator
 
 
-def is_regex(value):
+def is_regex(value: Any) -> Pattern[Any]:
     """Validate that a string is a valid regular expression."""
     try:
         r = re.compile(value)
         return r
     except TypeError:
         raise vol.Invalid(
-            "value {} is of the wrong type for a regular " "expression".format(value)
+            "value {} is of the wrong type for a regular expression".format(value)
         )
     except re.error:
         raise vol.Invalid("value {} is not a valid regular expression".format(value))
@@ -204,9 +216,9 @@ def ensure_list(value: Union[T, List[T], None]) -> List[T]:
 
 def entity_id(value: Any) -> str:
     """Validate Entity ID."""
-    value = string(value).lower()
-    if valid_entity_id(value):
-        return value
+    str_value = string(value).lower()
+    if valid_entity_id(str_value):
+        return str_value
 
     raise vol.Invalid("Entity ID {} is an invalid entity id".format(value))
 
@@ -253,17 +265,17 @@ def entities_domain(domain: str) -> Callable[[Union[str, List]], List[str]]:
     return validate
 
 
-def enum(enumClass):
+def enum(enumClass: Type[Enum]) -> vol.All:
     """Create validator for specified enum."""
     return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__)
 
 
-def icon(value):
+def icon(value: Any) -> str:
     """Validate icon."""
-    value = str(value)
+    str_value = str(value)
 
-    if ":" in value:
-        return value
+    if ":" in str_value:
+        return str_value
 
     raise vol.Invalid('Icons should be specified in the form "prefix:name"')
 
@@ -362,7 +374,7 @@ def time_period_seconds(value: Union[int, str]) -> timedelta:
 time_period = vol.Any(time_period_str, time_period_seconds, timedelta, time_period_dict)
 
 
-def match_all(value):
+def match_all(value: T) -> T:
     """Validate that matches all values."""
     return value
 
@@ -382,12 +394,12 @@ def remove_falsy(value: List[T]) -> List[T]:
     return [v for v in value if v]
 
 
-def service(value):
+def service(value: Any) -> str:
     """Validate service."""
     # Services use same format as entities so we can use same helper.
-    value = string(value).lower()
-    if valid_entity_id(value):
-        return value
+    str_value = string(value).lower()
+    if valid_entity_id(str_value):
+        return str_value
     raise vol.Invalid("Service {} does not match format <domain>.<name>".format(value))
 
 
@@ -407,7 +419,7 @@ def schema_with_slug_keys(value_schema: Union[T, Callable]) -> Callable:
         for key in value.keys():
             slug(key)
 
-        return schema(value)
+        return cast(Dict, schema(value))
 
     return verify
 
@@ -416,10 +428,10 @@ def slug(value: Any) -> str:
     """Validate value is a valid slug."""
     if value is None:
         raise vol.Invalid("Slug should not be None")
-    value = str(value)
-    slg = util_slugify(value)
-    if value == slg:
-        return value
+    str_value = str(value)
+    slg = util_slugify(str_value)
+    if str_value == slg:
+        return str_value
     raise vol.Invalid("invalid slug {} (try {})".format(value, slg))
 
 
@@ -458,42 +470,41 @@ unit_system = vol.All(
 )
 
 
-def template(value):
+def template(value: Optional[Any]) -> template_helper.Template:
     """Validate a jinja2 template."""
-    from homeassistant.helpers import template as template_helper
 
     if value is None:
         raise vol.Invalid("template value is None")
     if isinstance(value, (list, dict, template_helper.Template)):
         raise vol.Invalid("template value should be a string")
 
-    value = template_helper.Template(str(value))
+    template_value = template_helper.Template(str(value))  # type: ignore
 
     try:
-        value.ensure_valid()
-        return value
+        template_value.ensure_valid()
+        return cast(template_helper.Template, template_value)
     except TemplateError as ex:
         raise vol.Invalid("invalid template ({})".format(ex))
 
 
-def template_complex(value):
+def template_complex(value: Any) -> Any:
     """Validate a complex jinja2 template."""
     if isinstance(value, list):
-        return_value = value.copy()
-        for idx, element in enumerate(return_value):
-            return_value[idx] = template_complex(element)
-        return return_value
+        return_list = value.copy()
+        for idx, element in enumerate(return_list):
+            return_list[idx] = template_complex(element)
+        return return_list
     if isinstance(value, dict):
-        return_value = value.copy()
-        for key, element in return_value.items():
-            return_value[key] = template_complex(element)
-        return return_value
+        return_dict = value.copy()
+        for key, element in return_dict.items():
+            return_dict[key] = template_complex(element)
+        return return_dict
     if isinstance(value, str):
         return template(value)
     return value
 
 
-def datetime(value):
+def datetime(value: Any) -> datetime_sys:
     """Validate datetime."""
     if isinstance(value, datetime_sys):
         return value
@@ -509,7 +520,7 @@ def datetime(value):
     return date_val
 
 
-def time_zone(value):
+def time_zone(value: str) -> str:
     """Validate timezone."""
     if dt_util.get_time_zone(value) is not None:
         return value
@@ -522,7 +533,7 @@ def time_zone(value):
 weekdays = vol.All(ensure_list, [vol.In(WEEKDAYS)])
 
 
-def socket_timeout(value):
+def socket_timeout(value: Optional[Any]) -> object:
     """Validate timeout float > 0.0.
 
     None coerced to socket._GLOBAL_DEFAULT_TIMEOUT bare object.
@@ -544,12 +555,12 @@ def url(value: Any) -> str:
     url_in = str(value)
 
     if urlparse(url_in).scheme in ["http", "https"]:
-        return vol.Schema(vol.Url())(url_in)
+        return cast(str, vol.Schema(vol.Url())(url_in))
 
     raise vol.Invalid("invalid url")
 
 
-def x10_address(value):
+def x10_address(value: str) -> str:
     """Validate an x10 address."""
     regex = re.compile(r"([A-Pa-p]{1})(?:[2-9]|1[0-6]?)$")
     if not regex.match(value):
@@ -557,7 +568,7 @@ def x10_address(value):
     return str(value).lower()
 
 
-def uuid4_hex(value):
+def uuid4_hex(value: Any) -> str:
     """Validate a v4 UUID in hex format."""
     try:
         result = UUID(value, version=4)
@@ -678,10 +689,12 @@ def deprecated(
 # Validator helpers
 
 
-def key_dependency(key, dependency):
+def key_dependency(
+    key: Hashable, dependency: Hashable
+) -> Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]:
     """Validate that all dependencies exist for key."""
 
-    def validator(value):
+    def validator(value: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
         """Test dependencies."""
         if not isinstance(value, dict):
             raise vol.Invalid("key dependencies require a dict")
@@ -696,7 +709,7 @@ def key_dependency(key, dependency):
     return validator
 
 
-def custom_serializer(schema):
+def custom_serializer(schema: Any) -> Any:
     """Serialize additional types for voluptuous_serialize."""
     if schema is positive_time_period_dict:
         return {"type": "positive_time_period_dict"}
diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py
index 4818de83cb9..512334c8d3c 100644
--- a/homeassistant/helpers/device_registry.py
+++ b/homeassistant/helpers/device_registry.py
@@ -12,8 +12,7 @@ from homeassistant.loader import bind_hass
 
 from .typing import HomeAssistantType
 
-# mypy: allow-untyped-calls, allow-untyped-defs
-# mypy: no-check-untyped-defs, no-warn-return-any
+# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
 
 _LOGGER = logging.getLogger(__name__)
 _UNDEF = object()
@@ -71,10 +70,11 @@ def format_mac(mac: str) -> str:
 class DeviceRegistry:
     """Class to hold a registry of devices."""
 
-    def __init__(self, hass):
+    devices: Dict[str, DeviceEntry]
+
+    def __init__(self, hass: HomeAssistantType) -> None:
         """Initialize the device registry."""
         self.hass = hass
-        self.devices = None
         self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
 
     @callback
diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py
index a5bd62d973c..5eb79965880 100644
--- a/homeassistant/helpers/entity_registry.py
+++ b/homeassistant/helpers/entity_registry.py
@@ -11,7 +11,7 @@ import asyncio
 from collections import OrderedDict
 from itertools import chain
 import logging
-from typing import Any, Dict, Iterable, List, Optional, cast
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast
 
 import attr
 
@@ -23,6 +23,9 @@ from homeassistant.util.yaml import load_yaml
 
 from .typing import HomeAssistantType
 
+if TYPE_CHECKING:
+    from homeassistant.config_entries import ConfigEntry  # noqa: F401
+
 # mypy: allow-untyped-defs, no-check-untyped-defs
 
 PATH_REGISTRY = "entity_registry.yaml"
@@ -48,7 +51,7 @@ class RegistryEntry:
     unique_id = attr.ib(type=str)
     platform = attr.ib(type=str)
     name = attr.ib(type=str, default=None)
-    device_id = attr.ib(type=str, default=None)
+    device_id: Optional[str] = attr.ib(default=None)
     config_entry_id: Optional[str] = attr.ib(default=None)
     disabled_by = attr.ib(
         type=Optional[str],
@@ -135,16 +138,16 @@ class EntityRegistry:
     @callback
     def async_get_or_create(
         self,
-        domain,
-        platform,
-        unique_id,
+        domain: str,
+        platform: str,
+        unique_id: str,
         *,
-        suggested_object_id=None,
-        config_entry=None,
-        device_id=None,
-        known_object_ids=None,
-        disabled_by=None,
-    ):
+        suggested_object_id: Optional[str] = None,
+        config_entry: Optional["ConfigEntry"] = None,
+        device_id: Optional[str] = None,
+        known_object_ids: Optional[Iterable[str]] = None,
+        disabled_by: Optional[str] = None,
+    ) -> RegistryEntry:
         """Get entity. Create if it doesn't exist."""
         config_entry_id = None
         if config_entry:
@@ -153,7 +156,7 @@ class EntityRegistry:
         entity_id = self.async_get_entity_id(domain, platform, unique_id)
 
         if entity_id:
-            return self._async_update_entity(
+            return self._async_update_entity(  # type: ignore
                 entity_id,
                 config_entry_id=config_entry_id or _UNDEF,
                 device_id=device_id or _UNDEF,
@@ -228,12 +231,15 @@ class EntityRegistry:
         disabled_by=_UNDEF,
     ):
         """Update properties of an entity."""
-        return self._async_update_entity(
-            entity_id,
-            name=name,
-            new_entity_id=new_entity_id,
-            new_unique_id=new_unique_id,
-            disabled_by=disabled_by,
+        return cast(  # cast until we have _async_update_entity type hinted
+            RegistryEntry,
+            self._async_update_entity(
+                entity_id,
+                name=name,
+                new_entity_id=new_entity_id,
+                new_unique_id=new_unique_id,
+                disabled_by=disabled_by,
+            ),
         )
 
     @callback
diff --git a/homeassistant/helpers/logging.py b/homeassistant/helpers/logging.py
index 7b2507d9e05..0b274458045 100644
--- a/homeassistant/helpers/logging.py
+++ b/homeassistant/helpers/logging.py
@@ -1,8 +1,7 @@
 """Helpers for logging allowing more advanced logging styles to be used."""
 import inspect
 import logging
-
-# mypy: allow-untyped-defs, no-check-untyped-defs
+from typing import Any, Mapping, MutableMapping, Optional, Tuple
 
 
 class KeywordMessage:
@@ -12,13 +11,13 @@ class KeywordMessage:
     Adapted from: https://stackoverflow.com/a/24683360/2267718
     """
 
-    def __init__(self, fmt, args, kwargs):
-        """Initialize a new BraceMessage object."""
+    def __init__(self, fmt: Any, args: Any, kwargs: Mapping[str, Any]) -> None:
+        """Initialize a new KeywordMessage object."""
         self._fmt = fmt
         self._args = args
         self._kwargs = kwargs
 
-    def __str__(self):
+    def __str__(self) -> str:
         """Convert the object to a string for logging."""
         return str(self._fmt).format(*self._args, **self._kwargs)
 
@@ -26,26 +25,30 @@ class KeywordMessage:
 class KeywordStyleAdapter(logging.LoggerAdapter):
     """Represents an adapter wrapping the logger allowing KeywordMessages."""
 
-    def __init__(self, logger, extra=None):
+    def __init__(
+        self, logger: logging.Logger, extra: Optional[Mapping[str, Any]] = None
+    ) -> None:
         """Initialize a new StyleAdapter for the provided logger."""
         super().__init__(logger, extra or {})
 
-    def log(self, level, msg, *args, **kwargs):
+    def log(self, level: int, msg: Any, *args: Any, **kwargs: Any) -> None:
         """Log the message provided at the appropriate level."""
         if self.isEnabledFor(level):
             msg, log_kwargs = self.process(msg, kwargs)
-            self.logger._log(  # pylint: disable=protected-access
+            self.logger._log(  # type: ignore # pylint: disable=protected-access
                 level, KeywordMessage(msg, args, kwargs), (), **log_kwargs
             )
 
-    def process(self, msg, kwargs):
+    def process(
+        self, msg: Any, kwargs: MutableMapping[str, Any]
+    ) -> Tuple[Any, MutableMapping[str, Any]]:
         """Process the keyward args in preparation for logging."""
         return (
             msg,
             {
                 k: kwargs[k]
                 for k in inspect.getfullargspec(
-                    self.logger._log  # pylint: disable=protected-access
+                    self.logger._log  # type: ignore # pylint: disable=protected-access
                 ).args[1:]
                 if k in kwargs
             },
-- 
GitLab