From ff78a5b04b0831a5075b49508effc922392ca2c5 Mon Sep 17 00:00:00 2001
From: Jason Hu <awarecan@users.noreply.github.com>
Date: Wed, 12 Sep 2018 04:24:16 -0700
Subject: [PATCH] Track refresh token last usage information (#16408)

* Extend refresh_token to support last_used_at and last_used_by

* Address code review comment

* Remove unused code

* Add it to websocket response

* Fix typing
---
 homeassistant/auth/__init__.py            |  5 ++++-
 homeassistant/auth/auth_store.py          | 26 ++++++++++++++++++++++-
 homeassistant/auth/models.py              |  5 ++++-
 homeassistant/components/auth/__init__.py | 19 +++++++++++------
 tests/auth/test_init.py                   | 18 +++++++++++++++-
 tests/components/auth/test_init.py        |  2 ++
 6 files changed, 65 insertions(+), 10 deletions(-)

diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py
index b0cebb5fd6c..c6f978640f6 100644
--- a/homeassistant/auth/__init__.py
+++ b/homeassistant/auth/__init__.py
@@ -309,8 +309,11 @@ class AuthManager:
 
     @callback
     def async_create_access_token(self,
-                                  refresh_token: models.RefreshToken) -> str:
+                                  refresh_token: models.RefreshToken,
+                                  remote_ip: Optional[str] = None) -> str:
         """Create a new access token."""
+        self._store.async_log_refresh_token_usage(refresh_token, remote_ip)
+
         # pylint: disable=no-self-use
         now = dt_util.utcnow()
         return jwt.encode({
diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py
index 8e8d03253e5..fb4700c806f 100644
--- a/homeassistant/auth/auth_store.py
+++ b/homeassistant/auth/auth_store.py
@@ -195,6 +195,15 @@ class AuthStore:
 
         return found
 
+    @callback
+    def async_log_refresh_token_usage(
+            self, refresh_token: models.RefreshToken,
+            remote_ip: Optional[str] = None) -> None:
+        """Update refresh token last used information."""
+        refresh_token.last_used_at = dt_util.utcnow()
+        refresh_token.last_used_ip = remote_ip
+        self._async_schedule_save()
+
     async def _async_load(self) -> None:
         """Load the users."""
         data = await self._store.async_load()
@@ -233,12 +242,21 @@ class AuthStore:
                     'Ignoring refresh token %(id)s with invalid created_at '
                     '%(created_at)s for user_id %(user_id)s', rt_dict)
                 continue
+
             token_type = rt_dict.get('token_type')
             if token_type is None:
                 if rt_dict['client_id'] is None:
                     token_type = models.TOKEN_TYPE_SYSTEM
                 else:
                     token_type = models.TOKEN_TYPE_NORMAL
+
+            # old refresh_token don't have last_used_at (pre-0.78)
+            last_used_at_str = rt_dict.get('last_used_at')
+            if last_used_at_str:
+                last_used_at = dt_util.parse_datetime(last_used_at_str)
+            else:
+                last_used_at = None
+
             token = models.RefreshToken(
                 id=rt_dict['id'],
                 user=users[rt_dict['user_id']],
@@ -251,7 +269,9 @@ class AuthStore:
                 access_token_expiration=timedelta(
                     seconds=rt_dict['access_token_expiration']),
                 token=rt_dict['token'],
-                jwt_key=rt_dict['jwt_key']
+                jwt_key=rt_dict['jwt_key'],
+                last_used_at=last_used_at,
+                last_used_ip=rt_dict.get('last_used_ip'),
             )
             users[rt_dict['user_id']].refresh_tokens[token.id] = token
 
@@ -306,6 +326,10 @@ class AuthStore:
                     refresh_token.access_token_expiration.total_seconds(),
                 'token': refresh_token.token,
                 'jwt_key': refresh_token.jwt_key,
+                'last_used_at':
+                    refresh_token.last_used_at.isoformat()
+                    if refresh_token.last_used_at else None,
+                'last_used_ip': refresh_token.last_used_ip,
             }
             for user in self._users.values()
             for refresh_token in user.refresh_tokens.values()
diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py
index c5273d7fa1d..b0f4024c3ab 100644
--- a/homeassistant/auth/models.py
+++ b/homeassistant/auth/models.py
@@ -55,13 +55,16 @@ class RefreshToken:
     jwt_key = attr.ib(type=str,
                       default=attr.Factory(lambda: generate_secret(64)))
 
+    last_used_at = attr.ib(type=Optional[datetime], default=None)
+    last_used_ip = attr.ib(type=Optional[str], default=None)
+
 
 @attr.s(slots=True)
 class Credentials:
     """Credentials for a user on an auth provider."""
 
     auth_provider_type = attr.ib(type=str)
-    auth_provider_id = attr.ib(type=str)  # type: Optional[str]
+    auth_provider_id = attr.ib(type=Optional[str])
 
     # Allow the auth provider to store data to represent their auth.
     data = attr.ib(type=dict)
diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py
index 01cfe4724bf..bee72d8e4fc 100644
--- a/homeassistant/components/auth/__init__.py
+++ b/homeassistant/components/auth/__init__.py
@@ -129,6 +129,7 @@ import voluptuous as vol
 from homeassistant.auth.models import User, Credentials, \
     TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
 from homeassistant.components import websocket_api
