From 940d5fb2eeaf8103ca990c84a05869d849375074 Mon Sep 17 00:00:00 2001
From: Patrik <21142447+ggravlingen@users.noreply.github.com>
Date: Sun, 30 Sep 2018 22:22:07 +0200
Subject: [PATCH] Add basic support for Tradfri switches (#17007)

* Initial commit

* Sockets have been moved to separate component

* Sockets have been moved to separate component

* Fix const PLATFORM_SCHEMA

* Fix unique id

* Fix async_create_task

* Fix PLATFORM_SCHEMA

* Fix typo

* Remove pylint disable
---
 homeassistant/components/sensor/tradfri.py   |   3 +-
 homeassistant/components/switch/tradfri.py   | 137 +++++++++++++++++++
 homeassistant/components/tradfri/__init__.py |   5 +-
 requirements_all.txt                         |   2 +-
 requirements_test_all.txt                    |   2 +-
 5 files changed, 145 insertions(+), 4 deletions(-)
 create mode 100644 homeassistant/components/switch/tradfri.py

diff --git a/homeassistant/components/sensor/tradfri.py b/homeassistant/components/sensor/tradfri.py
index 86d0c1abc19..769857cb6df 100644
--- a/homeassistant/components/sensor/tradfri.py
+++ b/homeassistant/components/sensor/tradfri.py
@@ -26,7 +26,8 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
 
     devices_commands = await api(gateway.get_devices())
     all_devices = await api(devices_commands)
-    devices = (dev for dev in all_devices if not dev.has_light_control)
+    devices = (dev for dev in all_devices if not dev.has_light_control and
+               not dev.has_socket_control)
     async_add_entities(TradfriDevice(device, api) for device in devices)
 
 
diff --git a/homeassistant/components/switch/tradfri.py b/homeassistant/components/switch/tradfri.py
new file mode 100644
index 00000000000..74997332b07
--- /dev/null
+++ b/homeassistant/components/switch/tradfri.py
@@ -0,0 +1,137 @@
+"""
+Support for the IKEA Tradfri platform.
+
+For more details about this platform, please refer to the documentation at
+https://home-assistant.io/components/switch.tradfri/
+"""
+import logging
+
+from homeassistant.core import callback
+from homeassistant.components.switch import SwitchDevice
+from homeassistant.components.tradfri import (
+    KEY_GATEWAY, KEY_API, DOMAIN as TRADFRI_DOMAIN)
+from homeassistant.components.tradfri.const import (
+    CONF_GATEWAY_ID)
+
+_LOGGER = logging.getLogger(__name__)
+
+DEPENDENCIES = ['tradfri']
+IKEA = 'IKEA of Sweden'
+TRADFRI_SWITCH_MANAGER = 'Tradfri Switch Manager'
+
+
+async def async_setup_entry(hass, config_entry, async_add_entities):
+    """Load Tradfri switches based on a config entry."""
+    gateway_id = config_entry.data[CONF_GATEWAY_ID]
+    api = hass.data[KEY_API][config_entry.entry_id]
+    gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
+
+    devices_commands = await api(gateway.get_devices())
+    devices = await api(devices_commands)
+    switches = [dev for dev in devices if dev.has_socket_control]
+    if switches:
+        async_add_entities(
+            TradfriSwitch(switch, api, gateway_id) for switch in switches)
+
+
+class TradfriSwitch(SwitchDevice):
+    """The platform class required by Home Assistant."""
+
+    def __init__(self, switch, api, gateway_id):
+        """Initialize a switch."""
+        self._api = api
+        self._unique_id = "{}-{}".format(gateway_id, switch.id)
+        self._switch = None
+        self._socket_control = None
+        self._switch_data = None
+        self._name = None
+        self._available = True
+        self._gateway_id = gateway_id
+
+        self._refresh(switch)
+
+    @property
+    def unique_id(self):
+        """Return unique ID for switch."""
+        return self._unique_id
+
+    @property
+    def device_info(self):
+        """Return the device info."""
+        info = self._switch.device_info
+
+        return {
+            'identifiers': {
+                (TRADFRI_DOMAIN, self._switch.id)
+            },
+            'name': self._name,
+            'manufacturer': info.manufacturer,
+            'model': info.model_number,
+            'sw_version': info.firmware_version,
+            'via_hub': (TRADFRI_DOMAIN, self._gateway_id),
+        }
+
+    async def async_added_to_hass(self):
+        """Start thread when added to hass."""
+        self._async_start_observe()
+
+    @property
+    def available(self):
+        """Return True if entity is available."""
+        return self._available
+
+    @property
+    def should_poll(self):
+        """No polling needed for tradfri switch."""
+        return False
+
+    @property
+    def name(self):
+        """Return the display name of this switch."""
+        return self._name
+
+    @property
+    def is_on(self):
+        """Return true if switch is on."""
+        return self._switch_data.state
+
+    async def async_turn_off(self, **kwargs):
+        """Instruct the switch to turn off."""
+        await self._api(self._socket_control.set_state(False))
+
+    async def async_turn_on(self, **kwargs):
+        """Instruct the switch to turn on."""
+        await self._api(self._socket_control.set_state(True))
+
+    @callback
+    def _async_start_observe(self, exc=None):
+        """Start observation of switch."""
+        from pytradfri.error import PytradfriError
+        if exc:
+            _LOGGER.warning("Observation failed for %s", self._name,
+                            exc_info=exc)
+
+        try:
+            cmd = self._switch.observe(callback=self._observe_update,
+                                       err_callback=self._async_start_observe,
+                                       duration=0)
+            self.hass.async_create_task(self._api(cmd))
+        except PytradfriError as err:
+            _LOGGER.warning("Observation failed, trying again", exc_info=err)
+            self._async_start_observe()
+
+    def _refresh(self, switch):
+        """Refresh the switch data."""
+        self._switch = switch
+
+        # Caching of switchControl and switch object
+        self._available = switch.reachable
+        self._socket_control = switch.socket_control
+        self._switch_data = switch.socket_control.sockets[0]
+        self._name = switch.name
+
+    @callback
+    def _observe_update(self, tradfri_device):
+        """Receive new state data for this switch."""
+        self._refresh(tradfri_device)
+        self.async_schedule_update_ha_state()
diff --git a/homeassistant/components/tradfri/__init__.py b/homeassistant/components/tradfri/__init__.py
index 6e91ab338a3..51195d0a168 100644
--- a/homeassistant/components/tradfri/__init__.py
+++ b/homeassistant/components/tradfri/__init__.py
@@ -17,7 +17,7 @@ from .const import (
 
 from . import config_flow  # noqa  pylint_disable=unused-import
 
-REQUIREMENTS = ['pytradfri[async]==5.5.1']
+REQUIREMENTS = ['pytradfri[async]==5.6.0']
 
 DOMAIN = 'tradfri'
 CONFIG_FILE = '.tradfri_psk.conf'
@@ -119,5 +119,8 @@ async def async_setup_entry(hass, entry):
     hass.async_create_task(hass.config_entries.async_forward_entry_setup(
         entry, 'sensor'
     ))
+    hass.async_create_task(hass.config_entries.async_forward_entry_setup(
+        entry, 'switch'
+    ))
 
     return True
diff --git a/requirements_all.txt b/requirements_all.txt
index 638599b9642..c56425a9671 100644
--- a/requirements_all.txt
+++ b/requirements_all.txt
@@ -1214,7 +1214,7 @@ pytouchline==0.7
 pytrackr==0.0.5
 
 # homeassistant.components.tradfri
-pytradfri[async]==5.5.1
+pytradfri[async]==5.6.0
 
 # homeassistant.components.sensor.trafikverket_weatherstation
 pytrafikverket==0.1.5.8
diff --git a/requirements_test_all.txt b/requirements_test_all.txt
index 540a200ad32..7c3ace8e8cf 100644
--- a/requirements_test_all.txt
+++ b/requirements_test_all.txt
@@ -189,7 +189,7 @@ python-nest==4.0.3
 pythonwhois==2.4.3
 
 # homeassistant.components.tradfri
-pytradfri[async]==5.5.1
+pytradfri[async]==5.6.0
 
 # homeassistant.components.device_tracker.unifi
 pyunifi==2.13
-- 
GitLab