From e776f88eecc6cd543a7749cc9b0a38660eb33e77 Mon Sep 17 00:00:00 2001
From: Paulus Schoutsen <paulus@paulusschoutsen.nl>
Date: Tue, 14 Aug 2018 21:14:12 +0200
Subject: [PATCH] Use JWT for access tokens (#15972)

* Use JWT for access tokens

* Update requirements

* Improvements
---
 homeassistant/auth/__init__.py                | 64 +++++++++++++------
 homeassistant/auth/auth_store.py              | 56 ++++++++--------
 homeassistant/auth/models.py                  | 22 +------
 homeassistant/components/auth/__init__.py     |  6 +-
 homeassistant/components/http/auth.py         |  6 +-
 homeassistant/components/websocket_api.py     |  9 +--
 homeassistant/package_constraints.txt         |  1 +
 requirements_all.txt                          |  1 +
 setup.py                                      |  1 +
 tests/auth/test_init.py                       | 45 ++++---------
 tests/common.py                               | 14 ++--
 tests/components/auth/test_init.py            | 24 +++++--
 tests/components/auth/test_init_link_user.py  |  2 +-
 tests/components/config/test_auth.py          | 16 +++--
 .../test_auth_provider_homeassistant.py       | 38 ++++++++---
 tests/components/conftest.py                  |  2 +-
 tests/components/hassio/test_init.py          |  6 +-
 tests/components/http/test_auth.py            |  8 ++-
 tests/components/test_api.py                  | 22 +++++--
 tests/components/test_websocket_api.py        | 15 +++--
 20 files changed, 203 insertions(+), 155 deletions(-)

diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py
index 9695e77f6f1..148f97702e3 100644
--- a/homeassistant/auth/__init__.py
+++ b/homeassistant/auth/__init__.py
@@ -4,10 +4,12 @@ import logging
 from collections import OrderedDict
 from typing import List, Awaitable
 
+import jwt
+
 from homeassistant import data_entry_flow
 from homeassistant.core import callback, HomeAssistant
+from homeassistant.util import dt as dt_util
 
-from . import models
 from . import auth_store
 from .providers import auth_provider_from_config
 
@@ -54,7 +56,6 @@ class AuthManager:
         self.login_flow = data_entry_flow.FlowManager(
             hass, self._async_create_login_flow,
             self._async_finish_login_flow)
-        self._access_tokens = OrderedDict()
 
     @property
     def active(self):
@@ -181,35 +182,56 @@ class AuthManager:
 
         return await self._store.async_create_refresh_token(user, client_id)
 
-    async def async_get_refresh_token(self, token):
+    async def async_get_refresh_token(self, token_id):
+        """Get refresh token by id."""
+        return await self._store.async_get_refresh_token(token_id)
+
+    async def async_get_refresh_token_by_token(self, token):
         """Get refresh token by token."""
-        return await self._store.async_get_refresh_token(token)
+        return await self._store.async_get_refresh_token_by_token(token)
 
     @callback
     def async_create_access_token(self, refresh_token):
         """Create a new access token."""
-        access_token = models.AccessToken(refresh_token=refresh_token)
-        self._access_tokens[access_token.token] = access_token
-        return access_token
-
-    @callback
-    def async_get_access_token(self, token):
-        """Get an access token."""
-        tkn = self._access_tokens.get(token)
+        # pylint: disable=no-self-use
+        return jwt.encode({
+            'iss': refresh_token.id,
+            'iat': dt_util.utcnow(),
+            'exp': dt_util.utcnow() + refresh_token.access_token_expiration,
+        }, refresh_token.jwt_key, algorithm='HS256').decode()
+
+    async def async_validate_access_token(self, token):
+        """Return if an access token is valid."""
+        try:
+            unverif_claims = jwt.decode(token, verify=False)
+        except jwt.InvalidTokenError:
+            return None
 
-        if tkn is None:
-            _LOGGER.debug('Attempt to get non-existing access token')
+        refresh_token = await self.async_get_refresh_token(
+            unverif_claims.get('iss'))
+
+        if refresh_token is None:
+            jwt_key = ''
+            issuer = ''
+        else:
+            jwt_key = refresh_token.jwt_key
+            issuer = refresh_token.id
+
+        try:
+            jwt.decode(
+                token,
+                jwt_key,
+                leeway=10,
+                issuer=issuer,
+                algorithms=['HS256']
+            )
+        except jwt.InvalidTokenError:
             return None
 
-        if tkn.expired or not tkn.refresh_token.user.is_active:
-            if tkn.expired:
-                _LOGGER.debug('Attempt to get expired access token')
-            else:
-                _LOGGER.debug('Attempt to get access token for inactive user')
-            self._access_tokens.pop(token)
+        if not refresh_token.user.is_active:
             return None
 
-        return tkn
+        return refresh_token
 
     async def _async_create_login_flow(self, handler, *, context, data):
         """Create a login flow."""
diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py
index 8fd66d4bbb7..806cd109d78 100644
--- a/homeassistant/auth/auth_store.py
+++ b/homeassistant/auth/auth_store.py
@@ -1,6 +1,7 @@
 """Storage for auth models."""
 from collections import OrderedDict
 from datetime import timedelta
+import hmac
 
 from homeassistant.util import dt as dt_util
 
@@ -110,22 +111,36 @@ class AuthStore:
     async def async_create_refresh_token(self, user, client_id=None):
         """Create a new token for a user."""
         refresh_token = models.RefreshToken(user=user, client_id=client_id)
-        user.refresh_tokens[refresh_token.token] = refresh_token
+        user.refresh_tokens[refresh_token.id] = refresh_token
         await self.async_save()
         return refresh_token
 
-    async def async_get_refresh_token(self, token):
-        """Get refresh token by token."""
+    async def async_get_refresh_token(self, token_id):
+        """Get refresh token by id."""
         if self._users is None:
             await self.async_load()
 
         for user in self._users.values():
-            refresh_token = user.refresh_tokens.get(token)
+            refresh_token = user.refresh_tokens.get(token_id)
             if refresh_token is not None:
                 return refresh_token
 
         return None
 
+    async def async_get_refresh_token_by_token(self, token):
+        """Get refresh token by token."""
+        if self._users is None:
+            await self.async_load()
+
+        found = None
+
+        for user in self._users.values():
+            for refresh_token in user.refresh_tokens.values():
+                if hmac.compare_digest(refresh_token.token, token):
+                    found = refresh_token
+
+        return found
+
     async def async_load(self):
         """Load the users."""
         data = await self._store.async_load()
@@ -153,9 +168,11 @@ class AuthStore:
                 data=cred_dict['data'],
             ))
 
