Skip to content
Snippets Groups Projects
Commit a738d0b4 authored by Benjamin Bertrand's avatar Benjamin Bertrand
Browse files

Reorganize and improve tests

Put under the functional directory all tests that require
access to the database. There is a "session" fixture in that
directory that is automatically invoked for each tests. It rollbacks
any transaction after the test to leave the database clean.

Other tests should go under the unit directory.

Inspired from:
- https://stackoverflow.com/a/38626139
- http://alexmic.net/flask-sqlalchemy-pytest/
parent 914e74b3
No related branches found
No related tags found
No related merge requests found
......@@ -3,88 +3,9 @@
conftest
~~~~~~~~
This module defines the main configuration for the tests.
Empty conftest.py so that pytest finds the checkout pakage.
:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.
"""
import pytest
from flask_ldap3_login import AuthenticationResponse, AuthenticationResponseStatus
from app.factory import create_app
from app.extensions import db
from app.defaults import defaults
@pytest.fixture(scope='session')
def app(request):
config = {
'TESTING': True,
'WTF_CSRF_ENABLED': False,
'SQLALCHEMY_DATABASE_URI': 'postgresql://ics:icstest@postgres_test/inventory_db_test',
'INVENTORY_LDAP_GROUPS': {
'admin': 'Inventory Admin',
'create': 'Inventory User',
}
}
app = create_app(config=config)
with app.app_context():
db.drop_all()
db.engine.execute('CREATE EXTENSION IF NOT EXISTS citext')
db.create_all()
for instance in defaults:
db.session.add(instance)
db.session.flush()
db.session.expunge_all()
db.session.commit()
yield app
db.session.remove()
db.drop_all()
@pytest.fixture
def client(request, app):
return app.test_client()
# TODO: make this work to clean the database between tests
# @pytest.fixture(autouse=True)
# def dbsession(request, monkeypatch):
# # Roll back at the end of every test
# request.addfinalizer(db.session.remove)
# # Prevent the session from closing (make it a no-op) and
# # committing (redirect to flush() instead)
# monkeypatch.setattr(db.session, 'commit', db.session.flush)
# monkeypatch.setattr(db.session, 'remove', lambda: None)
@pytest.fixture(autouse=True)
def no_ldap_connection(monkeypatch):
"""Make sure we don't make any connection to the LDAP server"""
monkeypatch.delattr('flask_ldap3_login.LDAP3LoginManager._make_connection')
@pytest.fixture(autouse=True)
def patch_ldap_authenticate(monkeypatch):
def authenticate(self, username, password):
response = AuthenticationResponse()
response.user_id = username
response.user_dn = f'cn={username},dc=esss,dc=lu,dc=se'
if username == 'admin' and password == 'adminpasswd':
response.status = AuthenticationResponseStatus.success
response.user_info = {'cn': 'Admin User', 'mail': 'admin@example.com'}
response.user_groups = [{'cn': 'Inventory Admin'}]
elif username == 'user_rw' and password == 'userrw':
response.status = AuthenticationResponseStatus.success
response.user_info = {'cn': 'User RW', 'mail': 'user_rw@example.com'}
response.user_groups = [{'cn': 'Inventory User'}]
elif username == 'user_ro' and password == 'userro':
response.status = AuthenticationResponseStatus.success
response.user_info = {'cn': 'User RO', 'mail': 'user_ro@example.com'}
response.user_groups = [{'cn': 'ESS Employees'}]
else:
response.status = AuthenticationResponseStatus.fail
return response
monkeypatch.setattr('flask_ldap3_login.LDAP3LoginManager.authenticate', authenticate)
# -*- coding: utf-8 -*-
"""
tests.functional.conftest
~~~~~~~~~~~~~~~~~~~~~~~~~
Pytest fixtures common to all functional tests.
:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.
"""
import pytest
import sqlalchemy as sa
from flask_ldap3_login import AuthenticationResponse, AuthenticationResponseStatus
from app.factory import create_app
from app.extensions import db as _db
@pytest.fixture(scope='session')
def app(request):
"""Session-wide test `Flask` application."""
config = {
'TESTING': True,
'WTF_CSRF_ENABLED': False,
'SQLALCHEMY_DATABASE_URI': 'postgresql://ics:icstest@postgres_test/inventory_db_test',
'INVENTORY_LDAP_GROUPS': {
'admin': 'Inventory Admin',
'create': 'Inventory User',
}
}
app = create_app(config=config)
ctx = app.app_context()
ctx.push()
def teardown():
ctx.pop()
request.addfinalizer(teardown)
return app
@pytest.fixture
def client(request, app):
return app.test_client()
@pytest.fixture(scope='session')
def db(app, request):
"""Session-wide test database."""
def teardown():
_db.session.remove()
_db.drop_all()
_db.app = app
_db.engine.execute('CREATE EXTENSION IF NOT EXISTS citext')
_db.create_all()
request.addfinalizer(teardown)
return _db
@pytest.fixture(autouse=True)
def session(db, request):
"""Creates a new database session for every test.
Rollback any transaction to always leave the database clean
"""
connection = db.engine.connect()
transaction = connection.begin()
options = dict(bind=connection, binds={})
session = db.create_scoped_session(options=options)
session.begin_nested()
# session is actually a scoped_session
# for the `after_transaction_end` event, we need a session instance to
# listen for, hence the `session()` call
@sa.event.listens_for(session(), 'after_transaction_end')
def resetart_savepoint(sess, trans):
if trans.nested and not trans._parent.nested:
session.expire_all()
session.begin_nested()
db.session = session
yield session
session.remove()
transaction.rollback()
connection.close()
@pytest.fixture(autouse=True)
def no_ldap_connection(monkeypatch):
"""Make sure we don't make any connection to the LDAP server"""
monkeypatch.delattr('flask_ldap3_login.LDAP3LoginManager._make_connection')
@pytest.fixture(autouse=True)
def patch_ldap_authenticate(monkeypatch):
def authenticate(self, username, password):
response = AuthenticationResponse()
response.user_id = username
response.user_dn = f'cn={username},dc=esss,dc=lu,dc=se'
if username == 'admin' and password == 'adminpasswd':
response.status = AuthenticationResponseStatus.success
response.user_info = {'cn': 'Admin User', 'mail': 'admin@example.com'}
response.user_groups = [{'cn': 'Inventory Admin'}]
elif username == 'user_rw' and password == 'userrw':
response.status = AuthenticationResponseStatus.success
response.user_info = {'cn': 'User RW', 'mail': 'user_rw@example.com'}
response.user_groups = [{'cn': 'Inventory User'}]
elif username == 'user_ro' and password == 'userro':
response.status = AuthenticationResponseStatus.success
response.user_info = {'cn': 'User RO', 'mail': 'user_ro@example.com'}
response.user_groups = [{'cn': 'ESS Employees'}]
else:
response.status = AuthenticationResponseStatus.fail
return response
monkeypatch.setattr('flask_ldap3_login.LDAP3LoginManager.authenticate', authenticate)
# -*- coding: utf-8 -*-
"""
tests.test_api
~~~~~~~~~~~~~~
tests.functional.test_api
~~~~~~~~~~~~~~~~~~~~~~~~~
This module defines API tests.
......@@ -11,6 +11,16 @@ This module defines API tests.
"""
import json
import pytest
from app import models
ENDPOINT_MODEL = {
'manufacturers': models.Manufacturer,
'models': models.Model,
'locations': models.Location,
'status': models.Status,
}
ENDPOINTS = list(ENDPOINT_MODEL.keys())
def get(client, url, token=None):
......@@ -94,36 +104,52 @@ def test_login(client):
assert 'access_token' in response.json
def test_get_locations(client, readonly_token):
response = client.get('/api/locations')
@pytest.mark.parametrize('endpoint', ENDPOINTS)
def test_get_generic_model(endpoint, session, client, readonly_token):
model = ENDPOINT_MODEL[endpoint]
names = ('Foo', 'Bar', 'Alice')
for name in names:
session.add(model(name=name))
session.commit()
response = client.get(f'/api/{endpoint}')
check_response_message(response, 'Missing Authorization Header', 401)
response = get(client, '/api/locations', 'xxxxxxxxx')
response = get(client, f'/api/{endpoint}', 'xxxxxxxxx')
check_response_message(response, 'Not enough segments', 422)
response = get(client, '/api/locations', readonly_token)
check_names(response, ('ICS lab', 'Utgård', 'Site', 'ESS'))
response = get(client, f'/api/{endpoint}', readonly_token)
check_names(response, names)
def test_create_locations_fail(client, readonly_token):
response = client.post('/api/locations')
@pytest.mark.parametrize('endpoint', ENDPOINTS)
def test_create_generic_model_fail(endpoint, client, readonly_token):
response = client.post(f'/api/{endpoint}')
check_response_message(response, 'Missing Authorization Header', 401)
response = post(client, '/api/locations', data={}, token='xxxxxxxxx')
response = post(client, f'/api/{endpoint}', data={}, token='xxxxxxxxx')
check_response_message(response, 'Not enough segments', 422)
response = post(client, '/api/locations', data={}, token=readonly_token)
response = post(client, f'/api/{endpoint}', data={}, token=readonly_token)
check_response_message(response, "User doesn't have the required group", 403)
model = ENDPOINT_MODEL[endpoint]
assert model.query.count() == 0
def test_create_locations(client, user_token):
response = post(client, '/api/locations', data={}, token=user_token)
@pytest.mark.parametrize('endpoint', ENDPOINTS)
def test_create_generic_model(endpoint, client, user_token):
response = post(client, f'/api/{endpoint}', data={}, token=user_token)
check_response_message(response, "Missing mandatory field 'name'", 422)
data = {'name': 'Foo'}
response = post(client, '/api/locations', data=data, token=user_token)
response = post(client, f'/api/{endpoint}', data=data, token=user_token)
assert response.status_code == 201
assert response.json == {'id': 5, 'name': 'Foo'}
response = post(client, '/api/locations', data=data, token=user_token)
assert {'id', 'name'} <= set(response.json.keys())
assert response.json['name'] == 'Foo'
response = post(client, f'/api/{endpoint}', data=data, token=user_token)
check_response_message(response, 'IntegrityError', 409)
response = post(client, '/api/locations', data={'name': 'foo'}, token=user_token)
response = post(client, f'/api/{endpoint}', data={'name': 'foo'}, token=user_token)
check_response_message(response, 'IntegrityError', 409)
response = post(client, '/api/locations', data={'name': 'FOO'}, token=user_token)
response = post(client, f'/api/{endpoint}', data={'name': 'FOO'}, token=user_token)
check_response_message(response, 'IntegrityError', 409)
response = get(client, '/api/locations', user_token)
check_names(response, ('ICS lab', 'Utgård', 'Site', 'ESS', 'Foo'))
data = {'name': 'Bar'}
response = post(client, f'/api/{endpoint}', data=data, token=user_token)
assert response.status_code == 201
model = ENDPOINT_MODEL[endpoint]
assert model.query.count() == 2
response = get(client, f'/api/{endpoint}', user_token)
check_names(response, ('Foo', 'Bar'))
# -*- coding: utf-8 -*-
"""
tests.test_basic
~~~~~~~~~~~~~~~~
tests.functional.test_web
~~~~~~~~~~~~~~~~~~~~~~~~~
This module defines basic web tests.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment