From a738d0b4521a9fad7108ccff3ea7c34c04960765 Mon Sep 17 00:00:00 2001 From: Benjamin Bertrand <benjamin.bertrand@esss.se> Date: Mon, 28 Aug 2017 12:50:10 +0200 Subject: [PATCH] Reorganize and improve tests Put under the functional directory all tests that require access to the database. There is a "session" fixture in that directory that is automatically invoked for each tests. It rollbacks any transaction after the test to leave the database clean. Other tests should go under the unit directory. Inspired from: - https://stackoverflow.com/a/38626139 - http://alexmic.net/flask-sqlalchemy-pytest/ --- conftest.py | 81 +----------- tests/functional/conftest.py | 121 ++++++++++++++++++ tests/{ => functional}/test_api.py | 66 +++++++--- .../{test_basic.py => functional/test_web.py} | 4 +- 4 files changed, 170 insertions(+), 102 deletions(-) create mode 100644 tests/functional/conftest.py rename tests/{ => functional}/test_api.py (61%) rename tests/{test_basic.py => functional/test_web.py} (95%) diff --git a/conftest.py b/conftest.py index b37c41a..36af21a 100644 --- a/conftest.py +++ b/conftest.py @@ -3,88 +3,9 @@ conftest ~~~~~~~~ -This module defines the main configuration for the tests. +Empty conftest.py so that pytest finds the checkout pakage. :copyright: (c) 2017 European Spallation Source ERIC :license: BSD 2-Clause, see LICENSE for more details. """ -import pytest -from flask_ldap3_login import AuthenticationResponse, AuthenticationResponseStatus -from app.factory import create_app -from app.extensions import db -from app.defaults import defaults - - -@pytest.fixture(scope='session') -def app(request): - config = { - 'TESTING': True, - 'WTF_CSRF_ENABLED': False, - 'SQLALCHEMY_DATABASE_URI': 'postgresql://ics:icstest@postgres_test/inventory_db_test', - 'INVENTORY_LDAP_GROUPS': { - 'admin': 'Inventory Admin', - 'create': 'Inventory User', - } - } - app = create_app(config=config) - with app.app_context(): - db.drop_all() - db.engine.execute('CREATE EXTENSION IF NOT EXISTS citext') - db.create_all() - for instance in defaults: - db.session.add(instance) - db.session.flush() - db.session.expunge_all() - db.session.commit() - yield app - db.session.remove() - db.drop_all() - - -@pytest.fixture -def client(request, app): - return app.test_client() - - -# TODO: make this work to clean the database between tests -# @pytest.fixture(autouse=True) -# def dbsession(request, monkeypatch): -# # Roll back at the end of every test -# request.addfinalizer(db.session.remove) -# # Prevent the session from closing (make it a no-op) and -# # committing (redirect to flush() instead) -# monkeypatch.setattr(db.session, 'commit', db.session.flush) -# monkeypatch.setattr(db.session, 'remove', lambda: None) - - -@pytest.fixture(autouse=True) -def no_ldap_connection(monkeypatch): - """Make sure we don't make any connection to the LDAP server""" - monkeypatch.delattr('flask_ldap3_login.LDAP3LoginManager._make_connection') - - -@pytest.fixture(autouse=True) -def patch_ldap_authenticate(monkeypatch): - - def authenticate(self, username, password): - response = AuthenticationResponse() - response.user_id = username - response.user_dn = f'cn={username},dc=esss,dc=lu,dc=se' - if username == 'admin' and password == 'adminpasswd': - response.status = AuthenticationResponseStatus.success - response.user_info = {'cn': 'Admin User', 'mail': 'admin@example.com'} - response.user_groups = [{'cn': 'Inventory Admin'}] - elif username == 'user_rw' and password == 'userrw': - response.status = AuthenticationResponseStatus.success - response.user_info = {'cn': 'User RW', 'mail': 'user_rw@example.com'} - response.user_groups = [{'cn': 'Inventory User'}] - elif username == 'user_ro' and password == 'userro': - response.status = AuthenticationResponseStatus.success - response.user_info = {'cn': 'User RO', 'mail': 'user_ro@example.com'} - response.user_groups = [{'cn': 'ESS Employees'}] - else: - response.status = AuthenticationResponseStatus.fail - return response - - monkeypatch.setattr('flask_ldap3_login.LDAP3LoginManager.authenticate', authenticate) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py new file mode 100644 index 0000000..2147fdb --- /dev/null +++ b/tests/functional/conftest.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +""" +tests.functional.conftest +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Pytest fixtures common to all functional tests. + +:copyright: (c) 2017 European Spallation Source ERIC +:license: BSD 2-Clause, see LICENSE for more details. + +""" +import pytest +import sqlalchemy as sa +from flask_ldap3_login import AuthenticationResponse, AuthenticationResponseStatus +from app.factory import create_app +from app.extensions import db as _db + + +@pytest.fixture(scope='session') +def app(request): + """Session-wide test `Flask` application.""" + config = { + 'TESTING': True, + 'WTF_CSRF_ENABLED': False, + 'SQLALCHEMY_DATABASE_URI': 'postgresql://ics:icstest@postgres_test/inventory_db_test', + 'INVENTORY_LDAP_GROUPS': { + 'admin': 'Inventory Admin', + 'create': 'Inventory User', + } + } + app = create_app(config=config) + ctx = app.app_context() + ctx.push() + + def teardown(): + ctx.pop() + + request.addfinalizer(teardown) + return app + + +@pytest.fixture +def client(request, app): + return app.test_client() + + +@pytest.fixture(scope='session') +def db(app, request): + """Session-wide test database.""" + def teardown(): + _db.session.remove() + _db.drop_all() + + _db.app = app + _db.engine.execute('CREATE EXTENSION IF NOT EXISTS citext') + _db.create_all() + + request.addfinalizer(teardown) + return _db + + +@pytest.fixture(autouse=True) +def session(db, request): + """Creates a new database session for every test. + + Rollback any transaction to always leave the database clean + """ + connection = db.engine.connect() + transaction = connection.begin() + options = dict(bind=connection, binds={}) + session = db.create_scoped_session(options=options) + session.begin_nested() + + # session is actually a scoped_session + # for the `after_transaction_end` event, we need a session instance to + # listen for, hence the `session()` call + @sa.event.listens_for(session(), 'after_transaction_end') + def resetart_savepoint(sess, trans): + if trans.nested and not trans._parent.nested: + session.expire_all() + session.begin_nested() + + db.session = session + + yield session + + session.remove() + transaction.rollback() + connection.close() + + +@pytest.fixture(autouse=True) +def no_ldap_connection(monkeypatch): + """Make sure we don't make any connection to the LDAP server""" + monkeypatch.delattr('flask_ldap3_login.LDAP3LoginManager._make_connection') + + +@pytest.fixture(autouse=True) +def patch_ldap_authenticate(monkeypatch): + + def authenticate(self, username, password): + response = AuthenticationResponse() + response.user_id = username + response.user_dn = f'cn={username},dc=esss,dc=lu,dc=se' + if username == 'admin' and password == 'adminpasswd': + response.status = AuthenticationResponseStatus.success + response.user_info = {'cn': 'Admin User', 'mail': 'admin@example.com'} + response.user_groups = [{'cn': 'Inventory Admin'}] + elif username == 'user_rw' and password == 'userrw': + response.status = AuthenticationResponseStatus.success + response.user_info = {'cn': 'User RW', 'mail': 'user_rw@example.com'} + response.user_groups = [{'cn': 'Inventory User'}] + elif username == 'user_ro' and password == 'userro': + response.status = AuthenticationResponseStatus.success + response.user_info = {'cn': 'User RO', 'mail': 'user_ro@example.com'} + response.user_groups = [{'cn': 'ESS Employees'}] + else: + response.status = AuthenticationResponseStatus.fail + return response + + monkeypatch.setattr('flask_ldap3_login.LDAP3LoginManager.authenticate', authenticate) diff --git a/tests/test_api.py b/tests/functional/test_api.py similarity index 61% rename from tests/test_api.py rename to tests/functional/test_api.py index 5306538..db395da 100644 --- a/tests/test_api.py +++ b/tests/functional/test_api.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -tests.test_api -~~~~~~~~~~~~~~ +tests.functional.test_api +~~~~~~~~~~~~~~~~~~~~~~~~~ This module defines API tests. @@ -11,6 +11,16 @@ This module defines API tests. """ import json import pytest +from app import models + + +ENDPOINT_MODEL = { + 'manufacturers': models.Manufacturer, + 'models': models.Model, + 'locations': models.Location, + 'status': models.Status, +} +ENDPOINTS = list(ENDPOINT_MODEL.keys()) def get(client, url, token=None): @@ -94,36 +104,52 @@ def test_login(client): assert 'access_token' in response.json -def test_get_locations(client, readonly_token): - response = client.get('/api/locations') +@pytest.mark.parametrize('endpoint', ENDPOINTS) +def test_get_generic_model(endpoint, session, client, readonly_token): + model = ENDPOINT_MODEL[endpoint] + names = ('Foo', 'Bar', 'Alice') + for name in names: + session.add(model(name=name)) + session.commit() + response = client.get(f'/api/{endpoint}') check_response_message(response, 'Missing Authorization Header', 401) - response = get(client, '/api/locations', 'xxxxxxxxx') + response = get(client, f'/api/{endpoint}', 'xxxxxxxxx') check_response_message(response, 'Not enough segments', 422) - response = get(client, '/api/locations', readonly_token) - check_names(response, ('ICS lab', 'Utgård', 'Site', 'ESS')) + response = get(client, f'/api/{endpoint}', readonly_token) + check_names(response, names) -def test_create_locations_fail(client, readonly_token): - response = client.post('/api/locations') +@pytest.mark.parametrize('endpoint', ENDPOINTS) +def test_create_generic_model_fail(endpoint, client, readonly_token): + response = client.post(f'/api/{endpoint}') check_response_message(response, 'Missing Authorization Header', 401) - response = post(client, '/api/locations', data={}, token='xxxxxxxxx') + response = post(client, f'/api/{endpoint}', data={}, token='xxxxxxxxx') check_response_message(response, 'Not enough segments', 422) - response = post(client, '/api/locations', data={}, token=readonly_token) + response = post(client, f'/api/{endpoint}', data={}, token=readonly_token) check_response_message(response, "User doesn't have the required group", 403) + model = ENDPOINT_MODEL[endpoint] + assert model.query.count() == 0 -def test_create_locations(client, user_token): - response = post(client, '/api/locations', data={}, token=user_token) +@pytest.mark.parametrize('endpoint', ENDPOINTS) +def test_create_generic_model(endpoint, client, user_token): + response = post(client, f'/api/{endpoint}', data={}, token=user_token) check_response_message(response, "Missing mandatory field 'name'", 422) data = {'name': 'Foo'} - response = post(client, '/api/locations', data=data, token=user_token) + response = post(client, f'/api/{endpoint}', data=data, token=user_token) assert response.status_code == 201 - assert response.json == {'id': 5, 'name': 'Foo'} - response = post(client, '/api/locations', data=data, token=user_token) + assert {'id', 'name'} <= set(response.json.keys()) + assert response.json['name'] == 'Foo' + response = post(client, f'/api/{endpoint}', data=data, token=user_token) check_response_message(response, 'IntegrityError', 409) - response = post(client, '/api/locations', data={'name': 'foo'}, token=user_token) + response = post(client, f'/api/{endpoint}', data={'name': 'foo'}, token=user_token) check_response_message(response, 'IntegrityError', 409) - response = post(client, '/api/locations', data={'name': 'FOO'}, token=user_token) + response = post(client, f'/api/{endpoint}', data={'name': 'FOO'}, token=user_token) check_response_message(response, 'IntegrityError', 409) - response = get(client, '/api/locations', user_token) - check_names(response, ('ICS lab', 'Utgård', 'Site', 'ESS', 'Foo')) + data = {'name': 'Bar'} + response = post(client, f'/api/{endpoint}', data=data, token=user_token) + assert response.status_code == 201 + model = ENDPOINT_MODEL[endpoint] + assert model.query.count() == 2 + response = get(client, f'/api/{endpoint}', user_token) + check_names(response, ('Foo', 'Bar')) diff --git a/tests/test_basic.py b/tests/functional/test_web.py similarity index 95% rename from tests/test_basic.py rename to tests/functional/test_web.py index ee59447..d8e8527 100644 --- a/tests/test_basic.py +++ b/tests/functional/test_web.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -tests.test_basic -~~~~~~~~~~~~~~~~ +tests.functional.test_web +~~~~~~~~~~~~~~~~~~~~~~~~~ This module defines basic web tests. -- GitLab