-        refresh_tokens = OrderedDict()
-
         for rt_dict in data['refresh_tokens']:
+            # Filter out the old keys that don't have jwt_key (pre-0.76)
+            if 'jwt_key' not in rt_dict:
+                continue
+
             token = models.RefreshToken(
                 id=rt_dict['id'],
                 user=users[rt_dict['user_id']],
@@ -164,18 +181,9 @@ class AuthStore:
                 access_token_expiration=timedelta(
                     seconds=rt_dict['access_token_expiration']),
                 token=rt_dict['token'],
+                jwt_key=rt_dict['jwt_key']
             )
-            refresh_tokens[token.id] = token
-            users[rt_dict['user_id']].refresh_tokens[token.token] = token
-
-        for ac_dict in data['access_tokens']:
-            refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
-            token = models.AccessToken(
-                refresh_token=refresh_token,
-                created_at=dt_util.parse_datetime(ac_dict['created_at']),
-                token=ac_dict['token'],
-            )
-            refresh_token.access_tokens.append(token)
+            users[rt_dict['user_id']].refresh_tokens[token.id] = token
 
         self._users = users
 
@@ -213,27 +221,15 @@ class AuthStore:
                 'access_token_expiration':
                     refresh_token.access_token_expiration.total_seconds(),
                 'token': refresh_token.token,
