diff --git a/app/models.py b/app/models.py index 6093e9a247565b051d64676a0857bdfb285ca838..d0f526911b5f6b548617102d1d563262a7577f21 100644 --- a/app/models.py +++ b/app/models.py @@ -117,25 +117,26 @@ class User(db.Model, UserMixin): @property def csentry_groups(self): groups = [] - for key, value in current_app.config['CSENTRY_LDAP_GROUPS'].items(): - if value in self.groups: - groups.append(key) + for key, values in current_app.config['CSENTRY_LDAP_GROUPS'].items(): + for value in values: + if value in self.groups: + groups.append(key) return groups @property def is_admin(self): - return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups + for group in current_app.config['CSENTRY_LDAP_GROUPS']['admin']: + if group in self.groups: + return True + return False def is_member_of_one_group(self, groups): """Return True if the user is at least member of one of the given groups""" - names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups] + names = [] + for group in groups: + names.extend(current_app.config['CSENTRY_LDAP_GROUPS'].get(group)) return bool(set(self.groups) & set(names)) - def is_member_of_all_groups(self, groups): - """Return True if the user is member of all the given groups""" - names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups] - return set(names).issubset(self.groups) - def __str__(self): return self.display_name diff --git a/app/settings.py b/app/settings.py index 32f45bfbc07a2670053e39ce93d8d3e980cf1f11..97c63c3fa88d3605cc0e7aab179503f1de18b097 100644 --- a/app/settings.py +++ b/app/settings.py @@ -49,8 +49,8 @@ LDAP_GET_USER_ATTRIBUTES = ['cn', 'sAMAccountName', 'mail'] LDAP_GET_GROUP_ATTRIBUTES = ['cn'] CSENTRY_LDAP_GROUPS = { - 'admin': 'ICS Control System Infrastructure group', - 'create': 'ICS Employees', + 'admin': ['ICS Control System Infrastructure group'], + 'create': ['ICS Employees', 'ICS Consultants'], } NETWORK_DEFAULT_PREFIX = 24 diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 55a3c12971a6d893ce7c23dfd719ecf96f3679bc..75c6b356a6d5239d6784a6c967b15686d1a44025 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -40,8 +40,8 @@ def app(request): 'WTF_CSRF_ENABLED': False, 'SQLALCHEMY_DATABASE_URI': 'postgresql://ics:icspwd@postgres/csentry_db_test', 'CSENTRY_LDAP_GROUPS': { - 'admin': 'CSEntry Admin', - 'create': 'CSEntry User', + 'admin': ['CSEntry Admin'], + 'create': ['CSEntry User', 'CSEntry Consultant'], } } app = create_app(config=config) @@ -126,6 +126,10 @@ def patch_ldap_authenticate(monkeypatch): response.status = AuthenticationResponseStatus.success response.user_info = {'cn': 'User RW', 'mail': 'user_rw@example.com'} response.user_groups = [{'cn': 'CSEntry User'}] + elif username == 'consultant' and password == 'consultantpwd': + response.status = AuthenticationResponseStatus.success + response.user_info = {'cn': 'Consultant', 'mail': 'consultant@example.com'} + response.user_groups = [{'cn': 'CSEntry Consultant'}] elif username == 'user_ro' and password == 'userro': response.status = AuthenticationResponseStatus.success response.user_info = {'cn': 'User RO', 'mail': 'user_ro@example.com'} diff --git a/tests/functional/test_api.py b/tests/functional/test_api.py index ecc181385ade27e7a49cc2ea5b0f64b36efcbc94..c6ce317f621ef81696366a913b103b8b04b2716e 100644 --- a/tests/functional/test_api.py +++ b/tests/functional/test_api.py @@ -90,6 +90,11 @@ def user_token(client): return get_token(client, 'user_rw', 'userrw') +@pytest.fixture() +def consultant_token(client): + return get_token(client, 'consultant', 'consultantpwd') + + @pytest.fixture() def admin_token(client): return get_token(client, 'admin', 'adminpasswd') @@ -780,6 +785,12 @@ def test_create_host(client, item_factory, user_token): assert models.Host.query.count() == 2 +def test_create_host_as_consultant(client, item_factory, consultant_token): + data = {'name': 'my-hostname'} + response = post(client, f'{API_URL}/network/hosts', data=data, token=consultant_token) + assert response.status_code == 201 + + def test_get_user_profile(client, readonly_token): response = get(client, f'{API_URL}/user/profile', token=readonly_token) assert response.status_code == 200 diff --git a/tests/functional/test_models.py b/tests/functional/test_models.py index 5956fec4a4722a4735b9a8c7b52f303746a2db95..b96c16222757ab91fdfd92b5c6cbbb22913dabd7 100644 --- a/tests/functional/test_models.py +++ b/tests/functional/test_models.py @@ -22,6 +22,26 @@ def test_user_groups(user_factory): assert user.groups == groups +def test_user_is_admin(user_factory): + user = user_factory(groups=['foo', 'CSEntry User']) + assert not user.is_admin + user = user_factory(groups=['foo', 'CSEntry Admin']) + assert user.is_admin + + +def test_user_is_member_of_one_group(user_factory): + user = user_factory(groups=['one', 'two']) + assert not user.is_member_of_one_group(['create', 'admin']) + user = user_factory(groups=['one', 'CSEntry Consultant']) + assert user.is_member_of_one_group(['create']) + assert user.is_member_of_one_group(['create', 'admin']) + assert not user.is_member_of_one_group(['admin']) + user = user_factory(groups=['one', 'CSEntry Admin']) + assert not user.is_member_of_one_group(['create']) + assert user.is_member_of_one_group(['create', 'admin']) + assert user.is_member_of_one_group(['admin']) + + def test_network_ip_properties(network_factory): # Create some networks network1 = network_factory(address='172.16.1.0/24', first_ip='172.16.1.10', last_ip='172.16.1.250')