diff --git a/homeassistant/auth.py b/homeassistant/auth.py index f56e00bf31e55311df805a16045f260c7bd08911..a4e8ee05943b4c96fd5ff33068bcb04dd6969458 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -1,23 +1,22 @@ """Provide an authentication layer for Home Assistant.""" import asyncio import binascii -from collections import OrderedDict -from datetime import datetime, timedelta -import os import importlib import logging +import os import uuid +from collections import OrderedDict +from datetime import datetime, timedelta import attr import voluptuous as vol from voluptuous.humanize import humanize_error from homeassistant import data_entry_flow, requirements -from homeassistant.core import callback from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID -from homeassistant.util.decorator import Registry +from homeassistant.core import callback from homeassistant.util import dt as dt_util - +from homeassistant.util.decorator import Registry _LOGGER = logging.getLogger(__name__) @@ -349,6 +348,16 @@ class AuthManager: return await self._store.async_create_client( name, redirect_uris, no_secret) + async def async_get_or_create_client(self, name, *, redirect_uris=None, + no_secret=False): + """Find a client, if not exists, create a new one.""" + for client in await self._store.async_get_clients(): + if client.name == name: + return client + + return await self._store.async_create_client( + name, redirect_uris, no_secret) + async def async_get_client(self, client_id): """Get a client.""" return await self._store.async_get_client(client_id) @@ -392,29 +401,36 @@ class AuthStore: def __init__(self, hass): """Initialize the auth store.""" self.hass = hass - self.users = None - self.clients = None + self._users = None + self._clients = None self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) async def credentials_for_provider(self, provider_type, provider_id): """Return credentials for specific auth provider type and id.""" - if self.users is None: + if self._users is None: await self.async_load() return [ credentials - for user in self.users.values() + for user in self._users.values() for credentials in user.credentials if (credentials.auth_provider_type == provider_type and credentials.auth_provider_id == provider_id) ] + async def async_get_users(self): + """Retrieve all users.""" + if self._users is None: + await self.async_load() + + return list(self._users.values()) + async def async_get_user(self, user_id): """Retrieve a user.""" - if self.users is None: + if self._users is None: await self.async_load() - return self.users.get(user_id) + return self._users.get(user_id) async def async_get_or_create_user(self, credentials, auth_provider): """Get or create a new user for given credentials. @@ -422,7 +438,7 @@ class AuthStore: If link_user is passed in, the credentials will be linked to the passed in user if the credentials are new. """ - if self.users is None: + if self._users is None: await self.async_load() # New credentials, store in user @@ -430,7 +446,7 @@ class AuthStore: info = await auth_provider.async_user_meta_for_credentials( credentials) # Make owner and activate user if it's the first user. - if self.users: + if self._users: is_owner = False is_active = False else: @@ -442,11 +458,11 @@ class AuthStore: is_active=is_active, name=info.get('name'), ) - self.users[new_user.id] = new_user + self._users[new_user.id] = new_user await self.async_link_user(new_user, credentials) return new_user - for user in self.users.values(): + for user in self._users.values(): for creds in user.credentials: if (creds.auth_provider_type == credentials.auth_provider_type and creds.auth_provider_id == @@ -463,11 +479,19 @@ class AuthStore: async def async_remove_user(self, user): """Remove a user.""" - self.users.pop(user.id) + self._users.pop(user.id) await self.async_save() async def async_create_refresh_token(self, user, client_id): """Create a new token for a user.""" + local_user = await self.async_get_user(user.id) + if local_user is None: + raise ValueError('Invalid user') + + local_client = await self.async_get_client(client_id) + if local_client is None: + raise ValueError('Invalid client_id') + refresh_token = RefreshToken(user, client_id) user.refresh_tokens[refresh_token.token] = refresh_token await self.async_save() @@ -475,10 +499,10 @@ class AuthStore: async def async_get_refresh_token(self, token): """Get refresh token by token.""" - if self.users is None: + if self._users is None: await self.async_load() - for user in self.users.values(): + for user in self._users.values(): refresh_token = user.refresh_tokens.get(token) if refresh_token is not None: return refresh_token @@ -487,7 +511,7 @@ class AuthStore: async def async_create_client(self, name, redirect_uris, no_secret): """Create a new client.""" - if self.clients is None: + if self._clients is None: await self.async_load() kwargs = { @@ -499,16 +523,23 @@ class AuthStore: kwargs['secret'] = None client = Client(**kwargs) - self.clients[client.id] = client + self._clients[client.id] = client await self.async_save() return client + async def async_get_clients(self): + """Return all clients.""" + if self._clients is None: + await self.async_load() + + return list(self._clients.values()) + async def async_get_client(self, client_id): """Get a client.""" - if self.clients is None: + if self._clients is None: await self.async_load() - return self.clients.get(client_id) + return self._clients.get(client_id) async def async_load(self): """Load the users.""" @@ -516,12 +547,12 @@ class AuthStore: # Make sure that we're not overriding data if 2 loads happened at the # same time - if self.users is not None: + if self._users is not None: return if data is None: - self.users = {} - self.clients = {} + self._users = {} + self._clients = {} return users = { @@ -565,8 +596,8 @@ class AuthStore: cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients'] } - self.users = users - self.clients = clients + self._users = users + self._clients = clients async def async_save(self): """Save users.""" @@ -577,7 +608,7 @@ class AuthStore: 'is_active': user.is_active, 'name': user.name, } - for user in self.users.values() + for user in self._users.values() ] credentials = [ @@ -588,7 +619,7 @@ class AuthStore: 'auth_provider_id': credential.auth_provider_id, 'data': credential.data, } - for user in self.users.values() + for user in self._users.values() for credential in user.credentials ] @@ -602,7 +633,7 @@ class AuthStore: refresh_token.access_token_expiration.total_seconds(), 'token': refresh_token.token, } - for user in self.users.values() + for user in self._users.values() for refresh_token in user.refresh_tokens.values() ] @@ -613,7 +644,7 @@ class AuthStore: 'created_at': access_token.created_at.isoformat(), 'token': access_token.token, } - for user in self.users.values() + for user in self._users.values() for refresh_token in user.refresh_tokens.values() for access_token in refresh_token.access_tokens ] @@ -625,7 +656,7 @@ class AuthStore: 'secret': client.secret, 'redirect_uris': client.redirect_uris, } - for client in self.clients.values() + for client in self._clients.values() ] data = { diff --git a/homeassistant/components/frontend/__init__.py b/homeassistant/components/frontend/__init__.py index 7bad8ff727d93785fbcb261d5c806188f343d947..9a32626c66a0febd1a771938c579cefe3b7b35e7 100644 --- a/homeassistant/components/frontend/__init__.py +++ b/homeassistant/components/frontend/__init__.py @@ -201,7 +201,7 @@ def add_manifest_json_key(key, val): async def async_setup(hass, config): """Set up the serving of the frontend.""" if hass.auth.active: - client = await hass.auth.async_create_client( + client = await hass.auth.async_get_or_create_client( 'Home Assistant Frontend', redirect_uris=['/'], no_secret=True, diff --git a/tests/common.py b/tests/common.py index 1b8eabaa0db4bc50acd684bd6c66f7fa644294c7..3a51cd3e0598471591ec81d6fe2376bdc0cad64d 100644 --- a/tests/common.py +++ b/tests/common.py @@ -321,7 +321,7 @@ class MockUser(auth.User): def add_to_auth_manager(self, auth_mgr): """Test helper to add entry to hass.""" ensure_auth_manager_loaded(auth_mgr) - auth_mgr._store.users[self.id] = self + auth_mgr._store._users[self.id] = self return self @@ -329,10 +329,10 @@ class MockUser(auth.User): def ensure_auth_manager_loaded(auth_mgr): """Ensure an auth manager is considered loaded.""" store = auth_mgr._store - if store.clients is None: - store.clients = {} - if store.users is None: - store.users = {} + if store._clients is None: + store._clients = {} + if store._users is None: + store._users = {} class MockModule(object): diff --git a/tests/components/auth/__init__.py b/tests/components/auth/__init__.py index f0b205ff5ce490df3e0cc2600b28ca8bc3b8be6b..21719c12569b3b7e155868bd33bcf4cfd8ba9e5b 100644 --- a/tests/components/auth/__init__.py +++ b/tests/components/auth/__init__.py @@ -34,7 +34,7 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG, }) client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET, redirect_uris=[CLIENT_REDIRECT_URI]) - hass.auth._store.clients[client.id] = client + hass.auth._store._clients[client.id] = client if setup_api: await async_setup_component(hass, 'api', {}) return await aiohttp_client(hass.http.app) diff --git a/tests/test_auth.py b/tests/test_auth.py index 4c0db71466e97da5e102c16774cb2b48fc351214..5b545223c15a9cb39606b2c555f87e302d5a4e18 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -191,12 +191,13 @@ async def test_saving_loading(hass, hass_storage): await flush_store(manager._store._store) store2 = auth.AuthStore(hass) - await store2.async_load() - assert len(store2.users) == 1 - assert store2.users[user.id] == user + users = await store2.async_get_users() + assert len(users) == 1 + assert users[0] == user - assert len(store2.clients) == 1 - assert store2.clients[client.id] == client + clients = await store2.async_get_clients() + assert len(clients) == 1 + assert clients[0] == client def test_access_token_expired(): @@ -224,15 +225,18 @@ def test_access_token_expired(): async def test_cannot_retrieve_expired_access_token(hass): """Test that we cannot retrieve expired access tokens.""" manager = await auth.auth_manager_from_config(hass, []) + client = await manager.async_create_client('test') user = MockUser( id='mock-user', is_owner=False, is_active=False, name='Paulus', ).add_to_auth_manager(manager) - refresh_token = await manager.async_create_refresh_token(user, 'bla') - access_token = manager.async_create_access_token(refresh_token) + refresh_token = await manager.async_create_refresh_token(user, client.id) + assert refresh_token.user.id is user.id + assert refresh_token.client_id is client.id + access_token = manager.async_create_access_token(refresh_token) assert manager.async_get_access_token(access_token.token) is access_token with patch('homeassistant.auth.dt_util.utcnow', @@ -241,3 +245,38 @@ async def test_cannot_retrieve_expired_access_token(hass): # Even with unpatched time, it should have been removed from manager assert manager.async_get_access_token(access_token.token) is None + + +async def test_get_or_create_client(hass): + """Test that get_or_create_client works.""" + manager = await auth.auth_manager_from_config(hass, []) + + client1 = await manager.async_get_or_create_client( + 'Test Client', redirect_uris=['https://test.com/1']) + assert client1.name is 'Test Client' + + client2 = await manager.async_get_or_create_client( + 'Test Client', redirect_uris=['https://test.com/1']) + assert client2.id is client1.id + + +async def test_cannot_create_refresh_token_with_invalide_client_id(hass): + """Test that we cannot create refresh token with invalid client id.""" + manager = await auth.auth_manager_from_config(hass, []) + user = MockUser( + id='mock-user', + is_owner=False, + is_active=False, + name='Paulus', + ).add_to_auth_manager(manager) + with pytest.raises(ValueError): + await manager.async_create_refresh_token(user, 'bla') + + +async def test_cannot_create_refresh_token_with_invalide_user(hass): + """Test that we cannot create refresh token with invalid client id.""" + manager = await auth.auth_manager_from_config(hass, []) + client = await manager.async_create_client('test') + user = MockUser(id='invalid-user') + with pytest.raises(ValueError): + await manager.async_create_refresh_token(user, client.id)