+                'jwt_key': refresh_token.jwt_key,
             }
             for user in self._users.values()
             for refresh_token in user.refresh_tokens.values()
         ]
 
-        access_tokens = [
-            {
-                'id': user.id,
-                'refresh_token_id': refresh_token.id,
-                'created_at': access_token.created_at.isoformat(),
-                'token': access_token.token,
-            }
-            for user in self._users.values()
-            for refresh_token in user.refresh_tokens.values()
-            for access_token in refresh_token.access_tokens
-        ]
-
         data = {
             'users': users,
             'credentials': credentials,
-            'access_tokens': access_tokens,
             'refresh_tokens': refresh_tokens,
         }
 
diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py
index 38e054dc7cf..3f49c56bce6 100644
--- a/homeassistant/auth/models.py
+++ b/homeassistant/auth/models.py
@@ -39,26 +39,8 @@ class RefreshToken:
                                       default=ACCESS_TOKEN_EXPIRATION)
     token = attr.ib(type=str,
                     default=attr.Factory(lambda: generate_secret(64)))
-    access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
-
-
-@attr.s(slots=True)
-class AccessToken:
-    """Access token to access the API.
-
-    These will only ever be stored in memory and not be persisted.
-    """
-
-    refresh_token = attr.ib(type=RefreshToken)
-    created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
-    token = attr.ib(type=str,
-                    default=attr.Factory(generate_secret))
-
-    @property
-    def expired(self):
-        """Return if this token has expired."""
-        expires = self.created_at + self.refresh_token.access_token_expiration
-        return dt_util.utcnow() > expires
+    jwt_key = attr.ib(type=str,
+                      default=attr.Factory(lambda: generate_secret(64)))
 
 
 @attr.s(slots=True)
diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py
index 0b2b4fb1a2e..102bfe58b55 100644
--- a/homeassistant/components/auth/__init__.py
+++ b/homeassistant/components/auth/__init__.py
@@ -155,7 +155,7 @@ class GrantTokenView(HomeAssistantView):
         access_token = hass.auth.async_create_access_token(refresh_token)
 
         return self.json({
-            'access_token': access_token.token,
+            'access_token': access_token,
             'token_type': 'Bearer',
             'refresh_token': refresh_token.token,
             'expires_in':
@@ -178,7 +178,7 @@ class GrantTokenView(HomeAssistantView):
                 'error': 'invalid_request',
             }, status_code=400)
 
