Skip to content
Snippets Groups Projects
Commit a64a66dd authored by Jason Hu's avatar Jason Hu Committed by Paulus Schoutsen
Browse files

Only create front-end client_id once (#15214)

* Only create frontend client_id once

* Check user and client_id before create refresh token

* Lint

* Follow code review comment

* Minor clenaup

* Update doc string
parent dffe3676
Branches
Tags
No related merge requests found
"""Provide an authentication layer for Home Assistant.""" """Provide an authentication layer for Home Assistant."""
import asyncio import asyncio
import binascii import binascii
from collections import OrderedDict
from datetime import datetime, timedelta
import os
import importlib import importlib
import logging import logging
import os
import uuid import uuid
from collections import OrderedDict
from datetime import datetime, timedelta
import attr import attr
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements from homeassistant import data_entry_flow, requirements
from homeassistant.core import callback
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
from homeassistant.util.decorator import Registry from homeassistant.core import callback
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
...@@ -349,6 +348,16 @@ class AuthManager: ...@@ -349,6 +348,16 @@ class AuthManager:
return await self._store.async_create_client( return await self._store.async_create_client(
name, redirect_uris, no_secret) name, redirect_uris, no_secret)
async def async_get_or_create_client(self, name, *, redirect_uris=None,
no_secret=False):
"""Find a client, if not exists, create a new one."""
for client in await self._store.async_get_clients():
if client.name == name:
return client
return await self._store.async_create_client(
name, redirect_uris, no_secret)
async def async_get_client(self, client_id): async def async_get_client(self, client_id):
"""Get a client.""" """Get a client."""
return await self._store.async_get_client(client_id) return await self._store.async_get_client(client_id)
...@@ -392,29 +401,36 @@ class AuthStore: ...@@ -392,29 +401,36 @@ class AuthStore:
def __init__(self, hass): def __init__(self, hass):
"""Initialize the auth store.""" """Initialize the auth store."""
self.hass = hass self.hass = hass
self.users = None self._users = None
self.clients = None self._clients = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
async def credentials_for_provider(self, provider_type, provider_id): async def credentials_for_provider(self, provider_type, provider_id):
"""Return credentials for specific auth provider type and id.""" """Return credentials for specific auth provider type and id."""
if self.users is None: if self._users is None:
await self.async_load() await self.async_load()
return [ return [
credentials credentials
for user in self.users.values() for user in self._users.values()
for credentials in user.credentials for credentials in user.credentials
if (credentials.auth_provider_type == provider_type and if (credentials.auth_provider_type == provider_type and
credentials.auth_provider_id == provider_id) credentials.auth_provider_id == provider_id)
] ]
async def async_get_users(self):
"""Retrieve all users."""
if self._users is None:
await self.async_load()
return list(self._users.values())
async def async_get_user(self, user_id): async def async_get_user(self, user_id):
"""Retrieve a user.""" """Retrieve a user."""
if self.users is None: if self._users is None:
await self.async_load() await self.async_load()
return self.users.get(user_id) return self._users.get(user_id)
async def async_get_or_create_user(self, credentials, auth_provider): async def async_get_or_create_user(self, credentials, auth_provider):
"""Get or create a new user for given credentials. """Get or create a new user for given credentials.
...@@ -422,7 +438,7 @@ class AuthStore: ...@@ -422,7 +438,7 @@ class AuthStore:
If link_user is passed in, the credentials will be linked to the passed If link_user is passed in, the credentials will be linked to the passed
in user if the credentials are new. in user if the credentials are new.
""" """
if self.users is None: if self._users is None:
await self.async_load() await self.async_load()
# New credentials, store in user # New credentials, store in user
...@@ -430,7 +446,7 @@ class AuthStore: ...@@ -430,7 +446,7 @@ class AuthStore:
info = await auth_provider.async_user_meta_for_credentials( info = await auth_provider.async_user_meta_for_credentials(
credentials) credentials)
# Make owner and activate user if it's the first user. # Make owner and activate user if it's the first user.
if self.users: if self._users:
is_owner = False is_owner = False
is_active = False is_active = False
else: else:
...@@ -442,11 +458,11 @@ class AuthStore: ...@@ -442,11 +458,11 @@ class AuthStore:
is_active=is_active, is_active=is_active,
name=info.get('name'), name=info.get('name'),
) )
self.users[new_user.id] = new_user self._users[new_user.id] = new_user
await self.async_link_user(new_user, credentials) await self.async_link_user(new_user, credentials)
return new_user return new_user
for user in self.users.values(): for user in self._users.values():
for creds in user.credentials: for creds in user.credentials:
if (creds.auth_provider_type == credentials.auth_provider_type if (creds.auth_provider_type == credentials.auth_provider_type
and creds.auth_provider_id == and creds.auth_provider_id ==
...@@ -463,11 +479,19 @@ class AuthStore: ...@@ -463,11 +479,19 @@ class AuthStore:
async def async_remove_user(self, user): async def async_remove_user(self, user):
"""Remove a user.""" """Remove a user."""
self.users.pop(user.id) self._users.pop(user.id)
await self.async_save() await self.async_save()
async def async_create_refresh_token(self, user, client_id): async def async_create_refresh_token(self, user, client_id):
"""Create a new token for a user.""" """Create a new token for a user."""
local_user = await self.async_get_user(user.id)
if local_user is None:
raise ValueError('Invalid user')
local_client = await self.async_get_client(client_id)
if local_client is None:
raise ValueError('Invalid client_id')
refresh_token = RefreshToken(user, client_id) refresh_token = RefreshToken(user, client_id)
user.refresh_tokens[refresh_token.token] = refresh_token user.refresh_tokens[refresh_token.token] = refresh_token
await self.async_save() await self.async_save()
...@@ -475,10 +499,10 @@ class AuthStore: ...@@ -475,10 +499,10 @@ class AuthStore:
async def async_get_refresh_token(self, token): async def async_get_refresh_token(self, token):
"""Get refresh token by token.""" """Get refresh token by token."""
if self.users is None: if self._users is None:
await self.async_load() await self.async_load()
for user in self.users.values(): for user in self._users.values():
refresh_token = user.refresh_tokens.get(token) refresh_token = user.refresh_tokens.get(token)
if refresh_token is not None: if refresh_token is not None:
return refresh_token return refresh_token
...@@ -487,7 +511,7 @@ class AuthStore: ...@@ -487,7 +511,7 @@ class AuthStore:
async def async_create_client(self, name, redirect_uris, no_secret): async def async_create_client(self, name, redirect_uris, no_secret):
"""Create a new client.""" """Create a new client."""
if self.clients is None: if self._clients is None:
await self.async_load() await self.async_load()
kwargs = { kwargs = {
...@@ -499,16 +523,23 @@ class AuthStore: ...@@ -499,16 +523,23 @@ class AuthStore:
kwargs['secret'] = None kwargs['secret'] = None
client = Client(**kwargs) client = Client(**kwargs)
self.clients[client.id] = client self._clients[client.id] = client
await self.async_save() await self.async_save()
return client return client
async def async_get_clients(self):
"""Return all clients."""
if self._clients is None:
await self.async_load()
return list(self._clients.values())
async def async_get_client(self, client_id): async def async_get_client(self, client_id):
"""Get a client.""" """Get a client."""
if self.clients is None: if self._clients is None:
await self.async_load() await self.async_load()
return self.clients.get(client_id) return self._clients.get(client_id)
async def async_load(self): async def async_load(self):
"""Load the users.""" """Load the users."""
...@@ -516,12 +547,12 @@ class AuthStore: ...@@ -516,12 +547,12 @@ class AuthStore:
# Make sure that we're not overriding data if 2 loads happened at the # Make sure that we're not overriding data if 2 loads happened at the
# same time # same time
if self.users is not None: if self._users is not None:
return return
if data is None: if data is None:
self.users = {} self._users = {}
self.clients = {} self._clients = {}
return return
users = { users = {
...@@ -565,8 +596,8 @@ class AuthStore: ...@@ -565,8 +596,8 @@ class AuthStore:
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients'] cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
} }
self.users = users self._users = users
self.clients = clients self._clients = clients
async def async_save(self): async def async_save(self):
"""Save users.""" """Save users."""
...@@ -577,7 +608,7 @@ class AuthStore: ...@@ -577,7 +608,7 @@ class AuthStore:
'is_active': user.is_active, 'is_active': user.is_active,
'name': user.name, 'name': user.name,
} }
for user in self.users.values() for user in self._users.values()
] ]
credentials = [ credentials = [
...@@ -588,7 +619,7 @@ class AuthStore: ...@@ -588,7 +619,7 @@ class AuthStore:
'auth_provider_id': credential.auth_provider_id, 'auth_provider_id': credential.auth_provider_id,
'data': credential.data, 'data': credential.data,
} }
for user in self.users.values() for user in self._users.values()
for credential in user.credentials for credential in user.credentials
] ]
...@@ -602,7 +633,7 @@ class AuthStore: ...@@ -602,7 +633,7 @@ class AuthStore:
refresh_token.access_token_expiration.total_seconds(), refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token, 'token': refresh_token.token,
} }
for user in self.users.values() for user in self._users.values()
for refresh_token in user.refresh_tokens.values() for refresh_token in user.refresh_tokens.values()
] ]
...@@ -613,7 +644,7 @@ class AuthStore: ...@@ -613,7 +644,7 @@ class AuthStore:
'created_at': access_token.created_at.isoformat(), 'created_at': access_token.created_at.isoformat(),
'token': access_token.token, 'token': access_token.token,
} }
for user in self.users.values() for user in self._users.values()
for refresh_token in user.refresh_tokens.values() for refresh_token in user.refresh_tokens.values()
for access_token in refresh_token.access_tokens for access_token in refresh_token.access_tokens
] ]
...@@ -625,7 +656,7 @@ class AuthStore: ...@@ -625,7 +656,7 @@ class AuthStore:
'secret': client.secret, 'secret': client.secret,
'redirect_uris': client.redirect_uris, 'redirect_uris': client.redirect_uris,
} }
for client in self.clients.values() for client in self._clients.values()
] ]
data = { data = {
......
...@@ -201,7 +201,7 @@ def add_manifest_json_key(key, val): ...@@ -201,7 +201,7 @@ def add_manifest_json_key(key, val):
async def async_setup(hass, config): async def async_setup(hass, config):
"""Set up the serving of the frontend.""" """Set up the serving of the frontend."""
if hass.auth.active: if hass.auth.active:
client = await hass.auth.async_create_client( client = await hass.auth.async_get_or_create_client(
'Home Assistant Frontend', 'Home Assistant Frontend',
redirect_uris=['/'], redirect_uris=['/'],
no_secret=True, no_secret=True,
......
...@@ -321,7 +321,7 @@ class MockUser(auth.User): ...@@ -321,7 +321,7 @@ class MockUser(auth.User):
def add_to_auth_manager(self, auth_mgr): def add_to_auth_manager(self, auth_mgr):
"""Test helper to add entry to hass.""" """Test helper to add entry to hass."""
ensure_auth_manager_loaded(auth_mgr) ensure_auth_manager_loaded(auth_mgr)
auth_mgr._store.users[self.id] = self auth_mgr._store._users[self.id] = self
return self return self
...@@ -329,10 +329,10 @@ class MockUser(auth.User): ...@@ -329,10 +329,10 @@ class MockUser(auth.User):
def ensure_auth_manager_loaded(auth_mgr): def ensure_auth_manager_loaded(auth_mgr):
"""Ensure an auth manager is considered loaded.""" """Ensure an auth manager is considered loaded."""
store = auth_mgr._store store = auth_mgr._store
if store.clients is None: if store._clients is None:
store.clients = {} store._clients = {}
if store.users is None: if store._users is None:
store.users = {} store._users = {}
class MockModule(object): class MockModule(object):
......
...@@ -34,7 +34,7 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG, ...@@ -34,7 +34,7 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
}) })
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET, client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET,
redirect_uris=[CLIENT_REDIRECT_URI]) redirect_uris=[CLIENT_REDIRECT_URI])
hass.auth._store.clients[client.id] = client hass.auth._store._clients[client.id] = client
if setup_api: if setup_api:
await async_setup_component(hass, 'api', {}) await async_setup_component(hass, 'api', {})
return await aiohttp_client(hass.http.app) return await aiohttp_client(hass.http.app)
...@@ -191,12 +191,13 @@ async def test_saving_loading(hass, hass_storage): ...@@ -191,12 +191,13 @@ async def test_saving_loading(hass, hass_storage):
await flush_store(manager._store._store) await flush_store(manager._store._store)
store2 = auth.AuthStore(hass) store2 = auth.AuthStore(hass)
await store2.async_load() users = await store2.async_get_users()
assert len(store2.users) == 1 assert len(users) == 1
assert store2.users[user.id] == user assert users[0] == user
assert len(store2.clients) == 1 clients = await store2.async_get_clients()
assert store2.clients[client.id] == client assert len(clients) == 1
assert clients[0] == client
def test_access_token_expired(): def test_access_token_expired():
...@@ -224,15 +225,18 @@ def test_access_token_expired(): ...@@ -224,15 +225,18 @@ def test_access_token_expired():
async def test_cannot_retrieve_expired_access_token(hass): async def test_cannot_retrieve_expired_access_token(hass):
"""Test that we cannot retrieve expired access tokens.""" """Test that we cannot retrieve expired access tokens."""
manager = await auth.auth_manager_from_config(hass, []) manager = await auth.auth_manager_from_config(hass, [])
client = await manager.async_create_client('test')
user = MockUser( user = MockUser(
id='mock-user', id='mock-user',
is_owner=False, is_owner=False,
is_active=False, is_active=False,
name='Paulus', name='Paulus',
).add_to_auth_manager(manager) ).add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(user, 'bla') refresh_token = await manager.async_create_refresh_token(user, client.id)
access_token = manager.async_create_access_token(refresh_token) assert refresh_token.user.id is user.id
assert refresh_token.client_id is client.id
access_token = manager.async_create_access_token(refresh_token)
assert manager.async_get_access_token(access_token.token) is access_token assert manager.async_get_access_token(access_token.token) is access_token
with patch('homeassistant.auth.dt_util.utcnow', with patch('homeassistant.auth.dt_util.utcnow',
...@@ -241,3 +245,38 @@ async def test_cannot_retrieve_expired_access_token(hass): ...@@ -241,3 +245,38 @@ async def test_cannot_retrieve_expired_access_token(hass):
# Even with unpatched time, it should have been removed from manager # Even with unpatched time, it should have been removed from manager
assert manager.async_get_access_token(access_token.token) is None assert manager.async_get_access_token(access_token.token) is None
async def test_get_or_create_client(hass):
"""Test that get_or_create_client works."""
manager = await auth.auth_manager_from_config(hass, [])
client1 = await manager.async_get_or_create_client(
'Test Client', redirect_uris=['https://test.com/1'])
assert client1.name is 'Test Client'
client2 = await manager.async_get_or_create_client(
'Test Client', redirect_uris=['https://test.com/1'])
assert client2.id is client1.id
async def test_cannot_create_refresh_token_with_invalide_client_id(hass):
"""Test that we cannot create refresh token with invalid client id."""
manager = await auth.auth_manager_from_config(hass, [])
user = MockUser(
id='mock-user',
is_owner=False,
is_active=False,
name='Paulus',
).add_to_auth_manager(manager)
with pytest.raises(ValueError):
await manager.async_create_refresh_token(user, 'bla')
async def test_cannot_create_refresh_token_with_invalide_user(hass):
"""Test that we cannot create refresh token with invalid client id."""
manager = await auth.auth_manager_from_config(hass, [])
client = await manager.async_create_client('test')
user = MockUser(id='invalid-user')
with pytest.raises(ValueError):
await manager.async_create_refresh_token(user, client.id)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment