diff --git a/homeassistant/__main__.py b/homeassistant/__main__.py index 7a5345a1a73e3adf0276c919dba20d35362f7830..65b1cd2ae1a6164047e8c56d9f7ff00bba879bb1 100644 --- a/homeassistant/__main__.py +++ b/homeassistant/__main__.py @@ -20,7 +20,7 @@ from homeassistant.const import ( ) -def attempt_use_uvloop(): +def attempt_use_uvloop() -> None: """Attempt to use uvloop.""" import asyncio @@ -280,8 +280,8 @@ def setup_and_run_hass(config_dir: str, # Imported here to avoid importing asyncio before monkey patch from homeassistant.util.async_ import run_callback_threadsafe - def open_browser(event): - """Open the webinterface in a browser.""" + def open_browser(_: Any) -> None: + """Open the web interface in a browser.""" if hass.config.api is not None: # type: ignore import webbrowser webbrowser.open(hass.config.api.base_url) # type: ignore diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index d832dda47547ee0d8324020b2cf9ecf7553c2baa..43c7168dd2e861a188ffa00d573b2f90967bd5da 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -221,8 +221,8 @@ async def async_from_config_file(config_path: str, @core.callback def async_enable_logging(hass: core.HomeAssistant, verbose: bool = False, - log_rotate_days=None, - log_file=None, + log_rotate_days: Optional[int] = None, + log_file: Optional[str] = None, log_no_color: bool = False) -> None: """Set up the logging. @@ -291,7 +291,7 @@ def async_enable_logging(hass: core.HomeAssistant, async_handler = AsyncHandler(hass.loop, err_handler) - async def async_stop_async_handler(event): + async def async_stop_async_handler(_: Any) -> None: """Cleanup async handler.""" logging.getLogger('').removeHandler(async_handler) # type: ignore await async_handler.async_close(blocking=True) diff --git a/homeassistant/components/rachio.py b/homeassistant/components/rachio.py index 7162913087d7969edea9b539be9c5b6038d0223d..0e67e15d5c099a1e1efacc0324ba991df587e71c 100644 --- a/homeassistant/components/rachio.py +++ b/homeassistant/components/rachio.py @@ -9,7 +9,7 @@ import logging from aiohttp import web import voluptuous as vol - +from typing import Optional from homeassistant.auth.util import generate_secret from homeassistant.components.http import HomeAssistantView from homeassistant.const import CONF_API_KEY, EVENT_HOMEASSISTANT_STOP, URL_API @@ -241,7 +241,7 @@ class RachioIro: # Only enabled zones return [z for z in self._zones if z[KEY_ENABLED]] - def get_zone(self, zone_id) -> dict or None: + def get_zone(self, zone_id) -> Optional[dict]: """Return the zone with the given ID.""" for zone in self.list_zones(include_disabled=True): if zone[KEY_ID] == zone_id: diff --git a/homeassistant/config.py b/homeassistant/config.py index d458ce44d370bf7258f5da0dc6c79513092854ea..6120a20fd634734cb26831b3d755829005eecc4b 100644 --- a/homeassistant/config.py +++ b/homeassistant/config.py @@ -7,8 +7,9 @@ import os import re import shutil # pylint: disable=unused-import -from typing import Any, Tuple, Optional # noqa: F401 - +from typing import ( # noqa: F401 + Any, Tuple, Optional, Dict, List, Union, Callable) +from types import ModuleType import voluptuous as vol from voluptuous.humanize import humanize_error @@ -21,7 +22,7 @@ from homeassistant.const import ( CONF_UNIT_SYSTEM_IMPERIAL, CONF_TEMPERATURE_UNIT, TEMP_CELSIUS, __version__, CONF_CUSTOMIZE, CONF_CUSTOMIZE_DOMAIN, CONF_CUSTOMIZE_GLOB, CONF_WHITELIST_EXTERNAL_DIRS, CONF_AUTH_PROVIDERS, CONF_TYPE) -from homeassistant.core import callback, DOMAIN as CONF_CORE +from homeassistant.core import callback, DOMAIN as CONF_CORE, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import get_component, get_platform from homeassistant.util.yaml import load_yaml, SECRET_YAML @@ -193,7 +194,7 @@ def ensure_config_exists(config_dir: str, detect_location: bool = True)\ return config_path -def create_default_config(config_dir: str, detect_location=True)\ +def create_default_config(config_dir: str, detect_location: bool = True)\ -> Optional[str]: """Create a default configuration file in given configuration directory. @@ -276,7 +277,7 @@ def create_default_config(config_dir: str, detect_location=True)\ return None -async def async_hass_config_yaml(hass): +async def async_hass_config_yaml(hass: HomeAssistant) -> Dict: """Load YAML from a Home Assistant configuration file. This function allow a component inside the asyncio loop to reload its @@ -284,23 +285,26 @@ async def async_hass_config_yaml(hass): This method is a coroutine. """ - def _load_hass_yaml_config(): + def _load_hass_yaml_config() -> Dict: path = find_config_file(hass.config.config_dir) - conf = load_yaml_config_file(path) - return conf + if path is None: + raise HomeAssistantError( + "Config file not found in: {}".format(hass.config.config_dir)) + return load_yaml_config_file(path) - conf = await hass.async_add_job(_load_hass_yaml_config) - return conf + return await hass.async_add_executor_job(_load_hass_yaml_config) -def find_config_file(config_dir: str) -> Optional[str]: +def find_config_file(config_dir: Optional[str]) -> Optional[str]: """Look in given directory for supported configuration files.""" + if config_dir is None: + return None config_path = os.path.join(config_dir, YAML_CONFIG_FILE) return config_path if os.path.isfile(config_path) else None -def load_yaml_config_file(config_path): +def load_yaml_config_file(config_path: str) -> Dict[Any, Any]: """Parse a YAML configuration file. This method needs to run in an executor. @@ -323,7 +327,7 @@ def load_yaml_config_file(config_path): return conf_dict -def process_ha_config_upgrade(hass): +def process_ha_config_upgrade(hass: HomeAssistant) -> None: """Upgrade configuration if necessary. This method needs to run in an executor. @@ -360,7 +364,8 @@ def process_ha_config_upgrade(hass): @callback -def async_log_exception(ex, domain, config, hass): +def async_log_exception(ex: vol.Invalid, domain: str, config: Dict, + hass: HomeAssistant) -> None: """Log an error for configuration validation. This method must be run in the event loop. @@ -371,7 +376,7 @@ def async_log_exception(ex, domain, config, hass): @callback -def _format_config_error(ex, domain, config): +def _format_config_error(ex: vol.Invalid, domain: str, config: Dict) -> str: """Generate log exception for configuration validation. This method must be run in the event loop. @@ -396,7 +401,8 @@ def _format_config_error(ex, domain, config): return message -async def async_process_ha_core_config(hass, config): +async def async_process_ha_core_config( + hass: HomeAssistant, config: Dict) -> None: """Process the [homeassistant] section from the configuration. This method is a coroutine. @@ -405,12 +411,12 @@ async def async_process_ha_core_config(hass, config): # Only load auth during startup. if not hasattr(hass, 'auth'): - hass.auth = await auth.auth_manager_from_config( - hass, config.get(CONF_AUTH_PROVIDERS, [])) + setattr(hass, 'auth', await auth.auth_manager_from_config( + hass, config.get(CONF_AUTH_PROVIDERS, []))) hac = hass.config - def set_time_zone(time_zone_str): + def set_time_zone(time_zone_str: Optional[str]) -> None: """Help to set the time zone.""" if time_zone_str is None: return @@ -430,11 +436,10 @@ async def async_process_ha_core_config(hass, config): if key in config: setattr(hac, attr, config[key]) - if CONF_TIME_ZONE in config: - set_time_zone(config.get(CONF_TIME_ZONE)) + set_time_zone(config.get(CONF_TIME_ZONE)) # Init whitelist external dir - hac.whitelist_external_dirs = set((hass.config.path('www'),)) + hac.whitelist_external_dirs = {hass.config.path('www')} if CONF_WHITELIST_EXTERNAL_DIRS in config: hac.whitelist_external_dirs.update( set(config[CONF_WHITELIST_EXTERNAL_DIRS])) @@ -484,12 +489,12 @@ async def async_process_ha_core_config(hass, config): hac.time_zone, hac.elevation): return - discovered = [] + discovered = [] # type: List[Tuple[str, Any]] # If we miss some of the needed values, auto detect them if None in (hac.latitude, hac.longitude, hac.units, hac.time_zone): - info = await hass.async_add_job( + info = await hass.async_add_executor_job( loc_util.detect_location_info) if info is None: @@ -515,7 +520,7 @@ async def async_process_ha_core_config(hass, config): if hac.elevation is None and hac.latitude is not None and \ hac.longitude is not None: - elevation = await hass.async_add_job( + elevation = await hass.async_add_executor_job( loc_util.elevation, hac.latitude, hac.longitude) hac.elevation = elevation discovered.append(('elevation', elevation)) @@ -526,7 +531,8 @@ async def async_process_ha_core_config(hass, config): ", ".join('{}: {}'.format(key, val) for key, val in discovered)) -def _log_pkg_error(package, component, config, message): +def _log_pkg_error( + package: str, component: str, config: Dict, message: str) -> None: """Log an error while merging packages.""" message = "Package {} setup failed. Component {} {}".format( package, component, message) @@ -539,12 +545,13 @@ def _log_pkg_error(package, component, config, message): _LOGGER.error(message) -def _identify_config_schema(module): +def _identify_config_schema(module: ModuleType) -> \ + Tuple[Optional[str], Optional[Dict]]: """Extract the schema and identify list or dict based.""" try: - schema = module.CONFIG_SCHEMA.schema[module.DOMAIN] + schema = module.CONFIG_SCHEMA.schema[module.DOMAIN] # type: ignore except (AttributeError, KeyError): - return (None, None) + return None, None t_schema = str(schema) if t_schema.startswith('{'): return ('dict', schema) @@ -553,9 +560,10 @@ def _identify_config_schema(module): return '', schema -def _recursive_merge(conf, package): +def _recursive_merge( + conf: Dict[str, Any], package: Dict[str, Any]) -> Union[bool, str]: """Merge package into conf, recursively.""" - error = False + error = False # type: Union[bool, str] for key, pack_conf in package.items(): if isinstance(pack_conf, dict): if not pack_conf: @@ -576,8 +584,8 @@ def _recursive_merge(conf, package): return error -def merge_packages_config(hass, config, packages, - _log_pkg_error=_log_pkg_error): +def merge_packages_config(hass: HomeAssistant, config: Dict, packages: Dict, + _log_pkg_error: Callable = _log_pkg_error) -> Dict: """Merge packages into the top-level configuration. Mutate config.""" # pylint: disable=too-many-nested-blocks PACKAGES_CONFIG_SCHEMA(packages) @@ -641,7 +649,8 @@ def merge_packages_config(hass, config, packages, @callback -def async_process_component_config(hass, config, domain): +def async_process_component_config( + hass: HomeAssistant, config: Dict, domain: str) -> Optional[Dict]: """Check component configuration and return processed configuration. Returns None on error. @@ -703,14 +712,14 @@ def async_process_component_config(hass, config, domain): return config -async def async_check_ha_config_file(hass): +async def async_check_ha_config_file(hass: HomeAssistant) -> Optional[str]: """Check if Home Assistant configuration file is valid. This method is a coroutine. """ from homeassistant.scripts.check_config import check_ha_config_file - res = await hass.async_add_job( + res = await hass.async_add_executor_job( check_ha_config_file, hass) if not res.errors: @@ -719,7 +728,9 @@ async def async_check_ha_config_file(hass): @callback -def async_notify_setup_error(hass, component, display_link=False): +def async_notify_setup_error( + hass: HomeAssistant, component: str, + display_link: bool = False) -> None: """Print a persistent notification. This method must be run in the event loop. diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 0fc66174c6685dc858cb835c9cf08a34b70a06f3..8e2bb3fa5df9c4bc21dfb45fc6a64a400fbf9bd4 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -113,10 +113,10 @@ the flow from the config panel. import logging import uuid -from typing import Set # noqa pylint: disable=unused-import +from typing import Set, Optional # noqa pylint: disable=unused-import from homeassistant import data_entry_flow -from homeassistant.core import callback +from homeassistant.core import callback, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component, async_process_deps_reqs from homeassistant.util.decorator import Registry @@ -164,8 +164,9 @@ class ConfigEntry: __slots__ = ('entry_id', 'version', 'domain', 'title', 'data', 'source', 'state') - def __init__(self, version, domain, title, data, source, entry_id=None, - state=ENTRY_STATE_NOT_LOADED): + def __init__(self, version: str, domain: str, title: str, data: dict, + source: str, entry_id: Optional[str] = None, + state: str = ENTRY_STATE_NOT_LOADED) -> None: """Initialize a config entry.""" # Unique id of the config entry self.entry_id = entry_id or uuid.uuid4().hex @@ -188,7 +189,8 @@ class ConfigEntry: # State of the entry (LOADED, NOT_LOADED) self.state = state - async def async_setup(self, hass, *, component=None): + async def async_setup( + self, hass: HomeAssistant, *, component=None) -> None: """Set up an entry.""" if component is None: component = getattr(hass.components, self.domain) diff --git a/homeassistant/core.py b/homeassistant/core.py index f868c52cfb00bdfba93ebb48535e8abd1b862c3c..a7684d130ae4963874bd93a140849bd6e0cfb79a 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -4,9 +4,9 @@ Core components of Home Assistant. Home Assistant is a Home Automation framework for observing the state of entities and react to changes. """ -# pylint: disable=unused-import import asyncio from concurrent.futures import ThreadPoolExecutor +import datetime import enum import logging import os @@ -17,9 +17,10 @@ import threading from time import monotonic from types import MappingProxyType +# pylint: disable=unused-import from typing import ( # NOQA Optional, Any, Callable, List, TypeVar, Dict, Coroutine, Set, - TYPE_CHECKING) + TYPE_CHECKING, Awaitable, Iterator) from async_timeout import timeout import voluptuous as vol @@ -44,11 +45,13 @@ from homeassistant.util import location from homeassistant.util.unit_system import UnitSystem, METRIC_SYSTEM # NOQA # Typing imports that create a circular dependency -# pylint: disable=using-constant-test,unused-import +# pylint: disable=using-constant-test if TYPE_CHECKING: - from homeassistant.config_entries import ConfigEntries # noqa + from homeassistant.config_entries import ConfigEntries # noqa T = TypeVar('T') +CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable) +CALLBACK_TYPE = Callable[[], None] DOMAIN = 'homeassistant' @@ -79,7 +82,7 @@ def valid_state(state: str) -> bool: return len(state) < 256 -def callback(func: Callable[..., T]) -> Callable[..., T]: +def callback(func: CALLABLE_T) -> CALLABLE_T: """Annotation to mark method as safe to call from within the event loop.""" setattr(func, '_hass_callback', True) return func @@ -91,7 +94,7 @@ def is_callback(func: Callable[..., Any]) -> bool: @callback -def async_loop_exception_handler(loop, context): +def async_loop_exception_handler(_: Any, context: Dict) -> None: """Handle all exception inside the core loop.""" kwargs = {} exception = context.get('exception') @@ -119,7 +122,9 @@ class CoreState(enum.Enum): class HomeAssistant: """Root object of the Home Assistant home automation.""" - def __init__(self, loop=None): + def __init__( + self, + loop: Optional[asyncio.events.AbstractEventLoop] = None) -> None: """Initialize new Home Assistant object.""" if sys.platform == 'win32': self.loop = loop or asyncio.ProactorEventLoop() @@ -170,7 +175,7 @@ class HomeAssistant: self.loop.close() return self.exit_code - async def async_start(self): + async def async_start(self) -> None: """Finalize startup from inside the event loop. This method is a coroutine. @@ -178,8 +183,7 @@ class HomeAssistant: _LOGGER.info("Starting Home Assistant") self.state = CoreState.starting - # pylint: disable=protected-access - self.loop._thread_ident = threading.get_ident() + setattr(self.loop, '_thread_ident', threading.get_ident()) self.bus.async_fire(EVENT_HOMEASSISTANT_START) try: @@ -230,7 +234,8 @@ class HomeAssistant: elif asyncio.iscoroutinefunction(target): task = self.loop.create_task(target(*args)) else: - task = self.loop.run_in_executor(None, target, *args) + task = self.loop.run_in_executor( # type: ignore + None, target, *args) # If a task is scheduled if self._track_task and task is not None: @@ -256,11 +261,11 @@ class HomeAssistant: @callback def async_add_executor_job( self, - target: Callable[..., Any], - *args: Any) -> asyncio.Future: + target: Callable[..., T], + *args: Any) -> Awaitable[T]: """Add an executor job from within the event loop.""" - task = self.loop.run_in_executor( # type: ignore - None, target, *args) # type: asyncio.Future + task = self.loop.run_in_executor( + None, target, *args) # If a task is scheduled if self._track_task: @@ -269,12 +274,12 @@ class HomeAssistant: return task @callback - def async_track_tasks(self): + def async_track_tasks(self) -> None: """Track tasks so you can wait for all tasks to be done.""" self._track_task = True @callback - def async_stop_track_tasks(self): + def async_stop_track_tasks(self) -> None: """Stop track tasks so you can't wait for all tasks to be done.""" self._track_task = False @@ -297,7 +302,7 @@ class HomeAssistant: run_coroutine_threadsafe( self.async_block_till_done(), loop=self.loop).result() - async def async_block_till_done(self): + async def async_block_till_done(self) -> None: """Block till all pending work is done.""" # To flush out any call_soon_threadsafe await asyncio.sleep(0, loop=self.loop) @@ -342,9 +347,9 @@ class EventOrigin(enum.Enum): local = 'LOCAL' remote = 'REMOTE' - def __str__(self): + def __str__(self) -> str: """Return the event.""" - return self.value + return self.value # type: ignore class Event: @@ -352,15 +357,16 @@ class Event: __slots__ = ['event_type', 'data', 'origin', 'time_fired'] - def __init__(self, event_type, data=None, origin=EventOrigin.local, - time_fired=None): + def __init__(self, event_type: str, data: Optional[Dict] = None, + origin: EventOrigin = EventOrigin.local, + time_fired: Optional[int] = None) -> None: """Initialize a new event.""" self.event_type = event_type self.data = data or {} self.origin = origin self.time_fired = time_fired or dt_util.utcnow() - def as_dict(self): + def as_dict(self) -> Dict: """Create a dict representation of this Event. Async friendly. @@ -372,7 +378,7 @@ class Event: 'time_fired': self.time_fired, } - def __repr__(self): + def __repr__(self) -> str: """Return the representation.""" # pylint: disable=maybe-no-member if self.data: @@ -383,9 +389,9 @@ class Event: return "<Event {}[{}]>".format(self.event_type, str(self.origin)[0]) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Return the comparison.""" - return (self.__class__ == other.__class__ and + return (self.__class__ == other.__class__ and # type: ignore self.event_type == other.event_type and self.data == other.data and self.origin == other.origin and @@ -401,7 +407,7 @@ class EventBus: self._hass = hass @callback - def async_listeners(self): + def async_listeners(self) -> Dict[str, int]: """Return dictionary with events and the number of listeners. This method must be run in the event loop. @@ -410,20 +416,21 @@ class EventBus: for key in self._listeners} @property - def listeners(self): + def listeners(self) -> Dict[str, int]: """Return dictionary with events and the number of listeners.""" - return run_callback_threadsafe( + return run_callback_threadsafe( # type: ignore self._hass.loop, self.async_listeners ).result() - def fire(self, event_type: str, event_data=None, origin=EventOrigin.local): + def fire(self, event_type: str, event_data: Optional[Dict] = None, + origin: EventOrigin = EventOrigin.local) -> None: """Fire an event.""" self._hass.loop.call_soon_threadsafe( self.async_fire, event_type, event_data, origin) @callback - def async_fire(self, event_type: str, event_data=None, - origin=EventOrigin.local): + def async_fire(self, event_type: str, event_data: Optional[Dict] = None, + origin: EventOrigin = EventOrigin.local) -> None: """Fire an event. This method must be run in the event loop. @@ -447,7 +454,8 @@ class EventBus: for func in listeners: self._hass.async_add_job(func, event) - def listen(self, event_type, listener): + def listen( + self, event_type: str, listener: Callable) -> CALLBACK_TYPE: """Listen for all events or events of a specific type. To listen to all events specify the constant ``MATCH_ALL`` @@ -456,7 +464,7 @@ class EventBus: async_remove_listener = run_callback_threadsafe( self._hass.loop, self.async_listen, event_type, listener).result() - def remove_listener(): + def remove_listener() -> None: """Remove the listener.""" run_callback_threadsafe( self._hass.loop, async_remove_listener).result() @@ -464,7 +472,8 @@ class EventBus: return remove_listener @callback - def async_listen(self, event_type, listener): + def async_listen( + self, event_type: str, listener: Callable) -> CALLBACK_TYPE: """Listen for all events or events of a specific type. To listen to all events specify the constant ``MATCH_ALL`` @@ -477,13 +486,14 @@ class EventBus: else: self._listeners[event_type] = [listener] - def remove_listener(): + def remove_listener() -> None: """Remove the listener.""" self._async_remove_listener(event_type, listener) return remove_listener - def listen_once(self, event_type, listener): + def listen_once( + self, event_type: str, listener: Callable) -> CALLBACK_TYPE: """Listen once for event of a specific type. To listen to all events specify the constant ``MATCH_ALL`` @@ -495,7 +505,7 @@ class EventBus: self._hass.loop, self.async_listen_once, event_type, listener, ).result() - def remove_listener(): + def remove_listener() -> None: """Remove the listener.""" run_callback_threadsafe( self._hass.loop, async_remove_listener).result() @@ -503,7 +513,8 @@ class EventBus: return remove_listener @callback - def async_listen_once(self, event_type, listener): + def async_listen_once( + self, event_type: str, listener: Callable) -> CALLBACK_TYPE: """Listen once for event of a specific type. To listen to all events specify the constant ``MATCH_ALL`` @@ -514,8 +525,8 @@ class EventBus: This method must be run in the event loop. """ @callback - def onetime_listener(event): - """Remove listener from eventbus and then fire listener.""" + def onetime_listener(event: Event) -> None: + """Remove listener from event bus and then fire listener.""" if hasattr(onetime_listener, 'run'): return # Set variable so that we will never run twice. @@ -530,7 +541,8 @@ class EventBus: return self.async_listen(event_type, onetime_listener) @callback - def _async_remove_listener(self, event_type, listener): + def _async_remove_listener( + self, event_type: str, listener: Callable) -> None: """Remove a listener of a specific event_type. This method must be run in the event loop. @@ -560,8 +572,10 @@ class State: __slots__ = ['entity_id', 'state', 'attributes', 'last_changed', 'last_updated'] - def __init__(self, entity_id, state, attributes=None, last_changed=None, - last_updated=None): + def __init__(self, entity_id: str, state: Any, + attributes: Optional[Dict] = None, + last_changed: Optional[datetime.datetime] = None, + last_updated: Optional[datetime.datetime] = None) -> None: """Initialize a new state.""" state = str(state) @@ -582,23 +596,23 @@ class State: self.last_changed = last_changed or self.last_updated @property - def domain(self): + def domain(self) -> str: """Domain of this state.""" return split_entity_id(self.entity_id)[0] @property - def object_id(self): + def object_id(self) -> str: """Object id of this state.""" return split_entity_id(self.entity_id)[1] @property - def name(self): + def name(self) -> str: """Name of this state.""" return ( self.attributes.get(ATTR_FRIENDLY_NAME) or self.object_id.replace('_', ' ')) - def as_dict(self): + def as_dict(self) -> Dict: """Return a dict representation of the State. Async friendly. @@ -613,7 +627,7 @@ class State: 'last_updated': self.last_updated} @classmethod - def from_dict(cls, json_dict): + def from_dict(cls, json_dict: Dict) -> Any: """Initialize a state from a dict. Async friendly. @@ -637,14 +651,14 @@ class State: return cls(json_dict['entity_id'], json_dict['state'], json_dict.get('attributes'), last_changed, last_updated) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Return the comparison of the state.""" - return (self.__class__ == other.__class__ and + return (self.__class__ == other.__class__ and # type: ignore self.entity_id == other.entity_id and self.state == other.state and self.attributes == other.attributes) - def __repr__(self): + def __repr__(self) -> str: """Return the representation of the states.""" attr = "; {}".format(util.repr_helper(self.attributes)) \ if self.attributes else "" @@ -657,21 +671,23 @@ class State: class StateMachine: """Helper class that tracks the state of different entities.""" - def __init__(self, bus, loop): + def __init__(self, bus: EventBus, + loop: asyncio.events.AbstractEventLoop) -> None: """Initialize state machine.""" self._states = {} # type: Dict[str, State] self._bus = bus self._loop = loop - def entity_ids(self, domain_filter=None): + def entity_ids(self, domain_filter: Optional[str] = None)-> List[str]: """List of entity ids that are being tracked.""" future = run_callback_threadsafe( self._loop, self.async_entity_ids, domain_filter ) - return future.result() + return future.result() # type: ignore @callback - def async_entity_ids(self, domain_filter=None): + def async_entity_ids( + self, domain_filter: Optional[str] = None) -> List[str]: """List of entity ids that are being tracked. This method must be run in the event loop. @@ -684,26 +700,27 @@ class StateMachine: return [state.entity_id for state in self._states.values() if state.domain == domain_filter] - def all(self): + def all(self)-> List[State]: """Create a list of all states.""" - return run_callback_threadsafe(self._loop, self.async_all).result() + return run_callback_threadsafe( # type: ignore + self._loop, self.async_all).result() @callback - def async_all(self): + def async_all(self)-> List[State]: """Create a list of all states. This method must be run in the event loop. """ return list(self._states.values()) - def get(self, entity_id): + def get(self, entity_id: str) -> Optional[State]: """Retrieve state of entity_id or None if not found. Async friendly. """ return self._states.get(entity_id.lower()) - def is_state(self, entity_id, state): + def is_state(self, entity_id: str, state: State) -> bool: """Test if entity exists and is specified state. Async friendly. @@ -711,16 +728,16 @@ class StateMachine: state_obj = self.get(entity_id) return state_obj is not None and state_obj.state == state - def remove(self, entity_id): + def remove(self, entity_id: str) -> bool: """Remove the state of an entity. Returns boolean to indicate if an entity was removed. """ - return run_callback_threadsafe( + return run_callback_threadsafe( # type: ignore self._loop, self.async_remove, entity_id).result() @callback - def async_remove(self, entity_id): + def async_remove(self, entity_id: str) -> bool: """Remove the state of an entity. Returns boolean to indicate if an entity was removed. @@ -740,7 +757,9 @@ class StateMachine: }) return True - def set(self, entity_id, new_state, attributes=None, force_update=False): + def set(self, entity_id: str, new_state: Any, + attributes: Optional[Dict] = None, + force_update: bool = False) -> None: """Set the state of an entity, add entity if it does not exist. Attributes is an optional dict to specify attributes of this state. @@ -754,8 +773,9 @@ class StateMachine: ).result() @callback - def async_set(self, entity_id, new_state, attributes=None, - force_update=False): + def async_set(self, entity_id: str, new_state: Any, + attributes: Optional[Dict] = None, + force_update: bool = False) -> None: """Set the state of an entity, add entity if it does not exist. Attributes is an optional dict to specify attributes of this state. @@ -769,15 +789,19 @@ class StateMachine: new_state = str(new_state) attributes = attributes or {} old_state = self._states.get(entity_id) - is_existing = old_state is not None - same_state = (is_existing and old_state.state == new_state and - not force_update) - same_attr = is_existing and old_state.attributes == attributes + if old_state is None: + same_state = False + same_attr = False + last_changed = None + else: + same_state = (old_state.state == new_state and + not force_update) + same_attr = old_state.attributes == attributes + last_changed = old_state.last_changed if same_state else None if same_state and same_attr: return - last_changed = old_state.last_changed if same_state else None state = State(entity_id, new_state, attributes, last_changed) self._states[entity_id] = state self._bus.async_fire(EVENT_STATE_CHANGED, { @@ -792,7 +816,7 @@ class Service: __slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction'] - def __init__(self, func, schema): + def __init__(self, func: Callable, schema: Optional[vol.Schema]) -> None: """Initialize a service.""" self.func = func self.schema = schema @@ -805,14 +829,15 @@ class ServiceCall: __slots__ = ['domain', 'service', 'data', 'call_id'] - def __init__(self, domain, service, data=None, call_id=None): + def __init__(self, domain: str, service: str, data: Optional[Dict] = None, + call_id: Optional[str] = None) -> None: """Initialize a service call.""" self.domain = domain.lower() self.service = service.lower() self.data = MappingProxyType(data or {}) self.call_id = call_id - def __repr__(self): + def __repr__(self) -> str: """Return the representation of the service.""" if self.data: return "<ServiceCall {}.{}: {}>".format( @@ -824,13 +849,13 @@ class ServiceCall: class ServiceRegistry: """Offer the services over the eventbus.""" - def __init__(self, hass): + def __init__(self, hass: HomeAssistant) -> None: """Initialize a service registry.""" self._services = {} # type: Dict[str, Dict[str, Service]] self._hass = hass - self._async_unsub_call_event = None + self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE] - def _gen_unique_id(): + def _gen_unique_id() -> Iterator[str]: cur_id = 1 while True: yield '{}-{}'.format(id(self), cur_id) @@ -840,14 +865,14 @@ class ServiceRegistry: self._generate_unique_id = lambda: next(gen) @property - def services(self): + def services(self) -> Dict[str, Dict[str, Service]]: """Return dictionary with per domain a list of available services.""" - return run_callback_threadsafe( + return run_callback_threadsafe( # type: ignore self._hass.loop, self.async_services, ).result() @callback - def async_services(self): + def async_services(self) -> Dict[str, Dict[str, Service]]: """Return dictionary with per domain a list of available services. This method must be run in the event loop. @@ -855,14 +880,15 @@ class ServiceRegistry: return {domain: self._services[domain].copy() for domain in self._services} - def has_service(self, domain, service): + def has_service(self, domain: str, service: str) -> bool: """Test if specified service exists. Async friendly. """ return service.lower() in self._services.get(domain.lower(), []) - def register(self, domain, service, service_func, schema=None): + def register(self, domain: str, service: str, service_func: Callable, + schema: Optional[vol.Schema] = None) -> None: """ Register a service. @@ -874,7 +900,8 @@ class ServiceRegistry: ).result() @callback - def async_register(self, domain, service, service_func, schema=None): + def async_register(self, domain: str, service: str, service_func: Callable, + schema: Optional[vol.Schema] = None) -> None: """ Register a service. @@ -900,13 +927,13 @@ class ServiceRegistry: {ATTR_DOMAIN: domain, ATTR_SERVICE: service} ) - def remove(self, domain, service): + def remove(self, domain: str, service: str) -> None: """Remove a registered service from service handler.""" run_callback_threadsafe( self._hass.loop, self.async_remove, domain, service).result() @callback - def async_remove(self, domain, service): + def async_remove(self, domain: str, service: str) -> None: """Remove a registered service from service handler. This method must be run in the event loop. @@ -926,7 +953,9 @@ class ServiceRegistry: {ATTR_DOMAIN: domain, ATTR_SERVICE: service} ) - def call(self, domain, service, service_data=None, blocking=False): + def call(self, domain: str, service: str, + service_data: Optional[Dict] = None, + blocking: bool = False) -> Optional[bool]: """ Call a service. @@ -943,13 +972,14 @@ class ServiceRegistry: Because the service is sent as an event you are not allowed to use the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data. """ - return run_coroutine_threadsafe( + return run_coroutine_threadsafe( # type: ignore self.async_call(domain, service, service_data, blocking), self._hass.loop ).result() - async def async_call(self, domain, service, service_data=None, - blocking=False): + async def async_call(self, domain: str, service: str, + service_data: Optional[Dict] = None, + blocking: bool = False) -> Optional[bool]: """ Call a service. @@ -981,7 +1011,7 @@ class ServiceRegistry: fut = asyncio.Future(loop=self._hass.loop) # type: asyncio.Future @callback - def service_executed(event): + def service_executed(event: Event) -> None: """Handle an executed service.""" if event.data[ATTR_SERVICE_CALL_ID] == call_id: fut.set_result(True) @@ -989,20 +1019,22 @@ class ServiceRegistry: unsub = self._hass.bus.async_listen( EVENT_SERVICE_EXECUTED, service_executed) - self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data) + self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data) - if blocking: done, _ = await asyncio.wait( [fut], loop=self._hass.loop, timeout=SERVICE_CALL_LIMIT) success = bool(done) unsub() return success - async def _event_to_service_call(self, event): + self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data) + return None + + async def _event_to_service_call(self, event: Event) -> None: """Handle the SERVICE_CALLED events from the EventBus.""" service_data = event.data.get(ATTR_SERVICE_DATA) or {} - domain = event.data.get(ATTR_DOMAIN).lower() - service = event.data.get(ATTR_SERVICE).lower() + domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore + service = event.data.get(ATTR_SERVICE).lower() # type: ignore call_id = event.data.get(ATTR_SERVICE_CALL_ID) if not self.has_service(domain, service): @@ -1013,7 +1045,7 @@ class ServiceRegistry: service_handler = self._services[domain][service] - def fire_service_executed(): + def fire_service_executed() -> None: """Fire service executed event.""" if not call_id: return @@ -1045,12 +1077,12 @@ class ServiceRegistry: await service_handler.func(service_call) fire_service_executed() else: - def execute_service(): + def execute_service() -> None: """Execute a service and fires a SERVICE_EXECUTED event.""" service_handler.func(service_call) fire_service_executed() - await self._hass.async_add_job(execute_service) + await self._hass.async_add_executor_job(execute_service) except Exception: # pylint: disable=broad-except _LOGGER.exception('Error executing service %s', service_call) @@ -1058,13 +1090,13 @@ class ServiceRegistry: class Config: """Configuration settings for Home Assistant.""" - def __init__(self): + def __init__(self) -> None: """Initialize a new config object.""" self.latitude = None # type: Optional[float] self.longitude = None # type: Optional[float] self.elevation = None # type: Optional[int] self.location_name = None # type: Optional[str] - self.time_zone = None # type: Optional[str] + self.time_zone = None # type: Optional[datetime.tzinfo] self.units = METRIC_SYSTEM # type: UnitSystem # If True, pip install is skipped for requirements on startup @@ -1090,7 +1122,7 @@ class Config: return self.units.length( location.distance(self.latitude, self.longitude, lat, lon), 'm') - def path(self, *path): + def path(self, *path: str) -> str: """Generate path to the file within the configuration directory. Async friendly. @@ -1122,12 +1154,14 @@ class Config: return False - def as_dict(self): + def as_dict(self) -> Dict: """Create a dictionary representation of this dict. Async friendly. """ - time_zone = self.time_zone or dt_util.UTC + time_zone = dt_util.UTC.zone + if self.time_zone and getattr(self.time_zone, 'zone'): + time_zone = getattr(self.time_zone, 'zone') return { 'latitude': self.latitude, @@ -1135,7 +1169,7 @@ class Config: 'elevation': self.elevation, 'unit_system': self.units.as_dict(), 'location_name': self.location_name, - 'time_zone': time_zone.zone, + 'time_zone': time_zone, 'components': self.components, 'config_dir': self.config_dir, 'whitelist_external_dirs': self.whitelist_external_dirs, @@ -1143,12 +1177,12 @@ class Config: } -def _async_create_timer(hass): +def _async_create_timer(hass: HomeAssistant) -> None: """Create a timer that will start on HOMEASSISTANT_START.""" handle = None @callback - def fire_time_event(nxt): + def fire_time_event(nxt: float) -> None: """Fire next time event.""" nonlocal handle @@ -1165,7 +1199,7 @@ def _async_create_timer(hass): handle = hass.loop.call_later(slp_seconds, fire_time_event, nxt) @callback - def stop_timer(event): + def stop_timer(_: Event) -> None: """Stop the timer.""" if handle is not None: handle.cancel() diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 24dcb46bb681fffaa2d40328d5051f6af888fa87..f010ada02f3c7fe164948c186514fdc252c3b1bc 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -1,8 +1,9 @@ """Classes to help gather user submissions.""" import logging import uuid -from typing import Dict, Any # noqa pylint: disable=unused-import -from .core import callback +import voluptuous as vol +from typing import Dict, Any, Callable, List, Optional # noqa pylint: disable=unused-import +from .core import callback, HomeAssistant from .exceptions import HomeAssistantError _LOGGER = logging.getLogger(__name__) @@ -35,7 +36,8 @@ class UnknownStep(FlowError): class FlowManager: """Manage all the flows that are in progress.""" - def __init__(self, hass, async_create_flow, async_finish_flow): + def __init__(self, hass: HomeAssistant, async_create_flow: Callable, + async_finish_flow: Callable) -> None: """Initialize the flow manager.""" self.hass = hass self._progress = {} # type: Dict[str, Any] @@ -43,7 +45,7 @@ class FlowManager: self._async_finish_flow = async_finish_flow @callback - def async_progress(self): + def async_progress(self) -> List[Dict]: """Return the flows in progress.""" return [{ 'flow_id': flow.flow_id, @@ -51,7 +53,8 @@ class FlowManager: 'source': flow.source, } for flow in self._progress.values()] - async def async_init(self, handler, *, source=SOURCE_USER, data=None): + async def async_init(self, handler: Callable, *, source: str = SOURCE_USER, + data: str = None) -> Any: """Start a configuration flow.""" flow = await self._async_create_flow(handler, source=source, data=data) flow.hass = self.hass @@ -67,7 +70,8 @@ class FlowManager: return await self._async_handle_step(flow, step, data) - async def async_configure(self, flow_id, user_input=None): + async def async_configure( + self, flow_id: str, user_input: str = None) -> Any: """Continue a configuration flow.""" flow = self._progress.get(flow_id) @@ -83,12 +87,13 @@ class FlowManager: flow, step_id, user_input) @callback - def async_abort(self, flow_id): + def async_abort(self, flow_id: str) -> None: """Abort a flow.""" if self._progress.pop(flow_id, None) is None: raise UnknownFlow - async def _async_handle_step(self, flow, step_id, user_input): + async def _async_handle_step(self, flow: Any, step_id: str, + user_input: Optional[str]) -> Dict: """Handle a step of a flow.""" method = "async_step_{}".format(step_id) @@ -97,7 +102,7 @@ class FlowManager: raise UnknownStep("Handler {} doesn't support step {}".format( flow.__class__.__name__, step_id)) - result = await getattr(flow, method)(user_input) + result = await getattr(flow, method)(user_input) # type: Dict if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY, RESULT_TYPE_ABORT): @@ -133,8 +138,9 @@ class FlowHandler: VERSION = 1 @callback - def async_show_form(self, *, step_id, data_schema=None, errors=None, - description_placeholders=None): + def async_show_form(self, *, step_id: str, data_schema: vol.Schema = None, + errors: Dict = None, + description_placeholders: Dict = None) -> Dict: """Return the definition of a form to gather user input.""" return { 'type': RESULT_TYPE_FORM, @@ -147,7 +153,7 @@ class FlowHandler: } @callback - def async_create_entry(self, *, title, data): + def async_create_entry(self, *, title: str, data: Dict) -> Dict: """Finish config flow and create a config entry.""" return { 'version': self.VERSION, @@ -160,7 +166,7 @@ class FlowHandler: } @callback - def async_abort(self, *, reason): + def async_abort(self, *, reason: str) -> Dict: """Abort the config flow.""" return { 'type': RESULT_TYPE_ABORT, diff --git a/homeassistant/loader.py b/homeassistant/loader.py index 2b0f9ed18e45d72f66730e2443d6f847662ac4ad..c5cf99de234954988161fc69a0fecf51825af86e 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -17,7 +17,7 @@ import sys from types import ModuleType # pylint: disable=unused-import -from typing import Optional, Set, TYPE_CHECKING # NOQA +from typing import Optional, Set, TYPE_CHECKING, Callable, Any, TypeVar # NOQA from homeassistant.const import PLATFORM_FORMAT from homeassistant.util import OrderedSet @@ -27,6 +27,8 @@ from homeassistant.util import OrderedSet if TYPE_CHECKING: from homeassistant.core import HomeAssistant # NOQA +CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable) + PREPARED = False DEPENDENCY_BLACKLIST = {'config'} @@ -51,7 +53,8 @@ def set_component(hass, # type: HomeAssistant cache[comp_name] = component -def get_platform(hass, domain: str, platform: str) -> Optional[ModuleType]: +def get_platform(hass, # type: HomeAssistant + domain: str, platform: str) -> Optional[ModuleType]: """Try to load specified platform. Async friendly. @@ -59,7 +62,8 @@ def get_platform(hass, domain: str, platform: str) -> Optional[ModuleType]: return get_component(hass, PLATFORM_FORMAT.format(domain, platform)) -def get_component(hass, comp_or_platform) -> Optional[ModuleType]: +def get_component(hass, # type: HomeAssistant + comp_or_platform: str) -> Optional[ModuleType]: """Try to load specified component. Looks in config dir first, then built-in components. @@ -73,6 +77,9 @@ def get_component(hass, comp_or_platform) -> Optional[ModuleType]: cache = hass.data.get(DATA_KEY) if cache is None: + if hass.config.config_dir is None: + _LOGGER.error("Can't load components - config dir is not set") + return None # Only insert if it's not there (happens during tests) if sys.path[0] != hass.config.config_dir: sys.path.insert(0, hass.config.config_dir) @@ -134,14 +141,38 @@ def get_component(hass, comp_or_platform) -> Optional[ModuleType]: return None +class ModuleWrapper: + """Class to wrap a Python module and auto fill in hass argument.""" + + def __init__(self, + hass, # type: HomeAssistant + module: ModuleType) -> None: + """Initialize the module wrapper.""" + self._hass = hass + self._module = module + + def __getattr__(self, attr: str) -> Any: + """Fetch an attribute.""" + value = getattr(self._module, attr) + + if hasattr(value, '__bind_hass'): + value = ft.partial(value, self._hass) + + setattr(self, attr, value) + return value + + class Components: """Helper to load components.""" - def __init__(self, hass): + def __init__( + self, + hass # type: HomeAssistant + ) -> None: """Initialize the Components class.""" self._hass = hass - def __getattr__(self, comp_name): + def __getattr__(self, comp_name: str) -> ModuleWrapper: """Fetch a component.""" component = get_component(self._hass, comp_name) if component is None: @@ -154,11 +185,14 @@ class Components: class Helpers: """Helper to load helpers.""" - def __init__(self, hass): + def __init__( + self, + hass # type: HomeAssistant + ) -> None: """Initialize the Helpers class.""" self._hass = hass - def __getattr__(self, helper_name): + def __getattr__(self, helper_name: str) -> ModuleWrapper: """Fetch a helper.""" helper = importlib.import_module( 'homeassistant.helpers.{}'.format(helper_name)) @@ -167,33 +201,14 @@ class Helpers: return wrapped -class ModuleWrapper: - """Class to wrap a Python module and auto fill in hass argument.""" - - def __init__(self, hass, module): - """Initialize the module wrapper.""" - self._hass = hass - self._module = module - - def __getattr__(self, attr): - """Fetch an attribute.""" - value = getattr(self._module, attr) - - if hasattr(value, '__bind_hass'): - value = ft.partial(value, self._hass) - - setattr(self, attr, value) - return value - - -def bind_hass(func): +def bind_hass(func: CALLABLE_T) -> CALLABLE_T: """Decorate function to indicate that first argument is hass.""" - # pylint: disable=protected-access - func.__bind_hass = True + setattr(func, '__bind_hass', True) return func -def load_order_component(hass, comp_name: str) -> OrderedSet: +def load_order_component(hass, # type: HomeAssistant + comp_name: str) -> OrderedSet: """Return an OrderedSet of components in the correct order of loading. Raises HomeAssistantError if a circular dependency is detected. @@ -204,7 +219,8 @@ def load_order_component(hass, comp_name: str) -> OrderedSet: return _load_order_component(hass, comp_name, OrderedSet(), set()) -def _load_order_component(hass, comp_name: str, load_order: OrderedSet, +def _load_order_component(hass, # type: HomeAssistant + comp_name: str, load_order: OrderedSet, loading: Set) -> OrderedSet: """Recursive function to get load order of components. diff --git a/homeassistant/monkey_patch.py b/homeassistant/monkey_patch.py index 17329fbddff4259439251664eeaa45069d2bf455..aa330ffec167c67b5a44f7846a5f9ea6d8eb12af 100644 --- a/homeassistant/monkey_patch.py +++ b/homeassistant/monkey_patch.py @@ -20,9 +20,10 @@ Related Python bugs: - https://bugs.python.org/issue26617 """ import sys +from typing import Any -def patch_weakref_tasks(): +def patch_weakref_tasks() -> None: """Replace weakref.WeakSet to address Python 3 bug.""" # pylint: disable=no-self-use, protected-access, bare-except import asyncio.tasks @@ -30,7 +31,7 @@ def patch_weakref_tasks(): class IgnoreCalls: """Ignore add calls.""" - def add(self, other): + def add(self, other: Any) -> None: """No-op add.""" return @@ -41,7 +42,7 @@ def patch_weakref_tasks(): pass -def disable_c_asyncio(): +def disable_c_asyncio() -> None: """Disable using C implementation of asyncio. Required to be able to apply the weakref monkey patch. @@ -53,12 +54,12 @@ def disable_c_asyncio(): PATH_TRIGGER = '_asyncio' - def __init__(self, path_entry): + def __init__(self, path_entry: str) -> None: if path_entry != self.PATH_TRIGGER: raise ImportError() return - def find_module(self, fullname, path=None): + def find_module(self, fullname: str, path: Any = None) -> None: """Find a module.""" if fullname == self.PATH_TRIGGER: # We lint in Py35, exception is introduced in Py36 diff --git a/homeassistant/remote.py b/homeassistant/remote.py index 7147fab108063c64aa927f0e82a16d58c4d3de52..313f98a890c74162a7aac557db44cf91daafabe8 100644 --- a/homeassistant/remote.py +++ b/homeassistant/remote.py @@ -13,7 +13,7 @@ import json import logging import urllib.parse -from typing import Optional +from typing import Optional, Dict, Any, List from aiohttp.hdrs import METH_GET, METH_POST, METH_DELETE, CONTENT_TYPE import requests @@ -62,7 +62,7 @@ class API: if port is not None: self.base_url += ':{}'.format(port) - self.status = None + self.status = None # type: Optional[APIStatus] self._headers = {CONTENT_TYPE: CONTENT_TYPE_JSON} if api_password is not None: @@ -75,20 +75,24 @@ class API: return self.status == APIStatus.OK - def __call__(self, method, path, data=None, timeout=5): + def __call__(self, method: str, path: str, data: Dict = None, + timeout: int = 5) -> requests.Response: """Make a call to the Home Assistant API.""" - if data is not None: - data = json.dumps(data, cls=JSONEncoder) + if data is None: + data_str = None + else: + data_str = json.dumps(data, cls=JSONEncoder) url = urllib.parse.urljoin(self.base_url, path) try: if method == METH_GET: return requests.get( - url, params=data, timeout=timeout, headers=self._headers) + url, params=data_str, timeout=timeout, + headers=self._headers) return requests.request( - method, url, data=data, timeout=timeout, + method, url, data=data_str, timeout=timeout, headers=self._headers) except requests.exceptions.ConnectionError: @@ -110,7 +114,7 @@ class JSONEncoder(json.JSONEncoder): """JSONEncoder that supports Home Assistant objects.""" # pylint: disable=method-hidden - def default(self, o): + def default(self, o: Any) -> Any: """Convert Home Assistant objects. Hand other objects to the original method. @@ -125,7 +129,7 @@ class JSONEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, o) -def validate_api(api): +def validate_api(api: API) -> APIStatus: """Make a call to validate API.""" try: req = api(METH_GET, URL_API) @@ -142,12 +146,12 @@ def validate_api(api): return APIStatus.CANNOT_CONNECT -def get_event_listeners(api): +def get_event_listeners(api: API) -> Dict: """List of events that is being listened for.""" try: req = api(METH_GET, URL_API_EVENTS) - return req.json() if req.status_code == 200 else {} + return req.json() if req.status_code == 200 else {} # type: ignore except (HomeAssistantError, ValueError): # ValueError if req.json() can't parse the json @@ -156,7 +160,7 @@ def get_event_listeners(api): return {} -def fire_event(api, event_type, data=None): +def fire_event(api: API, event_type: str, data: Dict = None) -> None: """Fire an event at remote API.""" try: req = api(METH_POST, URL_API_EVENTS_EVENT.format(event_type), data) @@ -169,7 +173,7 @@ def fire_event(api, event_type, data=None): _LOGGER.exception("Error firing event") -def get_state(api, entity_id): +def get_state(api: API, entity_id: str) -> Optional[ha.State]: """Query given API for state of entity_id.""" try: req = api(METH_GET, URL_API_STATES_ENTITY.format(entity_id)) @@ -186,7 +190,7 @@ def get_state(api, entity_id): return None -def get_states(api): +def get_states(api: API) -> List[ha.State]: """Query given API for all states.""" try: req = api(METH_GET, @@ -202,7 +206,7 @@ def get_states(api): return [] -def remove_state(api, entity_id): +def remove_state(api: API, entity_id: str) -> bool: """Call API to remove state for entity_id. Return True if entity is gone (removed/never existed). @@ -222,7 +226,8 @@ def remove_state(api, entity_id): return False -def set_state(api, entity_id, new_state, attributes=None, force_update=False): +def set_state(api: API, entity_id: str, new_state: str, + attributes: Dict = None, force_update: bool = False) -> bool: """Tell API to update state for entity_id. Return True if success. @@ -249,14 +254,14 @@ def set_state(api, entity_id, new_state, attributes=None, force_update=False): return False -def is_state(api, entity_id, state): +def is_state(api: API, entity_id: str, state: str) -> bool: """Query API to see if entity_id is specified state.""" cur_state = get_state(api, entity_id) - return cur_state and cur_state.state == state + return bool(cur_state and cur_state.state == state) -def get_services(api): +def get_services(api: API) -> Dict: """Return a list of dicts. Each dict has a string "domain" and a list of strings "services". @@ -264,7 +269,7 @@ def get_services(api): try: req = api(METH_GET, URL_API_SERVICES) - return req.json() if req.status_code == 200 else {} + return req.json() if req.status_code == 200 else {} # type: ignore except (HomeAssistantError, ValueError): # ValueError if req.json() can't parse the json @@ -273,7 +278,9 @@ def get_services(api): return {} -def call_service(api, domain, service, service_data=None, timeout=5): +def call_service(api: API, domain: str, service: str, + service_data: Dict = None, + timeout: int = 5) -> None: """Call a service at the remote API.""" try: req = api(METH_POST, @@ -288,7 +295,7 @@ def call_service(api, domain, service, service_data=None, timeout=5): _LOGGER.exception("Error calling service") -def get_config(api): +def get_config(api: API) -> Dict: """Return configuration.""" try: req = api(METH_GET, URL_API_CONFIG) @@ -299,7 +306,7 @@ def get_config(api): result = req.json() if 'components' in result: result['components'] = set(result['components']) - return result + return result # type: ignore except (HomeAssistantError, ValueError): # ValueError if req.json() can't parse the JSON diff --git a/homeassistant/requirements.py b/homeassistant/requirements.py index 753947a2c12c69d073a4a3b58f12509e387f13ef..b73ec4e184e87d42ec60513695a94875cad39c1f 100644 --- a/homeassistant/requirements.py +++ b/homeassistant/requirements.py @@ -3,15 +3,18 @@ import asyncio from functools import partial import logging import os +from typing import List, Dict, Optional import homeassistant.util.package as pkg_util +from homeassistant.core import HomeAssistant DATA_PIP_LOCK = 'pip_lock' CONSTRAINT_FILE = 'package_constraints.txt' _LOGGER = logging.getLogger(__name__) -async def async_process_requirements(hass, name, requirements): +async def async_process_requirements(hass: HomeAssistant, name: str, + requirements: List[str]) -> bool: """Install the requirements for a component or platform. This method is a coroutine. @@ -25,7 +28,7 @@ async def async_process_requirements(hass, name, requirements): async with pip_lock: for req in requirements: - ret = await hass.async_add_job(pip_install, req) + ret = await hass.async_add_executor_job(pip_install, req) if not ret: _LOGGER.error("Not initializing %s because could not install " "requirement %s", name, req) @@ -34,11 +37,11 @@ async def async_process_requirements(hass, name, requirements): return True -def pip_kwargs(config_dir): +def pip_kwargs(config_dir: Optional[str]) -> Dict[str, str]: """Return keyword arguments for PIP install.""" kwargs = { 'constraints': os.path.join(os.path.dirname(__file__), CONSTRAINT_FILE) } - if not pkg_util.is_virtual_env(): + if not (config_dir is None or pkg_util.is_virtual_env()): kwargs['target'] = os.path.join(config_dir, 'deps') return kwargs diff --git a/homeassistant/setup.py b/homeassistant/setup.py index 0641a461130b6bfeb8d920110618561ebcd0490d..31404b978eb4069c1636146490a5abbb50b1b561 100644 --- a/homeassistant/setup.py +++ b/homeassistant/setup.py @@ -4,7 +4,7 @@ import logging.handlers from timeit import default_timer as timer from types import ModuleType -from typing import Optional, Dict +from typing import Optional, Dict, List from homeassistant import requirements, core, loader, config as conf_util from homeassistant.config import async_notify_setup_error @@ -56,7 +56,9 @@ async def async_setup_component(hass: core.HomeAssistant, domain: str, return await task # type: ignore -async def _async_process_dependencies(hass, config, name, dependencies): +async def _async_process_dependencies( + hass: core.HomeAssistant, config: Dict, name: str, + dependencies: List[str]) -> bool: """Ensure all dependencies are set up.""" blacklisted = [dep for dep in dependencies if dep in loader.DEPENDENCY_BLACKLIST] @@ -88,12 +90,12 @@ async def _async_process_dependencies(hass, config, name, dependencies): async def _async_setup_component(hass: core.HomeAssistant, - domain: str, config) -> bool: + domain: str, config: Dict) -> bool: """Set up a component for Home Assistant. This method is a coroutine. """ - def log_error(msg, link=True): + def log_error(msg: str, link: bool = True) -> None: """Log helper.""" _LOGGER.error("Setup failed for %s: %s", domain, msg) async_notify_setup_error(hass, domain, link) @@ -181,7 +183,7 @@ async def _async_setup_component(hass: core.HomeAssistant, return True -async def async_prepare_setup_platform(hass: core.HomeAssistant, config, +async def async_prepare_setup_platform(hass: core.HomeAssistant, config: Dict, domain: str, platform_name: str) \ -> Optional[ModuleType]: """Load a platform and makes sure dependencies are setup. @@ -190,7 +192,7 @@ async def async_prepare_setup_platform(hass: core.HomeAssistant, config, """ platform_path = PLATFORM_FORMAT.format(domain, platform_name) - def log_error(msg): + def log_error(msg: str) -> None: """Log helper.""" _LOGGER.error("Unable to prepare setup for platform %s: %s", platform_path, msg) @@ -217,7 +219,9 @@ async def async_prepare_setup_platform(hass: core.HomeAssistant, config, return platform -async def async_process_deps_reqs(hass, config, name, module): +async def async_process_deps_reqs( + hass: core.HomeAssistant, config: Dict, name: str, + module: ModuleType) -> None: """Process all dependencies and requirements for a module. Module is a Python module of either a component or platform. @@ -231,14 +235,14 @@ async def async_process_deps_reqs(hass, config, name, module): if hasattr(module, 'DEPENDENCIES'): dep_success = await _async_process_dependencies( - hass, config, name, module.DEPENDENCIES) + hass, config, name, module.DEPENDENCIES) # type: ignore if not dep_success: raise HomeAssistantError("Could not setup all dependencies.") if not hass.config.skip_pip and hasattr(module, 'REQUIREMENTS'): req_success = await requirements.async_process_requirements( - hass, name, module.REQUIREMENTS) + hass, name, module.REQUIREMENTS) # type: ignore if not req_success: raise HomeAssistantError("Could not install all requirements.") diff --git a/homeassistant/util/__init__.py b/homeassistant/util/__init__.py index 6b539e991865d2e61cddb8c7a19403424e28d6e8..37f669944d97c9ba8b0aaae71c3bfd512eecb025 100644 --- a/homeassistant/util/__init__.py +++ b/homeassistant/util/__init__.py @@ -1,9 +1,8 @@ """Helper methods for various modules.""" import asyncio -from collections.abc import MutableSet +from datetime import datetime, timedelta from itertools import chain import threading -from datetime import datetime import re import enum import socket @@ -14,12 +13,13 @@ from types import MappingProxyType from unicodedata import normalize from typing import (Any, Optional, TypeVar, Callable, KeysView, Union, # noqa - Iterable, List, Mapping) + Iterable, List, Dict, Iterator, Coroutine, MutableSet) from .dt import as_local, utcnow T = TypeVar('T') U = TypeVar('U') +ENUM_T = TypeVar('ENUM_T', bound=enum.Enum) RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)') RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)') @@ -91,7 +91,7 @@ def ensure_unique_string(preferred_string: str, current_strings: # Taken from: http://stackoverflow.com/a/11735897 -def get_local_ip(): +def get_local_ip() -> str: """Try to determine the local IP address of the machine.""" try: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -99,7 +99,7 @@ def get_local_ip(): # Use Google Public DNS server to determine own IP sock.connect(('8.8.8.8', 80)) - return sock.getsockname()[0] + return sock.getsockname()[0] # type: ignore except socket.error: try: return socket.gethostbyname(socket.gethostname()) @@ -110,7 +110,7 @@ def get_local_ip(): # Taken from http://stackoverflow.com/a/23728630 -def get_random_string(length=10): +def get_random_string(length: int = 10) -> str: """Return a random string with letters and digits.""" generator = random.SystemRandom() source_chars = string.ascii_letters + string.digits @@ -121,59 +121,59 @@ def get_random_string(length=10): class OrderedEnum(enum.Enum): """Taken from Python 3.4.0 docs.""" - def __ge__(self, other): + def __ge__(self: ENUM_T, other: ENUM_T) -> bool: """Return the greater than element.""" if self.__class__ is other.__class__: - return self.value >= other.value + return bool(self.value >= other.value) return NotImplemented - def __gt__(self, other): + def __gt__(self: ENUM_T, other: ENUM_T) -> bool: """Return the greater element.""" if self.__class__ is other.__class__: - return self.value > other.value + return bool(self.value > other.value) return NotImplemented - def __le__(self, other): + def __le__(self: ENUM_T, other: ENUM_T) -> bool: """Return the lower than element.""" if self.__class__ is other.__class__: - return self.value <= other.value + return bool(self.value <= other.value) return NotImplemented - def __lt__(self, other): + def __lt__(self: ENUM_T, other: ENUM_T) -> bool: """Return the lower element.""" if self.__class__ is other.__class__: - return self.value < other.value + return bool(self.value < other.value) return NotImplemented -class OrderedSet(MutableSet): +class OrderedSet(MutableSet[T]): """Ordered set taken from http://code.activestate.com/recipes/576694/.""" - def __init__(self, iterable=None): + def __init__(self, iterable: Iterable[T] = None) -> None: """Initialize the set.""" self.end = end = [] # type: List[Any] end += [None, end, end] # sentinel node for doubly linked list - self.map = {} # type: Mapping[List, Any] # key --> [key, prev, next] + self.map = {} # type: Dict[T, List] # key --> [key, prev, next] if iterable is not None: - self |= iterable + self |= iterable # type: ignore - def __len__(self): + def __len__(self) -> int: """Return the length of the set.""" return len(self.map) - def __contains__(self, key): + def __contains__(self, key: T) -> bool: # type: ignore """Check if key is in set.""" return key in self.map # pylint: disable=arguments-differ - def add(self, key): + def add(self, key: T) -> None: """Add an element to the end of the set.""" if key not in self.map: end = self.end curr = end[1] curr[2] = end[1] = self.map[key] = [key, curr, end] - def promote(self, key): + def promote(self, key: T) -> None: """Promote element to beginning of the set, add if not there.""" if key in self.map: self.discard(key) @@ -183,14 +183,14 @@ class OrderedSet(MutableSet): curr[2] = begin[1] = self.map[key] = [key, curr, begin] # pylint: disable=arguments-differ - def discard(self, key): + def discard(self, key: T) -> None: """Discard an element from the set.""" if key in self.map: key, prev_item, next_item = self.map.pop(key) prev_item[2] = next_item next_item[1] = prev_item - def __iter__(self): + def __iter__(self) -> Iterator[T]: """Iterate of the set.""" end = self.end curr = end[2] @@ -198,7 +198,7 @@ class OrderedSet(MutableSet): yield curr[0] curr = curr[2] - def __reversed__(self): + def __reversed__(self) -> Iterator[T]: """Reverse the ordering.""" end = self.end curr = end[1] @@ -207,7 +207,7 @@ class OrderedSet(MutableSet): curr = curr[1] # pylint: disable=arguments-differ - def pop(self, last=True): + def pop(self, last: bool = True) -> T: """Pop element of the end of the set. Set last=False to pop from the beginning. @@ -216,20 +216,20 @@ class OrderedSet(MutableSet): raise KeyError('set is empty') key = self.end[1][0] if last else self.end[2][0] self.discard(key) - return key + return key # type: ignore - def update(self, *args): + def update(self, *args: Any) -> None: """Add elements from args to the set.""" for item in chain(*args): self.add(item) - def __repr__(self): + def __repr__(self) -> str: """Return the representation.""" if not self: return '%s()' % (self.__class__.__name__,) return '%s(%r)' % (self.__class__.__name__, list(self)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Return the comparison.""" if isinstance(other, OrderedSet): return len(self) == len(other) and list(self) == list(other) @@ -254,20 +254,21 @@ class Throttle: Adds a datetime attribute `last_call` to the method. """ - def __init__(self, min_time, limit_no_throttle=None): + def __init__(self, min_time: timedelta, + limit_no_throttle: timedelta = None) -> None: """Initialize the throttle.""" self.min_time = min_time self.limit_no_throttle = limit_no_throttle - def __call__(self, method): + def __call__(self, method: Callable) -> Callable: """Caller for the throttle.""" # Make sure we return a coroutine if the method is async. if asyncio.iscoroutinefunction(method): - async def throttled_value(): + async def throttled_value() -> None: """Stand-in function for when real func is being throttled.""" return None else: - def throttled_value(): + def throttled_value() -> None: # type: ignore """Stand-in function for when real func is being throttled.""" return None @@ -288,14 +289,14 @@ class Throttle: '.' not in method.__qualname__.split('.<locals>.')[-1]) @wraps(method) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]: """Wrap that allows wrapped to be called only once per min_time. If we cannot acquire the lock, it is running so return None. """ # pylint: disable=protected-access if hasattr(method, '__self__'): - host = method.__self__ + host = getattr(method, '__self__') elif is_func: host = wrapper else: @@ -318,7 +319,7 @@ class Throttle: if force or utcnow() - throttle[1] > self.min_time: result = method(*args, **kwargs) throttle[1] = utcnow() - return result + return result # type: ignore return throttled_value() finally: diff --git a/homeassistant/util/async_.py b/homeassistant/util/async_.py index 334b4f4548341798f568e610abf7dd6e1ab7a456..1e2eb25245a9d8bc0914166dad14154b309574e5 100644 --- a/homeassistant/util/async_.py +++ b/homeassistant/util/async_.py @@ -3,22 +3,25 @@ import concurrent.futures import threading import logging from asyncio import coroutines +from asyncio.events import AbstractEventLoop from asyncio.futures import Future from asyncio import ensure_future - +from typing import Any, Union, Coroutine, Callable, Generator _LOGGER = logging.getLogger(__name__) -def _set_result_unless_cancelled(fut, result): +def _set_result_unless_cancelled(fut: Future, result: Any) -> None: """Set the result only if the Future was not cancelled.""" if fut.cancelled(): return fut.set_result(result) -def _set_concurrent_future_state(concurr, source): +def _set_concurrent_future_state( + concurr: concurrent.futures.Future, + source: Union[concurrent.futures.Future, Future]) -> None: """Copy state from a future to a concurrent.futures.Future.""" assert source.done() if source.cancelled(): @@ -33,7 +36,8 @@ def _set_concurrent_future_state(concurr, source): concurr.set_result(result) -def _copy_future_state(source, dest): +def _copy_future_state(source: Union[concurrent.futures.Future, Future], + dest: Union[concurrent.futures.Future, Future]) -> None: """Copy state from another Future. The other Future may be a concurrent.futures.Future. @@ -53,7 +57,9 @@ def _copy_future_state(source, dest): dest.set_result(result) -def _chain_future(source, destination): +def _chain_future( + source: Union[concurrent.futures.Future, Future], + destination: Union[concurrent.futures.Future, Future]) -> None: """Chain two futures so that when one completes, so does the other. The result (or exception) of source will be copied to destination. @@ -74,20 +80,23 @@ def _chain_future(source, destination): else: dest_loop = None - def _set_state(future, other): + def _set_state(future: Union[concurrent.futures.Future, Future], + other: Union[concurrent.futures.Future, Future]) -> None: if isinstance(future, Future): _copy_future_state(other, future) else: _set_concurrent_future_state(future, other) - def _call_check_cancel(destination): + def _call_check_cancel( + destination: Union[concurrent.futures.Future, Future]) -> None: if destination.cancelled(): if source_loop is None or source_loop is dest_loop: source.cancel() else: source_loop.call_soon_threadsafe(source.cancel) - def _call_set_state(source): + def _call_set_state( + source: Union[concurrent.futures.Future, Future]) -> None: if dest_loop is None or dest_loop is source_loop: _set_state(destination, source) else: @@ -97,7 +106,9 @@ def _chain_future(source, destination): source.add_done_callback(_call_set_state) -def run_coroutine_threadsafe(coro, loop): +def run_coroutine_threadsafe( + coro: Union[Coroutine, Generator], + loop: AbstractEventLoop) -> concurrent.futures.Future: """Submit a coroutine object to a given event loop. Return a concurrent.futures.Future to access the result. @@ -110,7 +121,7 @@ def run_coroutine_threadsafe(coro, loop): raise TypeError('A coroutine object is required') future = concurrent.futures.Future() # type: concurrent.futures.Future - def callback(): + def callback() -> None: """Handle the call to the coroutine.""" try: _chain_future(ensure_future(coro, loop=loop), future) @@ -125,7 +136,8 @@ def run_coroutine_threadsafe(coro, loop): return future -def fire_coroutine_threadsafe(coro, loop): +def fire_coroutine_threadsafe(coro: Coroutine, + loop: AbstractEventLoop) -> None: """Submit a coroutine object to a given event loop. This method does not provide a way to retrieve the result and @@ -139,7 +151,7 @@ def fire_coroutine_threadsafe(coro, loop): if not coroutines.iscoroutine(coro): raise TypeError('A coroutine object is required: %s' % coro) - def callback(): + def callback() -> None: """Handle the firing of a coroutine.""" ensure_future(coro, loop=loop) @@ -147,7 +159,8 @@ def fire_coroutine_threadsafe(coro, loop): return -def run_callback_threadsafe(loop, callback, *args): +def run_callback_threadsafe(loop: AbstractEventLoop, callback: Callable, + *args: Any) -> concurrent.futures.Future: """Submit a callback object to a given event loop. Return a concurrent.futures.Future to access the result. @@ -158,7 +171,7 @@ def run_callback_threadsafe(loop, callback, *args): future = concurrent.futures.Future() # type: concurrent.futures.Future - def run_callback(): + def run_callback() -> None: """Run callback and store result.""" try: future.set_result(callback(*args)) diff --git a/homeassistant/util/color.py b/homeassistant/util/color.py index a26f7014444cfdc776726d9a17af69f7da35e244..0538bfbf369ffdd1766e98ce19a31ac7a9562e8a 100644 --- a/homeassistant/util/color.py +++ b/homeassistant/util/color.py @@ -2,7 +2,7 @@ import math import colorsys -from typing import Tuple +from typing import Tuple, List # Official CSS3 colors from w3.org: # https://www.w3.org/TR/2010/PR-css3-color-20101028/#html4 @@ -162,7 +162,7 @@ COLORS = { } -def color_name_to_rgb(color_name): +def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]: """Convert color name to RGB hex value.""" # COLORS map has no spaces in it, so make the color_name have no # spaces in it as well for matching purposes @@ -305,7 +305,8 @@ def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]: return (r, g, b) -def color_RGB_to_hsv(iR: int, iG: int, iB: int) -> Tuple[float, float, float]: +def color_RGB_to_hsv( + iR: float, iG: float, iB: float) -> Tuple[float, float, float]: """Convert an rgb color to its hsv representation. Hue is scaled 0-360 @@ -316,7 +317,7 @@ def color_RGB_to_hsv(iR: int, iG: int, iB: int) -> Tuple[float, float, float]: return round(fHSV[0]*360, 3), round(fHSV[1]*100, 3), round(fHSV[2]*100, 3) -def color_RGB_to_hs(iR: int, iG: int, iB: int) -> Tuple[float, float]: +def color_RGB_to_hs(iR: float, iG: float, iB: float) -> Tuple[float, float]: """Convert an rgb color to its hs representation.""" return color_RGB_to_hsv(iR, iG, iB)[:2] @@ -340,7 +341,7 @@ def color_hs_to_RGB(iH: float, iS: float) -> Tuple[int, int, int]: def color_xy_to_hs(vX: float, vY: float) -> Tuple[float, float]: """Convert an xy color to its hs representation.""" h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY)) - return (h, s) + return h, s def color_hs_to_xy(iH: float, iS: float) -> Tuple[float, float]: @@ -348,8 +349,7 @@ def color_hs_to_xy(iH: float, iS: float) -> Tuple[float, float]: return color_RGB_to_xy(*color_hs_to_RGB(iH, iS)) -def _match_max_scale(input_colors: Tuple[int, ...], - output_colors: Tuple[int, ...]) -> Tuple[int, ...]: +def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple: """Match the maximum value of the output to the input.""" max_in = max(input_colors) max_out = max(output_colors) @@ -360,7 +360,7 @@ def _match_max_scale(input_colors: Tuple[int, ...], return tuple(int(round(i * factor)) for i in output_colors) -def color_rgb_to_rgbw(r, g, b): +def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]: """Convert an rgb color to an rgbw representation.""" # Calculate the white channel as the minimum of input rgb channels. # Subtract the white portion from the remaining rgb channels. @@ -369,25 +369,25 @@ def color_rgb_to_rgbw(r, g, b): # Match the output maximum value to the input. This ensures the full # channel range is used. - return _match_max_scale((r, g, b), rgbw) + return _match_max_scale((r, g, b), rgbw) # type: ignore -def color_rgbw_to_rgb(r, g, b, w): +def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> Tuple[int, int, int]: """Convert an rgbw color to an rgb representation.""" # Add the white channel back into the rgb channels. rgb = (r + w, g + w, b + w) # Match the output maximum value to the input. This ensures the # output doesn't overflow. - return _match_max_scale((r, g, b, w), rgb) + return _match_max_scale((r, g, b, w), rgb) # type: ignore -def color_rgb_to_hex(r, g, b): +def color_rgb_to_hex(r: int, g: int, b: int) -> str: """Return a RGB color from a hex color string.""" return '{0:02x}{1:02x}{2:02x}'.format(round(r), round(g), round(b)) -def rgb_hex_to_rgb_list(hex_string): +def rgb_hex_to_rgb_list(hex_string: str) -> List[int]: """Return an RGB color value list from a hex color string.""" return [int(hex_string[i:i + len(hex_string) // 3], 16) for i in range(0, @@ -395,12 +395,14 @@ def rgb_hex_to_rgb_list(hex_string): len(hex_string) // 3)] -def color_temperature_to_hs(color_temperature_kelvin): +def color_temperature_to_hs( + color_temperature_kelvin: float) -> Tuple[float, float]: """Return an hs color from a color temperature in Kelvin.""" return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin)) -def color_temperature_to_rgb(color_temperature_kelvin): +def color_temperature_to_rgb( + color_temperature_kelvin: float) -> Tuple[float, float, float]: """ Return an RGB color from a color temperature in Kelvin. @@ -421,7 +423,7 @@ def color_temperature_to_rgb(color_temperature_kelvin): blue = _get_blue(tmp_internal) - return (red, green, blue) + return red, green, blue def _bound(color_component: float, minimum: float = 0, @@ -464,11 +466,11 @@ def _get_blue(temperature: float) -> float: return _bound(blue) -def color_temperature_mired_to_kelvin(mired_temperature): +def color_temperature_mired_to_kelvin(mired_temperature: float) -> float: """Convert absolute mired shift to degrees kelvin.""" return math.floor(1000000 / mired_temperature) -def color_temperature_kelvin_to_mired(kelvin_temperature): +def color_temperature_kelvin_to_mired(kelvin_temperature: float) -> float: """Convert degrees kelvin to mired shift.""" return math.floor(1000000 / kelvin_temperature) diff --git a/homeassistant/util/decorator.py b/homeassistant/util/decorator.py index c26606d52cffa795afc99264aade06f0a60c3a48..9d2a4600a64ce46730cd615d67bfdc41429747e2 100644 --- a/homeassistant/util/decorator.py +++ b/homeassistant/util/decorator.py @@ -1,12 +1,14 @@ """Decorator utility functions.""" +from typing import Callable, TypeVar +CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable) class Registry(dict): """Registry of items.""" - def register(self, name): + def register(self, name: str) -> Callable[[CALLABLE_T], CALLABLE_T]: """Return decorator to register item with a specific name.""" - def decorator(func): + def decorator(func: CALLABLE_T) -> CALLABLE_T: """Register decorated function.""" self[name] = func return func diff --git a/homeassistant/util/dt.py b/homeassistant/util/dt.py index bae38f27ee25186f53e061f315d94b1e604e0c8b..06159a944a26b4fa61086944968be9b512b0d8ee 100644 --- a/homeassistant/util/dt.py +++ b/homeassistant/util/dt.py @@ -71,14 +71,14 @@ def as_utc(dattim: dt.datetime) -> dt.datetime: return dattim.astimezone(UTC) -def as_timestamp(dt_value): +def as_timestamp(dt_value: dt.datetime) -> float: """Convert a date/time into a unix time (seconds since 1970).""" if hasattr(dt_value, "timestamp"): - parsed_dt = dt_value + parsed_dt = dt_value # type: Optional[dt.datetime] else: parsed_dt = parse_datetime(str(dt_value)) - if not parsed_dt: - raise ValueError("not a valid date/time.") + if parsed_dt is None: + raise ValueError("not a valid date/time.") return parsed_dt.timestamp() @@ -150,7 +150,7 @@ def parse_date(dt_str: str) -> Optional[dt.date]: return None -def parse_time(time_str): +def parse_time(time_str: str) -> Optional[dt.time]: """Parse a time string (00:20:00) into Time object. Return None if invalid. diff --git a/homeassistant/util/json.py b/homeassistant/util/json.py index 1029e58c1186c7cc999bdc9c0a4c511dadf9f759..8ecfebd5b33162e16366acfddd8f710d78449646 100644 --- a/homeassistant/util/json.py +++ b/homeassistant/util/json.py @@ -38,7 +38,7 @@ def load_json(filename: str, default: Union[List, Dict, None] = None) \ return {} if default is None else default -def save_json(filename: str, data: Union[List, Dict]): +def save_json(filename: str, data: Union[List, Dict]) -> None: """Save JSON data to a file. Returns True on success. diff --git a/homeassistant/util/location.py b/homeassistant/util/location.py index 9fc87b24a9b70e0db83d768df04fa3751b6a1f2b..16aec2ec6172ebcd629b93a3dc16123f7d006eb4 100644 --- a/homeassistant/util/location.py +++ b/homeassistant/util/location.py @@ -33,7 +33,7 @@ LocationInfo = collections.namedtuple( 'use_metric']) -def detect_location_info(): +def detect_location_info() -> Optional[LocationInfo]: """Detect location information.""" data = _get_freegeoip() @@ -63,7 +63,7 @@ def distance(lat1: Optional[float], lon1: Optional[float], return result * 1000 -def elevation(latitude, longitude): +def elevation(latitude: float, longitude: float) -> int: """Return elevation for given latitude and longitude.""" try: req = requests.get( diff --git a/homeassistant/util/logging.py b/homeassistant/util/logging.py index 7ce98fc2f2a2c4801e250de217188a7725e47159..f2bf15d8a03b4d68adff8792989667e96abbcc09 100644 --- a/homeassistant/util/logging.py +++ b/homeassistant/util/logging.py @@ -1,7 +1,9 @@ """Logging utilities.""" import asyncio +from asyncio.events import AbstractEventLoop import logging import threading +from typing import Optional from .async_ import run_coroutine_threadsafe @@ -9,12 +11,12 @@ from .async_ import run_coroutine_threadsafe class HideSensitiveDataFilter(logging.Filter): """Filter API password calls.""" - def __init__(self, text): + def __init__(self, text: str) -> None: """Initialize sensitive data filter.""" super().__init__() self.text = text - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: """Hide sensitive data in messages.""" record.msg = record.msg.replace(self.text, '*******') @@ -25,7 +27,8 @@ class HideSensitiveDataFilter(logging.Filter): class AsyncHandler: """Logging handler wrapper to add an async layer.""" - def __init__(self, loop, handler): + def __init__( + self, loop: AbstractEventLoop, handler: logging.Handler) -> None: """Initialize async logging handler wrapper.""" self.handler = handler self.loop = loop @@ -45,11 +48,11 @@ class AsyncHandler: self._thread.start() - def close(self): + def close(self) -> None: """Wrap close to handler.""" self.emit(None) - async def async_close(self, blocking=False): + async def async_close(self, blocking: bool = False) -> None: """Close the handler. When blocking=True, will wait till closed. @@ -60,7 +63,7 @@ class AsyncHandler: while self._thread.is_alive(): await asyncio.sleep(0, loop=self.loop) - def emit(self, record): + def emit(self, record: Optional[logging.LogRecord]) -> None: """Process a record.""" ident = self.loop.__dict__.get("_thread_ident") @@ -71,11 +74,11 @@ class AsyncHandler: else: self.loop.call_soon_threadsafe(self._queue.put_nowait, record) - def __repr__(self): + def __repr__(self) -> str: """Return the string names.""" return str(self.handler) - def _process(self): + def _process(self) -> None: """Process log in a thread.""" while True: record = run_coroutine_threadsafe( @@ -87,34 +90,34 @@ class AsyncHandler: self.handler.emit(record) - def createLock(self): + def createLock(self) -> None: """Ignore lock stuff.""" pass - def acquire(self): + def acquire(self) -> None: """Ignore lock stuff.""" pass - def release(self): + def release(self) -> None: """Ignore lock stuff.""" pass @property - def level(self): + def level(self) -> int: """Wrap property level to handler.""" return self.handler.level @property - def formatter(self): + def formatter(self) -> Optional[logging.Formatter]: """Wrap property formatter to handler.""" return self.handler.formatter @property - def name(self): + def name(self) -> str: """Wrap property set_name to handler.""" - return self.handler.get_name() + return self.handler.get_name() # type: ignore @name.setter - def name(self, name): + def name(self, name: str) -> None: """Wrap property get_name to handler.""" - self.handler.name = name + self.handler.set_name(name) # type: ignore diff --git a/homeassistant/util/package.py b/homeassistant/util/package.py index d1d398020dee1758601477898bc4756301a29f2b..9433046e6881beed5453900ef7616f5342bed2e7 100644 --- a/homeassistant/util/package.py +++ b/homeassistant/util/package.py @@ -16,7 +16,7 @@ _LOGGER = logging.getLogger(__name__) INSTALL_LOCK = threading.Lock() -def is_virtual_env(): +def is_virtual_env() -> bool: """Return if we run in a virtual environtment.""" # Check supports venv && virtualenv return (getattr(sys, 'base_prefix', sys.prefix) != sys.prefix or diff --git a/homeassistant/util/ssl.py b/homeassistant/util/ssl.py index 4f528cfcb51644ae821cca808271ce821e01aa85..392c5986c8914bf3c0c4b775edd321627bacde95 100644 --- a/homeassistant/util/ssl.py +++ b/homeassistant/util/ssl.py @@ -4,7 +4,7 @@ import ssl import certifi -def client_context(): +def client_context() -> ssl.SSLContext: """Return an SSL context for making requests.""" context = ssl.create_default_context( purpose=ssl.Purpose.SERVER_AUTH, @@ -13,7 +13,7 @@ def client_context(): return context -def server_context(): +def server_context() -> ssl.SSLContext: """Return an SSL context following the Mozilla recommendations. TLS configuration follows the best-practice guidelines specified here: diff --git a/homeassistant/util/yaml.py b/homeassistant/util/yaml.py index 7ce16600e1b61300bbb5c30ba07cec0d568cb74a..40ddfdf7b966a5aad001f6b0973264baaaa05626 100644 --- a/homeassistant/util/yaml.py +++ b/homeassistant/util/yaml.py @@ -4,7 +4,7 @@ import os import sys import fnmatch from collections import OrderedDict -from typing import Union, List, Dict +from typing import Union, List, Dict, Iterator, overload, TypeVar import yaml try: @@ -22,7 +22,10 @@ from homeassistant.exceptions import HomeAssistantError _LOGGER = logging.getLogger(__name__) _SECRET_NAMESPACE = 'homeassistant' SECRET_YAML = 'secrets.yaml' -__SECRET_CACHE = {} # type: Dict +__SECRET_CACHE = {} # type: Dict[str, JSON_TYPE] + +JSON_TYPE = Union[List, Dict, str] +DICT_T = TypeVar('DICT_T', bound=Dict) class NodeListClass(list): @@ -37,22 +40,12 @@ class NodeStrClass(str): pass -def _add_reference(obj, loader, node): - """Add file reference information to an object.""" - if isinstance(obj, list): - obj = NodeListClass(obj) - if isinstance(obj, str): - obj = NodeStrClass(obj) - setattr(obj, '__config_file__', loader.name) - setattr(obj, '__line__', node.start_mark.line) - return obj - - # pylint: disable=too-many-ancestors class SafeLineLoader(yaml.SafeLoader): """Loader class that keeps track of line numbers.""" - def compose_node(self, parent: yaml.nodes.Node, index) -> yaml.nodes.Node: + def compose_node(self, parent: yaml.nodes.Node, + index: int) -> yaml.nodes.Node: """Annotate a node with the first line it was seen.""" last_line = self.line # type: int node = super(SafeLineLoader, @@ -61,7 +54,39 @@ class SafeLineLoader(yaml.SafeLoader): return node -def load_yaml(fname: str) -> Union[List, Dict]: +# pylint: disable=pointless-statement +@overload +def _add_reference(obj: Union[list, NodeListClass], + loader: yaml.SafeLoader, + node: yaml.nodes.Node) -> NodeListClass: ... + + +@overload # noqa: F811 +def _add_reference(obj: Union[str, NodeStrClass], + loader: yaml.SafeLoader, + node: yaml.nodes.Node) -> NodeStrClass: ... + + +@overload # noqa: F811 +def _add_reference(obj: DICT_T, + loader: yaml.SafeLoader, + node: yaml.nodes.Node) -> DICT_T: ... +# pylint: enable=pointless-statement + + +def _add_reference(obj, loader: SafeLineLoader, # type: ignore # noqa: F811 + node: yaml.nodes.Node): + """Add file reference information to an object.""" + if isinstance(obj, list): + obj = NodeListClass(obj) + if isinstance(obj, str): + obj = NodeStrClass(obj) + setattr(obj, '__config_file__', loader.name) + setattr(obj, '__line__', node.start_mark.line) + return obj + + +def load_yaml(fname: str) -> JSON_TYPE: """Load a YAML file.""" try: with open(fname, encoding='utf-8') as conf_file: @@ -83,12 +108,12 @@ def dump(_dict: dict) -> str: .replace(': null\n', ':\n') -def save_yaml(path, data): +def save_yaml(path: str, data: dict) -> None: """Save YAML to a file.""" # Dump before writing to not truncate the file if dumping fails - data = dump(data) + str_data = dump(data) with open(path, 'w', encoding='utf-8') as outfile: - outfile.write(data) + outfile.write(str_data) def clear_secret_cache() -> None: @@ -100,7 +125,7 @@ def clear_secret_cache() -> None: def _include_yaml(loader: SafeLineLoader, - node: yaml.nodes.Node) -> Union[List, Dict]: + node: yaml.nodes.Node) -> JSON_TYPE: """Load another YAML file and embeds it using the !include tag. Example: @@ -115,7 +140,7 @@ def _is_file_valid(name: str) -> bool: return not name.startswith('.') -def _find_files(directory: str, pattern: str): +def _find_files(directory: str, pattern: str) -> Iterator[str]: """Recursively load files in a directory.""" for root, dirs, files in os.walk(directory, topdown=True): dirs[:] = [d for d in dirs if _is_file_valid(d)] @@ -151,7 +176,7 @@ def _include_dir_merge_named_yaml(loader: SafeLineLoader, def _include_dir_list_yaml(loader: SafeLineLoader, - node: yaml.nodes.Node): + node: yaml.nodes.Node) -> List[JSON_TYPE]: """Load multiple files from directory as a list.""" loc = os.path.join(os.path.dirname(loader.name), node.value) return [load_yaml(f) for f in _find_files(loc, '*.yaml') @@ -159,11 +184,11 @@ def _include_dir_list_yaml(loader: SafeLineLoader, def _include_dir_merge_list_yaml(loader: SafeLineLoader, - node: yaml.nodes.Node): + node: yaml.nodes.Node) -> JSON_TYPE: """Load multiple files from directory as a merged list.""" loc = os.path.join(os.path.dirname(loader.name), node.value) # type: str - merged_list = [] # type: List + merged_list = [] # type: List[JSON_TYPE] for fname in _find_files(loc, '*.yaml'): if os.path.basename(fname) == SECRET_YAML: continue @@ -202,14 +227,14 @@ def _ordered_dict(loader: SafeLineLoader, return _add_reference(OrderedDict(nodes), loader, node) -def _construct_seq(loader: SafeLineLoader, node: yaml.nodes.Node): +def _construct_seq(loader: SafeLineLoader, node: yaml.nodes.Node) -> JSON_TYPE: """Add line number and file name to Load YAML sequence.""" obj, = loader.construct_yaml_seq(node) return _add_reference(obj, loader, node) def _env_var_yaml(loader: SafeLineLoader, - node: yaml.nodes.Node): + node: yaml.nodes.Node) -> str: """Load environment variables and embed it into the configuration YAML.""" args = node.value.split() @@ -222,7 +247,7 @@ def _env_var_yaml(loader: SafeLineLoader, raise HomeAssistantError(node.value) -def _load_secret_yaml(secret_path: str) -> Dict: +def _load_secret_yaml(secret_path: str) -> JSON_TYPE: """Load the secrets yaml from path.""" secret_path = os.path.join(secret_path, SECRET_YAML) if secret_path in __SECRET_CACHE: @@ -248,7 +273,7 @@ def _load_secret_yaml(secret_path: str) -> Dict: def _secret_yaml(loader: SafeLineLoader, - node: yaml.nodes.Node): + node: yaml.nodes.Node) -> JSON_TYPE: """Load secrets and embed it into the configuration YAML.""" secret_path = os.path.dirname(loader.name) while True: @@ -308,7 +333,8 @@ yaml.SafeLoader.add_constructor('!include_dir_merge_named', # From: https://gist.github.com/miracle2k/3184458 # pylint: disable=redefined-outer-name -def represent_odict(dump, tag, mapping, flow_style=None): +def represent_odict(dump, tag, mapping, # type: ignore + flow_style=None) -> yaml.MappingNode: """Like BaseRepresenter.represent_mapping but does not issue the sort().""" value = [] # type: list node = yaml.MappingNode(tag, value, flow_style=flow_style) diff --git a/mypy.ini b/mypy.ini index 5a597994d6bdd06035936c02c487a1c2d1f58815..c92786e643fd5be6c75067e8631d7b5c6ef02d65 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,11 +2,18 @@ check_untyped_defs = true follow_imports = silent ignore_missing_imports = true +warn_incomplete_stub = true warn_redundant_casts = true warn_return_any = true warn_unused_configs = true warn_unused_ignores = true +[mypy-homeassistant.*] +disallow_untyped_defs = true + +[mypy-homeassistant.config_entries] +disallow_untyped_defs = false + [mypy-homeassistant.util.yaml] warn_return_any = false diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 33f1a7aa704c33665f91cb7b621b744c12b507e9..b1990fb80aac5023ab105087581d9b3992d03079 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -437,10 +437,12 @@ class TestAutomation(unittest.TestCase): } } }}): - automation.reload(self.hass) - self.hass.block_till_done() - # De-flake ?! - self.hass.block_till_done() + with patch('homeassistant.config.find_config_file', + return_value=''): + automation.reload(self.hass) + self.hass.block_till_done() + # De-flake ?! + self.hass.block_till_done() assert self.hass.states.get('automation.hello') is None assert self.hass.states.get('automation.bye') is not None @@ -485,8 +487,10 @@ class TestAutomation(unittest.TestCase): with patch('homeassistant.config.load_yaml_config_file', autospec=True, return_value={automation.DOMAIN: 'not valid'}): - automation.reload(self.hass) - self.hass.block_till_done() + with patch('homeassistant.config.find_config_file', + return_value=''): + automation.reload(self.hass) + self.hass.block_till_done() assert self.hass.states.get('automation.hello') is None @@ -521,8 +525,10 @@ class TestAutomation(unittest.TestCase): with patch('homeassistant.config.load_yaml_config_file', side_effect=HomeAssistantError('bla')): - automation.reload(self.hass) - self.hass.block_till_done() + with patch('homeassistant.config.find_config_file', + return_value=''): + automation.reload(self.hass) + self.hass.block_till_done() assert self.hass.states.get('automation.hello') is not None diff --git a/tests/components/group/test_init.py b/tests/components/group/test_init.py index 31ad70e8abac22aca7f87cfa51ed7709b9ade87b..a5e9bbc0b820c2c33524056e2de3779ba7ef2984 100644 --- a/tests/components/group/test_init.py +++ b/tests/components/group/test_init.py @@ -365,8 +365,10 @@ class TestComponentsGroup(unittest.TestCase): 'icon': 'mdi:work', 'view': True, }}}): - group.reload(self.hass) - self.hass.block_till_done() + with patch('homeassistant.config.find_config_file', + return_value=''): + group.reload(self.hass) + self.hass.block_till_done() assert sorted(self.hass.states.entity_ids()) == \ ['group.all_tests', 'group.hello'] diff --git a/tests/components/test_script.py b/tests/components/test_script.py index fcb0047c135bb7e10b7ffcd0f17ccd4f00916b77..c4282cdfbaf201de2f298f6ec04f8ce95b208d90 100644 --- a/tests/components/test_script.py +++ b/tests/components/test_script.py @@ -199,8 +199,10 @@ class TestScriptComponent(unittest.TestCase): } }] }}}): - script.reload(self.hass) - self.hass.block_till_done() + with patch('homeassistant.config.find_config_file', + return_value=''): + script.reload(self.hass) + self.hass.block_till_done() assert self.hass.states.get(ENTITY_ID) is None assert not self.hass.services.has_service(script.DOMAIN, 'test')