-        refresh_token = await hass.auth.async_get_refresh_token(token)
+        refresh_token = await hass.auth.async_get_refresh_token_by_token(token)
 
         if refresh_token is None:
             return self.json({
@@ -193,7 +193,7 @@ class GrantTokenView(HomeAssistantView):
         access_token = hass.auth.async_create_access_token(refresh_token)
 
         return self.json({
-            'access_token': access_token.token,
+            'access_token': access_token,
             'token_type': 'Bearer',
             'expires_in':
                 int(refresh_token.access_token_expiration.total_seconds()),
diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py
index 77621e3bc7c..d01d1b50c5a 100644
--- a/homeassistant/components/http/auth.py
+++ b/homeassistant/components/http/auth.py
@@ -106,11 +106,11 @@ async def async_validate_auth_header(request, api_password=None):
 
     if auth_type == 'Bearer':
         hass = request.app['hass']
-        access_token = hass.auth.async_get_access_token(auth_val)
-        if access_token is None:
+        refresh_token = await hass.auth.async_validate_access_token(auth_val)
+        if refresh_token is None:
             return False
 
-        request['hass_user'] = access_token.refresh_token.user
+        request['hass_user'] = refresh_token.user
         return True
 
     if auth_type == 'Basic' and api_password is not None:
diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api.py
index 2a1e808188a..36811337ec1 100644
--- a/homeassistant/components/websocket_api.py
+++ b/homeassistant/components/websocket_api.py
@@ -355,11 +355,12 @@ class ActiveConnection:
 
                 if self.hass.auth.active and 'access_token' in msg:
                     self.debug("Received access_token")
-                    token = self.hass.auth.async_get_access_token(
-                        msg['access_token'])
-                    authenticated = token is not None
+                    refresh_token = \
+                        await self.hass.auth.async_validate_access_token(
+                            msg['access_token'])
+                    authenticated = refresh_token is not None
                     if authenticated:
-                        request['hass_user'] = token.refresh_token.user
+                        request['hass_user'] = refresh_token.user
 
                 elif ((not self.hass.auth.active or
                        self.hass.auth.support_legacy) and
diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt
index 29e10838f21..3aa1e3643c6 100644
--- a/homeassistant/package_constraints.txt
+++ b/homeassistant/package_constraints.txt
@@ -4,6 +4,7 @@ async_timeout==3.0.0
 attrs==18.1.0
 certifi>=2018.04.16
 jinja2>=2.10
+PyJWT==1.6.4
 pip>=8.0.3
 pytz>=2018.04
 pyyaml>=3.13,<4
diff --git a/requirements_all.txt b/requirements_all.txt
index 52c3168991d..cf64fde7c64 100644
--- a/requirements_all.txt
+++ b/requirements_all.txt
@@ -5,6 +5,7 @@ async_timeout==3.0.0
 attrs==18.1.0
 certifi>=2018.04.16
 jinja2>=2.10
+PyJWT==1.6.4
 pip>=8.0.3
 pytz>=2018.04
 pyyaml>=3.13,<4
diff --git a/setup.py b/setup.py
index b319df9067d..bd1e70aa8ae 100755
--- a/setup.py
+++ b/setup.py
@@ -38,6 +38,7 @@ REQUIRES = [
     'attrs==18.1.0',
     'certifi>=2018.04.16',
     'jinja2>=2.10',
+    'PyJWT==1.6.4',
     'pip>=8.0.3',
     'pytz>=2018.04',
     'pyyaml>=3.13,<4',
diff --git a/tests/auth/test_init.py b/tests/auth/test_init.py
index cad4bbdbd71..da5daca7cf6 100644
--- a/tests/auth/test_init.py
+++ b/tests/auth/test_init.py
@@ -199,9 +199,7 @@ async def test_saving_loading(hass, hass_storage):
     })
     user = await manager.async_get_or_create_user(step['result'])
     await manager.async_activate_user(user)
-    refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
-
-    manager.async_create_access_token(refresh_token)
+    await manager.async_create_refresh_token(user, CLIENT_ID)
 
     await flush_store(manager._store._store)
 
@@ -211,30 +209,6 @@ async def test_saving_loading(hass, hass_storage):
     assert users[0] == user
 
 
-def test_access_token_expired():
-    """Test that the expired property on access tokens work."""
-    refresh_token = auth_models.RefreshToken(
-        user=None,
-        client_id='bla'
-    )
-
-    access_token = auth_models.AccessToken(
-        refresh_token=refresh_token
-    )
-
-    assert access_token.expired is False
-
-    with patch('homeassistant.util.dt.utcnow',
-               return_value=dt_util.utcnow() +
-               auth_const.ACCESS_TOKEN_EXPIRATION):
-        assert access_token.expired is True
-
-    almost_exp = \
-        dt_util.utcnow() + auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(1)
-    with patch('homeassistant.util.dt.utcnow', return_value=almost_exp):
-        assert access_token.expired is False
-
-
 async def test_cannot_retrieve_expired_access_token(hass):
     """Test that we cannot retrieve expired access tokens."""
     manager = await auth.auth_manager_from_config(hass, [])
@@ -244,15 +218,20 @@ async def test_cannot_retrieve_expired_access_token(hass):
     assert refresh_token.client_id == CLIENT_ID
 
     access_token = manager.async_create_access_token(refresh_token)
-    assert manager.async_get_access_token(access_token.token) is access_token
+    assert (
+        await manager.async_validate_access_token(access_token)
+        is refresh_token
+    )
 
     with patch('homeassistant.util.dt.utcnow',
-               return_value=dt_util.utcnow() +
-               auth_const.ACCESS_TOKEN_EXPIRATION):
-        assert manager.async_get_access_token(access_token.token) is None
+               return_value=dt_util.utcnow() -
+               auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(seconds=11)):
+        access_token = manager.async_create_access_token(refresh_token)
 
-    # Even with unpatched time, it should have been removed from manager
-    assert manager.async_get_access_token(access_token.token) is None
+    assert (
+        await manager.async_validate_access_token(access_token)
+        is None
+    )
 
 
 async def test_generating_system_user(hass):
diff --git a/tests/common.py b/tests/common.py
index df333cca735..81e4774ccd4 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -314,12 +314,18 @@ def mock_registry(hass, mock_entries=None):
 class MockUser(auth_models.User):
     """Mock a user in Home Assistant."""
 
-    def __init__(self, id='mock-id', is_owner=False, is_active=True,
+    def __init__(self, id=None, is_owner=False, is_active=True,
                  name='Mock User', system_generated=False):
         """Initialize mock user."""
-        super().__init__(
-            id=id, is_owner=is_owner, is_active=is_active, name=name,
-            system_generated=system_generated)
+        kwargs = {
+            'is_owner': is_owner,
+            'is_active': is_active,
+            'name': name,
+            'system_generated': system_generated
+        }
+        if id is not None:
+            kwargs['id'] = id
+        super().__init__(**kwargs)
 
     def add_to_hass(self, hass):
         """Test helper to add entry to hass."""
diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py
index eea768c96a0..f1a1bb5bd3c 100644
--- a/tests/components/auth/test_init.py
+++ b/tests/components/auth/test_init.py
@@ -44,7 +44,10 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
     assert resp.status == 200
     tokens = await resp.json()
 
-    assert hass.auth.async_get_access_token(tokens['access_token']) is not None
+    assert (
+        await hass.auth.async_validate_access_token(tokens['access_token'])
+        is not None
+    )
 
     # Use refresh token to get more tokens.
     resp = await client.post('/auth/token', data={
@@ -56,7 +59,10 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
     assert resp.status == 200
     tokens = await resp.json()
     assert 'refresh_token' not in tokens
-    assert hass.auth.async_get_access_token(tokens['access_token']) is not None
+    assert (
+        await hass.auth.async_validate_access_token(tokens['access_token'])
+        is not None
+    )
 
     # Test using access token to hit API.
     resp = await client.get('/api/')
@@ -98,7 +104,9 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
         }
     })
 
-    user = hass_access_token.refresh_token.user
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    user = refresh_token.user
     credential = Credentials(auth_provider_type='homeassistant',
                              auth_provider_id=None,
                              data={}, id='test-id')
@@ -169,7 +177,10 @@ async def test_refresh_token_system_generated(hass, aiohttp_client):
 
     assert resp.status == 200
     tokens = await resp.json()
-    assert hass.auth.async_get_access_token(tokens['access_token']) is not None
+    assert (
+        await hass.auth.async_validate_access_token(tokens['access_token'])
+        is not None
+    )
 
 
 async def test_refresh_token_different_client_id(hass, aiohttp_client):
@@ -208,4 +219,7 @@ async def test_refresh_token_different_client_id(hass, aiohttp_client):
 
     assert resp.status == 200
     tokens = await resp.json()
-    assert hass.auth.async_get_access_token(tokens['access_token']) is not None
+    assert (
+        await hass.auth.async_validate_access_token(tokens['access_token'])
+        is not None
+    )
diff --git a/tests/components/auth/test_init_link_user.py b/tests/components/auth/test_init_link_user.py
index 13515db87fa..e209e0ee856 100644
--- a/tests/components/auth/test_init_link_user.py
+++ b/tests/components/auth/test_init_link_user.py
@@ -52,7 +52,7 @@ async def async_get_code(hass, aiohttp_client):
         'user': user,
         'code': step['result'],
         'client': client,
-        'access_token': access_token.token,
+        'access_token': access_token,
     }
 
 
diff --git a/tests/components/config/test_auth.py b/tests/components/config/test_auth.py
index fe8f351955f..cd04eedf08e 100644
--- a/tests/components/config/test_auth.py
+++ b/tests/components/config/test_auth.py
@@ -122,11 +122,13 @@ async def test_delete_unable_self_account(hass, hass_ws_client,
                                           hass_access_token):
     """Test we cannot delete our own account."""
     client = await hass_ws_client(hass, hass_access_token)
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
 
     await client.send_json({
         'id': 5,
         'type': auth_config.WS_TYPE_DELETE,
-        'user_id': hass_access_token.refresh_token.user.id,
+        'user_id': refresh_token.user.id,
     })
 
     result = await client.receive_json()
@@ -137,7 +139,9 @@ async def test_delete_unable_self_account(hass, hass_ws_client,
 async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
     """Test we cannot delete an unknown user."""
     client = await hass_ws_client(hass, hass_access_token)
-    hass_access_token.refresh_token.user.is_owner = True
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_owner = True
 
     await client.send_json({
         'id': 5,
@@ -153,7 +157,9 @@ async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
 async def test_delete(hass, hass_ws_client, hass_access_token):
     """Test delete command works."""
     client = await hass_ws_client(hass, hass_access_token)
-    hass_access_token.refresh_token.user.is_owner = True
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_owner = True
     test_user = MockUser(
         id='efg',
     ).add_to_hass(hass)
@@ -174,7 +180,9 @@ async def test_delete(hass, hass_ws_client, hass_access_token):
 async def test_create(hass, hass_ws_client, hass_access_token):
     """Test create command works."""
     client = await hass_ws_client(hass, hass_access_token)
-    hass_access_token.refresh_token.user.is_owner = True
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_owner = True
 
     assert len(await hass.auth.async_get_users()) == 1
 
diff --git a/tests/components/config/test_auth_provider_homeassistant.py b/tests/components/config/test_auth_provider_homeassistant.py
index cd2cbc44539..a374083c2ab 100644
--- a/tests/components/config/test_auth_provider_homeassistant.py
+++ b/tests/components/config/test_auth_provider_homeassistant.py
@@ -9,7 +9,7 @@ from tests.common import MockUser, register_auth_provider
 
 
 @pytest.fixture(autouse=True)
-def setup_config(hass, aiohttp_client):
+def setup_config(hass):
     """Fixture that sets up the auth provider homeassistant module."""
     hass.loop.run_until_complete(register_auth_provider(hass, {
         'type': 'homeassistant'
@@ -22,7 +22,9 @@ async def test_create_auth_system_generated_user(hass, hass_access_token,
     """Test we can't add auth to system generated users."""
     system_user = MockUser(system_generated=True).add_to_hass(hass)
     client = await hass_ws_client(hass, hass_access_token)
-    hass_access_token.refresh_token.user.is_owner = True
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_owner = True
 
     await client.send_json({
         'id': 5,
@@ -47,7 +49,9 @@ async def test_create_auth_unknown_user(hass_ws_client, hass,
                                         hass_access_token):
     """Test create pointing at unknown user."""
     client = await hass_ws_client(hass, hass_access_token)
-    hass_access_token.refresh_token.user.is_owner = True
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_owner = True
 
     await client.send_json({
         'id': 5,
@@ -86,7 +90,9 @@ async def test_create_auth(hass, hass_ws_client, hass_access_token,
     """Test create auth command works."""
     client = await hass_ws_client(hass, hass_access_token)
     user = MockUser().add_to_hass(hass)
-    hass_access_token.refresh_token.user.is_owner = True
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_owner = True
 
     assert len(user.credentials) == 0
 
@@ -117,7 +123,9 @@ async def test_create_auth_duplicate_username(hass, hass_ws_client,
     """Test we can't create auth with a duplicate username."""
     client = await hass_ws_client(hass, hass_access_token)
     user = MockUser().add_to_hass(hass)
-    hass_access_token.refresh_token.user.is_owner = True
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_owner = True
 
     hass_storage[prov_ha.STORAGE_KEY] = {
         'version': 1,
@@ -145,7 +153,9 @@ async def test_delete_removes_just_auth(hass_ws_client, hass, hass_storage,
                                         hass_access_token):
     """Test deleting an auth without being connected to a user."""
     client = await hass_ws_client(hass, hass_access_token)
-    hass_access_token.refresh_token.user.is_owner = True
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_owner = True
 
     hass_storage[prov_ha.STORAGE_KEY] = {
         'version': 1,
@@ -171,7 +181,9 @@ async def test_delete_removes_credential(hass, hass_ws_client,
                                          hass_access_token, hass_storage):
     """Test deleting auth that is connected to a user."""
     client = await hass_ws_client(hass, hass_access_token)
-    hass_access_token.refresh_token.user.is_owner = True
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_owner = True
 
     user = MockUser().add_to_hass(hass)
     user.credentials.append(
@@ -216,7 +228,9 @@ async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token):
 async def test_delete_unknown_auth(hass, hass_ws_client, hass_access_token):
     """Test trying to delete an unknown auth username."""
     client = await hass_ws_client(hass, hass_access_token)
-    hass_access_token.refresh_token.user.is_owner = True
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_owner = True
 
     await client.send_json({
         'id': 5,
@@ -240,7 +254,9 @@ async def test_change_password(hass, hass_ws_client, hass_access_token):
         'username': 'test-user'
     })
 
-    user = hass_access_token.refresh_token.user
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    user = refresh_token.user
     await hass.auth.async_link_user(user, credentials)
 
     client = await hass_ws_client(hass, hass_access_token)
@@ -268,7 +284,9 @@ async def test_change_password_wrong_pw(hass, hass_ws_client,
         'username': 'test-user'
     })
 
-    user = hass_access_token.refresh_token.user
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    user = refresh_token.user
     await hass.auth.async_link_user(user, credentials)
 
     client = await hass_ws_client(hass, hass_access_token)
diff --git a/tests/components/conftest.py b/tests/components/conftest.py
index 5f6a17a4101..bb9b643296e 100644
--- a/tests/components/conftest.py
+++ b/tests/components/conftest.py
@@ -28,7 +28,7 @@ def hass_ws_client(aiohttp_client):
 
         await websocket.send_json({
             'type': websocket_api.TYPE_AUTH,
-            'access_token': access_token.token
+            'access_token': access_token
         })
 
         auth_ok = await websocket.receive_json()
diff --git a/tests/components/hassio/test_init.py b/tests/components/hassio/test_init.py
index b1975669731..4fd59dd3f7a 100644
--- a/tests/components/hassio/test_init.py
+++ b/tests/components/hassio/test_init.py
@@ -106,7 +106,11 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
     )
     assert hassio_user is not None
     assert hassio_user.system_generated
-    assert refresh_token in hassio_user.refresh_tokens
+    for token in hassio_user.refresh_tokens.values():
+        if token.token == refresh_token:
+            break
+    else:
+        assert False, 'refresh token not found'
 
 
 async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock,
diff --git a/tests/components/http/test_auth.py b/tests/components/http/test_auth.py
index 31cba79a6c8..8e7a62e2e9f 100644
--- a/tests/components/http/test_auth.py
+++ b/tests/components/http/test_auth.py
@@ -156,9 +156,9 @@ async def test_access_with_trusted_ip(app2, aiohttp_client):
 
 
 async def test_auth_active_access_with_access_token_in_header(
-        app, aiohttp_client, hass_access_token):
+        hass, app, aiohttp_client, hass_access_token):
     """Test access with access token in header."""
-    token = hass_access_token.token
+    token = hass_access_token
     setup_auth(app, [], True, api_password=None)
     client = await aiohttp_client(app)
 
@@ -182,7 +182,9 @@ async def test_auth_active_access_with_access_token_in_header(
         '/', headers={'Authorization': 'BEARER {}'.format(token)})
     assert req.status == 401
 
-    hass_access_token.refresh_token.user.is_active = False
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_active = False
     req = await client.get(
         '/', headers={'Authorization': 'Bearer {}'.format(token)})
     assert req.status == 401
diff --git a/tests/components/test_api.py b/tests/components/test_api.py
index 09dc27e97c1..2be1168b86a 100644
--- a/tests/components/test_api.py
+++ b/tests/components/test_api.py
@@ -448,13 +448,15 @@ async def test_api_fire_event_context(hass, mock_api_client,
     await mock_api_client.post(
         const.URL_API_EVENTS_EVENT.format("test.event"),
         headers={
-            'authorization': 'Bearer {}'.format(hass_access_token.token)
+            'authorization': 'Bearer {}'.format(hass_access_token)
         })
     await hass.async_block_till_done()
 
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+
     assert len(test_value) == 1
-    assert test_value[0].context.user_id == \
-        hass_access_token.refresh_token.user.id
+    assert test_value[0].context.user_id == refresh_token.user.id
 
 
 async def test_api_call_service_context(hass, mock_api_client,
@@ -465,12 +467,15 @@ async def test_api_call_service_context(hass, mock_api_client,
     await mock_api_client.post(
         '/api/services/test_domain/test_service',
         headers={
-            'authorization': 'Bearer {}'.format(hass_access_token.token)
+            'authorization': 'Bearer {}'.format(hass_access_token)
         })
     await hass.async_block_till_done()
 
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+
     assert len(calls) == 1
-    assert calls[0].context.user_id == hass_access_token.refresh_token.user.id
+    assert calls[0].context.user_id == refresh_token.user.id
 
 
 async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
@@ -481,8 +486,11 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
             'state': 'on'
         },
         headers={
-            'authorization': 'Bearer {}'.format(hass_access_token.token)
+            'authorization': 'Bearer {}'.format(hass_access_token)
         })
 
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+
     state = hass.states.get('light.kitchen')
-    assert state.context.user_id == hass_access_token.refresh_token.user.id
+    assert state.context.user_id == refresh_token.user.id
diff --git a/tests/components/test_websocket_api.py b/tests/components/test_websocket_api.py
index 1fac1af9f64..199a9d804f8 100644
--- a/tests/components/test_websocket_api.py
+++ b/tests/components/test_websocket_api.py
@@ -334,7 +334,7 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
 
             await ws.send_json({
                 'type': wapi.TYPE_AUTH,
-                'access_token': hass_access_token.token
+                'access_token': hass_access_token
             })
 
             auth_msg = await ws.receive_json()
@@ -344,7 +344,9 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
 async def test_auth_active_user_inactive(hass, aiohttp_client,
                                          hass_access_token):
     """Test authenticating with a token."""
-    hass_access_token.refresh_token.user.is_active = False
+    refresh_token = await hass.auth.async_validate_access_token(
+        hass_access_token)
+    refresh_token.user.is_active = False
     assert await async_setup_component(hass, 'websocket_api', {
         'http': {
             'api_password': API_PASSWORD
@@ -361,7 +363,7 @@ async def test_auth_active_user_inactive(hass, aiohttp_client,
 
             await ws.send_json({
                 'type': wapi.TYPE_AUTH,
-                'access_token': hass_access_token.token
+                'access_token': hass_access_token
             })
 
             auth_msg = await ws.receive_json()
@@ -465,7 +467,7 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
 
             await ws.send_json({
                 'type': wapi.TYPE_AUTH,
-                'access_token': hass_access_token.token
+                'access_token': hass_access_token
             })
 
             auth_msg = await ws.receive_json()
@@ -484,12 +486,15 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
         msg = await ws.receive_json()
         assert msg['success']
 
+        refresh_token = await hass.auth.async_validate_access_token(
+            hass_access_token)
+
         assert len(calls) == 1
         call = calls[0]
         assert call.domain == 'domain_test'
         assert call.service == 'test_service'
         assert call.data == {'hello': 'world'}
-        assert call.context.user_id == hass_access_token.refresh_token.user.id
+        assert call.context.user_id == refresh_token.user.id
 
 
 async def test_call_service_context_no_user(hass, aiohttp_client):
-- 
GitLab