+from homeassistant.components.http import KEY_REAL_IP
 from homeassistant.components.http.ban import log_invalid_auth
 from homeassistant.components.http.data_validator import RequestDataValidator
 from homeassistant.components.http.view import HomeAssistantView
@@ -236,10 +237,12 @@ class TokenView(HomeAssistantView):
             return await self._async_handle_revoke_token(hass, data)
 
         if grant_type == 'authorization_code':
-            return await self._async_handle_auth_code(hass, data)
+            return await self._async_handle_auth_code(
+                hass, data, str(request[KEY_REAL_IP]))
 
         if grant_type == 'refresh_token':
-            return await self._async_handle_refresh_token(hass, data)
+            return await self._async_handle_refresh_token(
+                hass, data, str(request[KEY_REAL_IP]))
 
         return self.json({
             'error': 'unsupported_grant_type',
@@ -264,7 +267,7 @@ class TokenView(HomeAssistantView):
         await hass.auth.async_remove_refresh_token(refresh_token)
         return web.Response(status=200)
 
-    async def _async_handle_auth_code(self, hass, data):
+    async def _async_handle_auth_code(self, hass, data, remote_addr):
         """Handle authorization code request."""
         client_id = data.get('client_id')
         if client_id is None or not indieauth.verify_client_id(client_id):
@@ -300,7 +303,8 @@ class TokenView(HomeAssistantView):
 
         refresh_token = await hass.auth.async_create_refresh_token(user,
                                                                    client_id)
-        access_token = hass.auth.async_create_access_token(refresh_token)
+        access_token = hass.auth.async_create_access_token(
+            refresh_token, remote_addr)
 
         return self.json({
             'access_token': access_token,
@@ -310,7 +314,7 @@ class TokenView(HomeAssistantView):
                 int(refresh_token.access_token_expiration.total_seconds()),
         })
 
-    async def _async_handle_refresh_token(self, hass, data):
+    async def _async_handle_refresh_token(self, hass, data, remote_addr):
         """Handle authorization code request."""
         client_id = data.get('client_id')
         if client_id is not None and not indieauth.verify_client_id(client_id):
@@ -338,7 +342,8 @@ class TokenView(HomeAssistantView):
                 'error': 'invalid_request',
             }, status_code=400)
 
-        access_token = hass.auth.async_create_access_token(refresh_token)
+        access_token = hass.auth.async_create_access_token(
+            refresh_token, remote_addr)
 
         return self.json({
             'access_token': access_token,
@@ -484,6 +489,8 @@ def websocket_refresh_tokens(
         'type': refresh.token_type,
         'created_at': refresh.created_at,
         'is_current': refresh.id == current_id,
+        'last_used_at': refresh.last_used_at,
+        'last_used_ip': refresh.last_used_ip,
     } for refresh in connection.user.refresh_tokens.values()]))
 
 
diff --git a/tests/auth/test_init.py b/tests/auth/test_init.py
index 765199b256c..8325bd2551a 100644
--- a/tests/auth/test_init.py
+++ b/tests/auth/test_init.py
@@ -278,7 +278,11 @@ async def test_saving_loading(hass, hass_storage):
     })
     user = step['result']
     await manager.async_activate_user(user)
-    await manager.async_create_refresh_token(user, CLIENT_ID)
+    # the first refresh token will be used to create access token
+    refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
+    manager.async_create_access_token(refresh_token, '192.168.0.1')
+    # the second refresh token will not be used
+    await manager.async_create_refresh_token(user, 'dummy-client')
 
     await flush_store(manager._store._store)
 
@@ -286,6 +290,18 @@ async def test_saving_loading(hass, hass_storage):
     users = await store2.async_get_users()
     assert len(users) == 1
     assert users[0] == user
+    assert len(users[0].refresh_tokens) == 2
+    for r_token in users[0].refresh_tokens.values():
+        if r_token.client_id == CLIENT_ID:
+            # verify the first refresh token
+            assert r_token.last_used_at is not None
+            assert r_token.last_used_ip == '192.168.0.1'
+        elif r_token.client_id == 'dummy-client':
+            # verify the second refresh token
+            assert r_token.last_used_at is None
+            assert r_token.last_used_ip is None
+        else:
+            assert False, 'Unknown client_id: %s' % r_token.client_id
 
 
 async def test_cannot_retrieve_expired_access_token(hass):
diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py
index ad2aa01737b..a3974553661 100644
--- a/tests/components/auth/test_init.py
+++ b/tests/components/auth/test_init.py
@@ -321,6 +321,8 @@ async def test_ws_refresh_tokens(hass, hass_ws_client, hass_access_token):
     assert token['client_icon'] == refresh_token.client_icon
     assert token['created_at'] == refresh_token.created_at.isoformat()
     assert token['is_current'] is True
+    assert token['last_used_at'] == refresh_token.last_used_at.isoformat()
+    assert token['last_used_ip'] == refresh_token.last_used_ip
 
 
 async def test_ws_delete_refresh_token(hass, hass_ws_client,
-- 
GitLab