From cf57db919e450d2898a079fb9beb5467663de256 Mon Sep 17 00:00:00 2001
From: Pascal Vizeli <pascal.vizeli@syshack.ch>
Date: Mon, 28 Nov 2016 01:26:46 +0100
Subject: [PATCH] Refactory aiohttp clientsession handling in HA (#4602)

* Refactory aiohttp clientsession handling in HA

* remove from core / update platforms / rename file
---
 homeassistant/components/camera/generic.py    |   4 +-
 homeassistant/components/camera/mjpeg.py      |   4 +-
 homeassistant/components/camera/synology.py   |  38 ++----
 .../components/media_player/__init__.py       |   4 +-
 homeassistant/components/sensor/yr.py         |   4 +-
 homeassistant/components/switch/hook.py       |  13 +-
 homeassistant/core.py                         |  12 --
 homeassistant/helpers/aiohttp_client.py       | 119 ++++++++++++++++++
 homeassistant/remote.py                       |   1 -
 tests/components/media_player/test_demo.py    |   3 +-
 tests/helpers/test_aiohttp_client.py          |  81 ++++++++++++
 11 files changed, 229 insertions(+), 54 deletions(-)
 create mode 100644 homeassistant/helpers/aiohttp_client.py
 create mode 100644 tests/helpers/test_aiohttp_client.py

diff --git a/homeassistant/components/camera/generic.py b/homeassistant/components/camera/generic.py
index ec85e6306d4..a73132282bf 100644
--- a/homeassistant/components/camera/generic.py
+++ b/homeassistant/components/camera/generic.py
@@ -18,6 +18,7 @@ from homeassistant.const import (
     HTTP_BASIC_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION)
 from homeassistant.exceptions import TemplateError
 from homeassistant.components.camera import (PLATFORM_SCHEMA, Camera)
+from homeassistant.helpers.aiohttp_client import async_get_clientsession
 from homeassistant.helpers import config_validation as cv
 from homeassistant.util.async import run_coroutine_threadsafe
 
@@ -108,8 +109,9 @@ class GenericCamera(Camera):
         # async
         else:
             try:
+                websession = async_get_clientsession(self.hass)
                 with async_timeout.timeout(10, loop=self.hass.loop):
-                    response = yield from self.hass.websession.get(
+                    response = yield from websession.get(
                         url, auth=self._auth)
                     self._last_image = yield from response.read()
                     yield from response.release()
diff --git a/homeassistant/components/camera/mjpeg.py b/homeassistant/components/camera/mjpeg.py
index a2c35410c55..d96ea4ab0a3 100644
--- a/homeassistant/components/camera/mjpeg.py
+++ b/homeassistant/components/camera/mjpeg.py
@@ -20,6 +20,7 @@ from homeassistant.const import (
     CONF_NAME, CONF_USERNAME, CONF_PASSWORD, CONF_AUTHENTICATION,
     HTTP_BASIC_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION)
 from homeassistant.components.camera import (PLATFORM_SCHEMA, Camera)
+from homeassistant.helpers.aiohttp_client import async_get_clientsession
 from homeassistant.helpers import config_validation as cv
 
 _LOGGER = logging.getLogger(__name__)
@@ -101,9 +102,10 @@ class MjpegCamera(Camera):
             return
 
         # connect to stream
+        websession = async_get_clientsession(self.hass)
         try:
             with async_timeout.timeout(10, loop=self.hass.loop):
-                stream = yield from self.hass.websession.get(
+                stream = yield from websession.get(
                     self._mjpeg_url,
                     auth=self._auth
                 )
diff --git a/homeassistant/components/camera/synology.py b/homeassistant/components/camera/synology.py
index 9292e839b53..1db83ddf762 100644
--- a/homeassistant/components/camera/synology.py
+++ b/homeassistant/components/camera/synology.py
@@ -14,12 +14,13 @@ from aiohttp import web
 from aiohttp.web_exceptions import HTTPGatewayTimeout
 import async_timeout
 
-from homeassistant.core import callback
 from homeassistant.const import (
     CONF_NAME, CONF_USERNAME, CONF_PASSWORD,
-    CONF_URL, CONF_WHITELIST, CONF_VERIFY_SSL, EVENT_HOMEASSISTANT_STOP)
+    CONF_URL, CONF_WHITELIST, CONF_VERIFY_SSL)
 from homeassistant.components.camera import (
     Camera, PLATFORM_SCHEMA)
+from homeassistant.helpers.aiohttp_client import (
+    async_get_clientsession, async_create_clientsession)
 import homeassistant.helpers.config_validation as cv
 from homeassistant.util.async import run_coroutine_threadsafe
 
@@ -59,23 +60,8 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
 @asyncio.coroutine
 def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
     """Setup a Synology IP Camera."""
-    if not config.get(CONF_VERIFY_SSL):
-        connector = aiohttp.TCPConnector(verify_ssl=False)
-
-        @asyncio.coroutine
-        def _async_close_connector(event):
-            """Close websession on shutdown."""
-            yield from connector.close()
-
-        hass.bus.async_listen_once(
-            EVENT_HOMEASSISTANT_STOP, _async_close_connector)
-    else:
-        connector = hass.websession.connector
-
-    websession_init = aiohttp.ClientSession(
-        loop=hass.loop,
-        connector=connector
-    )
+    verify_ssl = config.get(CONF_VERIFY_SSL)
+    websession_init = async_get_clientsession(hass, verify_ssl)
 
     # Determine API to use for authentication
     syno_api_url = SYNO_API_URL.format(
@@ -118,19 +104,9 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
         syno_auth_url
     )
 
-    websession_init.detach()
-
     # init websession
-    websession = aiohttp.ClientSession(
-        loop=hass.loop, connector=connector, cookies={'id': session_id})
-
-    @callback
-    def _async_close_websession(event):
-        """Close websession on shutdown."""
-        websession.detach()
-
-    hass.bus.async_listen_once(
-        EVENT_HOMEASSISTANT_STOP, _async_close_websession)
+    websession = async_create_clientsession(
+        hass, verify_ssl, cookies={'id': session_id})
 
     # Use SessionID to get cameras in system
     syno_camera_url = SYNO_API_URL.format(
diff --git a/homeassistant/components/media_player/__init__.py b/homeassistant/components/media_player/__init__.py
index 5665699d4f3..c9df431965b 100644
--- a/homeassistant/components/media_player/__init__.py
+++ b/homeassistant/components/media_player/__init__.py
@@ -18,6 +18,7 @@ from homeassistant.helpers.entity import Entity
 from homeassistant.helpers.entity_component import EntityComponent
 from homeassistant.helpers.config_validation import PLATFORM_SCHEMA  # noqa
 from homeassistant.components.http import HomeAssistantView, KEY_AUTHENTICATED
+from homeassistant.helpers.aiohttp_client import async_get_clientsession
 import homeassistant.helpers.config_validation as cv
 from homeassistant.util.async import run_coroutine_threadsafe
 from homeassistant.const import (
@@ -705,8 +706,9 @@ def _async_fetch_image(hass, url):
 
     content, content_type = (None, None)
     try:
+        websession = async_get_clientsession(hass)
         with async_timeout.timeout(10, loop=hass.loop):
-            response = yield from hass.websession.get(url)
+            response = yield from websession.get(url)
             if response.status == 200:
                 content = yield from response.read()
                 content_type = response.headers.get(CONTENT_TYPE_HEADER)
diff --git a/homeassistant/components/sensor/yr.py b/homeassistant/components/sensor/yr.py
index 6429c9dcaad..e3cc5186230 100644
--- a/homeassistant/components/sensor/yr.py
+++ b/homeassistant/components/sensor/yr.py
@@ -18,6 +18,7 @@ from homeassistant.components.sensor import PLATFORM_SCHEMA
 from homeassistant.const import (
     CONF_LATITUDE, CONF_LONGITUDE, CONF_ELEVATION, CONF_MONITORED_CONDITIONS,
     ATTR_ATTRIBUTION)
+from homeassistant.helpers.aiohttp_client import async_get_clientsession
 from homeassistant.helpers.entity import Entity
 from homeassistant.helpers.event import (
     async_track_point_in_utc_time, async_track_utc_time_change)
@@ -155,8 +156,9 @@ class YrData(object):
 
         if self._nextrun is None or dt_util.utcnow() >= self._nextrun:
             try:
+                websession = async_get_clientsession(self.hass)
                 with async_timeout.timeout(10, loop=self.hass.loop):
-                    resp = yield from self.hass.websession.get(self._url)
+                    resp = yield from websession.get(self._url)
                 if resp.status != 200:
                     try_again('{} returned {}'.format(self._url, resp.status))
                     return
diff --git a/homeassistant/components/switch/hook.py b/homeassistant/components/switch/hook.py
index 8f24842212d..29fe8372fab 100644
--- a/homeassistant/components/switch/hook.py
+++ b/homeassistant/components/switch/hook.py
@@ -13,6 +13,7 @@ import aiohttp
 
 from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA)
 from homeassistant.const import CONF_PASSWORD, CONF_USERNAME
+from homeassistant.helpers.aiohttp_client import async_get_clientsession
 import homeassistant.helpers.config_validation as cv
 
 _LOGGER = logging.getLogger(__name__)
@@ -31,10 +32,11 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
     """Set up Hook by getting the access token and list of actions."""
     username = config.get(CONF_USERNAME)
     password = config.get(CONF_PASSWORD)
+    websession = async_get_clientsession(hass)
 
     try:
         with async_timeout.timeout(TIMEOUT, loop=hass.loop):
-            response = yield from hass.websession.post(
+            response = yield from websession.post(
                 '{}{}'.format(HOOK_ENDPOINT, 'user/login'),
                 data={
                     'username': username,
@@ -54,7 +56,7 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
 
     try:
         with async_timeout.timeout(TIMEOUT, loop=hass.loop):
-            response = yield from hass.websession.get(
+            response = yield from websession.get(
                 '{}{}'.format(HOOK_ENDPOINT, 'device'),
                 params={"token": data['data']['token']})
             data = yield from response.json()
@@ -79,7 +81,7 @@ class HookSmartHome(SwitchDevice):
 
     def __init__(self, hass, token, device_id, device_name):
         """Initialize the switch."""
-        self._hass = hass
+        self.hass = hass
         self._token = token
         self._state = False
         self._id = device_id
@@ -102,8 +104,9 @@ class HookSmartHome(SwitchDevice):
         """Send the url to the Hook API."""
         try:
             _LOGGER.debug("Sending: %s", url)
-            with async_timeout.timeout(TIMEOUT, loop=self._hass.loop):
-                response = yield from self._hass.websession.get(
+            websession = async_get_clientsession(self.hass)
+            with async_timeout.timeout(TIMEOUT, loop=self.hass.loop):
+                response = yield from websession.get(
                     url, params={"token": self._token})
                 data = yield from response.json()
         except (asyncio.TimeoutError,
diff --git a/homeassistant/core.py b/homeassistant/core.py
index 42ab117eadc..f358903735b 100644
--- a/homeassistant/core.py
+++ b/homeassistant/core.py
@@ -18,7 +18,6 @@ import threading
 from types import MappingProxyType
 from typing import Optional, Any, Callable, List  # NOQA
 
-import aiohttp
 import voluptuous as vol
 from voluptuous.humanize import humanize_error
 
@@ -121,21 +120,12 @@ class HomeAssistant(object):
         self.data = {}
         self.state = CoreState.not_running
         self.exit_code = None
-        self._websession = None
 
     @property
     def is_running(self) -> bool:
         """Return if Home Assistant is running."""
         return self.state in (CoreState.starting, CoreState.running)
 
-    @property
-    def websession(self):
-        """Return an aiohttp session to make web requests."""
-        if self._websession is None:
-            self._websession = aiohttp.ClientSession(loop=self.loop)
-
-        return self._websession
-
     def start(self) -> None:
         """Start home assistant."""
         # Register the async start
@@ -295,8 +285,6 @@ class HomeAssistant(object):
         self.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
         yield from self.async_block_till_done()
         self.executor.shutdown()
-        if self._websession is not None:
-            yield from self._websession.close()
         self.state = CoreState.not_running
         self.loop.stop()
 
diff --git a/homeassistant/helpers/aiohttp_client.py b/homeassistant/helpers/aiohttp_client.py
new file mode 100644
index 00000000000..a1ec8ac85da
--- /dev/null
+++ b/homeassistant/helpers/aiohttp_client.py
@@ -0,0 +1,119 @@
+"""Helper for aiohttp webclient stuff."""
+import asyncio
+
+import aiohttp
+
+from homeassistant.core import callback
+from homeassistant.const import EVENT_HOMEASSISTANT_STOP
+
+
+DATA_CONNECTOR = 'aiohttp_connector'
+DATA_CONNECTOR_NOTVERIFY = 'aiohttp_connector_notverify'
+DATA_CLIENTSESSION = 'aiohttp_clientsession'
+DATA_CLIENTSESSION_NOTVERIFY = 'aiohttp_clientsession_notverify'
+
+
+@callback
+def async_get_clientsession(hass, verify_ssl=True):
+    """Return default aiohttp ClientSession.
+
+    This method must be run in the event loop.
+    """
+    if verify_ssl:
+        key = DATA_CLIENTSESSION
+    else:
+        key = DATA_CLIENTSESSION_NOTVERIFY
+
+    if key not in hass.data:
+        connector = _async_get_connector(hass, verify_ssl)
+        clientsession = aiohttp.ClientSession(
+            loop=hass.loop,
+            connector=connector
+        )
+        _async_register_clientsession_shutdown(hass, clientsession)
+        hass.data[key] = clientsession
+
+    return hass.data[key]
+
+
+@callback
+def async_create_clientsession(hass, verify_ssl=True, auto_cleanup=True,
+                               **kwargs):
+    """Create a new ClientSession with kwargs, i.e. for cookies.
+
+    If auto_cleanup is False, you need to call detach() after the session
+    returned is no longer used. Default is True, the session will be
+    automatically detached on homeassistant_stop.
+
+    This method must be run in the event loop.
+    """
+    connector = _async_get_connector(hass, verify_ssl)
+
+    clientsession = aiohttp.ClientSession(
+        loop=hass.loop,
+        connector=connector,
+        **kwargs
+    )
+
+    if auto_cleanup:
+        _async_register_clientsession_shutdown(hass, clientsession)
+
+    return clientsession
+
+
+@callback
+# pylint: disable=invalid-name
+def _async_register_clientsession_shutdown(hass, clientsession):
+    """Register ClientSession close on homeassistant shutdown.
+
+    This method must be run in the event loop.
+    """
+    @callback
+    def _async_close_websession(event):
+        """Close websession."""
+        clientsession.detach()
+
+    hass.bus.async_listen_once(
+        EVENT_HOMEASSISTANT_STOP, _async_close_websession)
+
+
+@callback
+def _async_get_connector(hass, verify_ssl=True):
+    """Return the connector pool for aiohttp.
+
+    This method must be run in the event loop.
+    """
+    if verify_ssl:
+        if DATA_CONNECTOR not in hass.data:
+            connector = aiohttp.TCPConnector(loop=hass.loop)
+            hass.data[DATA_CONNECTOR] = connector
+
+            _async_register_connector_shutdown(hass, connector)
+        else:
+            connector = hass.data[DATA_CONNECTOR]
+    else:
+        if DATA_CONNECTOR_NOTVERIFY not in hass.data:
+            connector = aiohttp.TCPConnector(loop=hass.loop, verify_ssl=False)
+            hass.data[DATA_CONNECTOR_NOTVERIFY] = connector
+
+            _async_register_connector_shutdown(hass, connector)
+        else:
+            connector = hass.data[DATA_CONNECTOR_NOTVERIFY]
+
+    return connector
+
+
+@callback
+# pylint: disable=invalid-name
+def _async_register_connector_shutdown(hass, connector):
+    """Register connector pool close on homeassistant shutdown.
+
+    This method must be run in the event loop.
+    """
+    @asyncio.coroutine
+    def _async_close_connector(event):
+        """Close websession on shutdown."""
+        yield from connector.close()
+
+    hass.bus.async_listen_once(
+        EVENT_HOMEASSISTANT_STOP, _async_close_connector)
diff --git a/homeassistant/remote.py b/homeassistant/remote.py
index fa6cb446c67..c9270e2032f 100644
--- a/homeassistant/remote.py
+++ b/homeassistant/remote.py
@@ -138,7 +138,6 @@ class HomeAssistant(ha.HomeAssistant):
         self.data = {}
         self.state = ha.CoreState.not_running
         self.exit_code = None
-        self._websession = None
         self.config.api = local_api
 
     def start(self):
diff --git a/tests/components/media_player/test_demo.py b/tests/components/media_player/test_demo.py
index 3539c73b7dd..c9fb3ad6ff8 100644
--- a/tests/components/media_player/test_demo.py
+++ b/tests/components/media_player/test_demo.py
@@ -7,6 +7,7 @@ from homeassistant.bootstrap import setup_component
 from homeassistant.const import HTTP_HEADER_HA_AUTH
 import homeassistant.components.media_player as mp
 import homeassistant.components.http as http
+from homeassistant.helpers.aiohttp_client import DATA_CLIENTSESSION
 
 import requests
 
@@ -289,7 +290,7 @@ class TestMediaPlayerWeb(unittest.TestCase):
             def close(self):
                 pass
 
-        self.hass._websession = MockWebsession()
+        self.hass.data[DATA_CLIENTSESSION] = MockWebsession()
 
         assert self.hass.states.is_state(entity_id, 'playing')
         state = self.hass.states.get(entity_id)
diff --git a/tests/helpers/test_aiohttp_client.py b/tests/helpers/test_aiohttp_client.py
new file mode 100644
index 00000000000..83e1275819b
--- /dev/null
+++ b/tests/helpers/test_aiohttp_client.py
@@ -0,0 +1,81 @@
+"""Test the aiohttp client helper."""
+import unittest
+
+import aiohttp
+
+import homeassistant.helpers.aiohttp_client as client
+from homeassistant.util.async import run_callback_threadsafe
+
+from tests.common import get_test_home_assistant
+
+
+class TestHelpersAiohttpClient(unittest.TestCase):
+    """Test homeassistant.helpers.aiohttp_client module."""
+
+    def setup_method(self, method):
+        """Setup things to be run when tests are started."""
+        self.hass = get_test_home_assistant()
+
+    def teardown_method(self, method):
+        """Stop everything that was started."""
+        self.hass.stop()
+
+    def test_get_clientsession_with_ssl(self):
+        """Test init clientsession with ssl."""
+        run_callback_threadsafe(self.hass.loop, client.async_get_clientsession,
+                                self.hass).result()
+
+        assert isinstance(
+            self.hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession)
+        assert isinstance(
+            self.hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
+
+    def test_get_clientsession_without_ssl(self):
+        """Test init clientsession without ssl."""
+        run_callback_threadsafe(self.hass.loop, client.async_get_clientsession,
+                                self.hass, False).result()
+
+        assert isinstance(
+            self.hass.data[client.DATA_CLIENTSESSION_NOTVERIFY],
+            aiohttp.ClientSession)
+        assert isinstance(
+            self.hass.data[client.DATA_CONNECTOR_NOTVERIFY],
+            aiohttp.TCPConnector)
+
+    def test_create_clientsession_with_ssl_and_cookies(self):
+        """Test create clientsession with ssl."""
+        def _async_helper():
+            return client.async_create_clientsession(
+                self.hass,
+                cookies={'bla': True}
+            )
+
+        session = run_callback_threadsafe(
+            self.hass.loop,
+            _async_helper,
+        ).result()
+
+        assert isinstance(
+            session, aiohttp.ClientSession)
+        assert isinstance(
+            self.hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
+
+    def test_create_clientsession_without_ssl_and_cookies(self):
+        """Test create clientsession without ssl."""
+        def _async_helper():
+            return client.async_create_clientsession(
+                self.hass,
+                False,
+                cookies={'bla': True}
+            )
+
+        session = run_callback_threadsafe(
+            self.hass.loop,
+            _async_helper,
+        ).result()
+
+        assert isinstance(
+            session, aiohttp.ClientSession)
+        assert isinstance(
+            self.hass.data[client.DATA_CONNECTOR_NOTVERIFY],
+            aiohttp.TCPConnector)
-- 
GitLab