Skip to content
Snippets Groups Projects
Commit ff78a5b0 authored by Jason Hu's avatar Jason Hu Committed by Paulus Schoutsen
Browse files

Track refresh token last usage information (#16408)

* Extend refresh_token to support last_used_at and last_used_by

* Address code review comment

* Remove unused code

* Add it to websocket response

* Fix typing
parent 453cbb7c
No related branches found
No related tags found
No related merge requests found
......@@ -309,8 +309,11 @@ class AuthManager:
@callback
def async_create_access_token(self,
refresh_token: models.RefreshToken) -> str:
refresh_token: models.RefreshToken,
remote_ip: Optional[str] = None) -> str:
"""Create a new access token."""
self._store.async_log_refresh_token_usage(refresh_token, remote_ip)
# pylint: disable=no-self-use
now = dt_util.utcnow()
return jwt.encode({
......
......@@ -195,6 +195,15 @@ class AuthStore:
return found
@callback
def async_log_refresh_token_usage(
self, refresh_token: models.RefreshToken,
remote_ip: Optional[str] = None) -> None:
"""Update refresh token last used information."""
refresh_token.last_used_at = dt_util.utcnow()
refresh_token.last_used_ip = remote_ip
self._async_schedule_save()
async def _async_load(self) -> None:
"""Load the users."""
data = await self._store.async_load()
......@@ -233,12 +242,21 @@ class AuthStore:
'Ignoring refresh token %(id)s with invalid created_at '
'%(created_at)s for user_id %(user_id)s', rt_dict)
continue
token_type = rt_dict.get('token_type')
if token_type is None:
if rt_dict['client_id'] is None:
token_type = models.TOKEN_TYPE_SYSTEM
else:
token_type = models.TOKEN_TYPE_NORMAL
# old refresh_token don't have last_used_at (pre-0.78)
last_used_at_str = rt_dict.get('last_used_at')
if last_used_at_str:
last_used_at = dt_util.parse_datetime(last_used_at_str)
else:
last_used_at = None
token = models.RefreshToken(
id=rt_dict['id'],
user=users[rt_dict['user_id']],
......@@ -251,7 +269,9 @@ class AuthStore:
access_token_expiration=timedelta(
seconds=rt_dict['access_token_expiration']),
token=rt_dict['token'],
jwt_key=rt_dict['jwt_key']
jwt_key=rt_dict['jwt_key'],
last_used_at=last_used_at,
last_used_ip=rt_dict.get('last_used_ip'),
)
users[rt_dict['user_id']].refresh_tokens[token.id] = token
......@@ -306,6 +326,10 @@ class AuthStore:
refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token,
'jwt_key': refresh_token.jwt_key,
'last_used_at':
refresh_token.last_used_at.isoformat()
if refresh_token.last_used_at else None,
'last_used_ip': refresh_token.last_used_ip,
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
......
......@@ -55,13 +55,16 @@ class RefreshToken:
jwt_key = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64)))
last_used_at = attr.ib(type=Optional[datetime], default=None)
last_used_ip = attr.ib(type=Optional[str], default=None)
@attr.s(slots=True)
class Credentials:
"""Credentials for a user on an auth provider."""
auth_provider_type = attr.ib(type=str)
auth_provider_id = attr.ib(type=str) # type: Optional[str]
auth_provider_id = attr.ib(type=Optional[str])
# Allow the auth provider to store data to represent their auth.
data = attr.ib(type=dict)
......
......@@ -129,6 +129,7 @@ import voluptuous as vol
from homeassistant.auth.models import User, Credentials, \
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
from homeassistant.components import websocket_api
from homeassistant.components.http import KEY_REAL_IP
from homeassistant.components.http.ban import log_invalid_auth
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.components.http.view import HomeAssistantView
......@@ -236,10 +237,12 @@ class TokenView(HomeAssistantView):
return await self._async_handle_revoke_token(hass, data)
if grant_type == 'authorization_code':
return await self._async_handle_auth_code(hass, data)
return await self._async_handle_auth_code(
hass, data, str(request[KEY_REAL_IP]))
if grant_type == 'refresh_token':
return await self._async_handle_refresh_token(hass, data)
return await self._async_handle_refresh_token(
hass, data, str(request[KEY_REAL_IP]))
return self.json({
'error': 'unsupported_grant_type',
......@@ -264,7 +267,7 @@ class TokenView(HomeAssistantView):
await hass.auth.async_remove_refresh_token(refresh_token)
return web.Response(status=200)
async def _async_handle_auth_code(self, hass, data):
async def _async_handle_auth_code(self, hass, data, remote_addr):
"""Handle authorization code request."""
client_id = data.get('client_id')
if client_id is None or not indieauth.verify_client_id(client_id):
......@@ -300,7 +303,8 @@ class TokenView(HomeAssistantView):
refresh_token = await hass.auth.async_create_refresh_token(user,
client_id)
access_token = hass.auth.async_create_access_token(refresh_token)
access_token = hass.auth.async_create_access_token(
refresh_token, remote_addr)
return self.json({
'access_token': access_token,
......@@ -310,7 +314,7 @@ class TokenView(HomeAssistantView):
int(refresh_token.access_token_expiration.total_seconds()),
})
async def _async_handle_refresh_token(self, hass, data):
async def _async_handle_refresh_token(self, hass, data, remote_addr):
"""Handle authorization code request."""
client_id = data.get('client_id')
if client_id is not None and not indieauth.verify_client_id(client_id):
......@@ -338,7 +342,8 @@ class TokenView(HomeAssistantView):
'error': 'invalid_request',
}, status_code=400)
access_token = hass.auth.async_create_access_token(refresh_token)
access_token = hass.auth.async_create_access_token(
refresh_token, remote_addr)
return self.json({
'access_token': access_token,
......@@ -484,6 +489,8 @@ def websocket_refresh_tokens(
'type': refresh.token_type,
'created_at': refresh.created_at,
'is_current': refresh.id == current_id,
'last_used_at': refresh.last_used_at,
'last_used_ip': refresh.last_used_ip,
} for refresh in connection.user.refresh_tokens.values()]))
......
......@@ -278,7 +278,11 @@ async def test_saving_loading(hass, hass_storage):
})
user = step['result']
await manager.async_activate_user(user)
await manager.async_create_refresh_token(user, CLIENT_ID)
# the first refresh token will be used to create access token
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
manager.async_create_access_token(refresh_token, '192.168.0.1')
# the second refresh token will not be used
await manager.async_create_refresh_token(user, 'dummy-client')
await flush_store(manager._store._store)
......@@ -286,6 +290,18 @@ async def test_saving_loading(hass, hass_storage):
users = await store2.async_get_users()
assert len(users) == 1
assert users[0] == user
assert len(users[0].refresh_tokens) == 2
for r_token in users[0].refresh_tokens.values():
if r_token.client_id == CLIENT_ID:
# verify the first refresh token
assert r_token.last_used_at is not None
assert r_token.last_used_ip == '192.168.0.1'
elif r_token.client_id == 'dummy-client':
# verify the second refresh token
assert r_token.last_used_at is None
assert r_token.last_used_ip is None
else:
assert False, 'Unknown client_id: %s' % r_token.client_id
async def test_cannot_retrieve_expired_access_token(hass):
......
......@@ -321,6 +321,8 @@ async def test_ws_refresh_tokens(hass, hass_ws_client, hass_access_token):
assert token['client_icon'] == refresh_token.client_icon
assert token['created_at'] == refresh_token.created_at.isoformat()
assert token['is_current'] is True
assert token['last_used_at'] == refresh_token.last_used_at.isoformat()
assert token['last_used_ip'] == refresh_token.last_used_ip
async def test_ws_delete_refresh_token(hass, hass_ws_client,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment