From 56f8f5ad7f9ed32a9371b027e1387a28d98df0ce Mon Sep 17 00:00:00 2001 From: Benjamin Bertrand <benjamin.bertrand@esss.se> Date: Sat, 25 Nov 2017 21:39:01 +0100 Subject: [PATCH] Use factory_boy to ease testing --- tests/functional/__init__.py | 0 tests/functional/common.py | 16 +++++ tests/functional/conftest.py | 15 +++- tests/functional/factories.py | 117 ++++++++++++++++++++++++++++++++ tests/functional/test_api.py | 76 +++++++-------------- tests/functional/test_models.py | 32 +++------ tests/functional/test_web.py | 5 +- 7 files changed, 184 insertions(+), 77 deletions(-) create mode 100644 tests/functional/__init__.py create mode 100644 tests/functional/common.py create mode 100644 tests/functional/factories.py diff --git a/tests/functional/__init__.py b/tests/functional/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/functional/common.py b/tests/functional/common.py new file mode 100644 index 0000000..3f5b6c6 --- /dev/null +++ b/tests/functional/common.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +""" +tests.functional.common +~~~~~~~~~~~~~~~~~~~~~~~ + +Define common functions and variables used in the tests. + +:copyright: (c) 2017 European Spallation Source ERIC +:license: BSD 2-Clause, see LICENSE for more details. + +""" +from sqlalchemy import orm + +# create the global scope_session for the tests +# See http://factoryboy.readthedocs.io/en/latest/orms.html#managing-sessions +Session = orm.scoped_session(orm.sessionmaker()) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 71a747f..9fb3008 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -11,9 +11,20 @@ Pytest fixtures common to all functional tests. """ import pytest import sqlalchemy as sa +from pytest_factoryboy import register from flask_ldap3_login import AuthenticationResponse, AuthenticationResponseStatus from app.factory import create_app from app.extensions import db as _db +from . import common, factories + +register(factories.ActionFactory) +register(factories.ManufacturerFactory) +register(factories.ModelFactory) +register(factories.LocationFactory) +register(factories.StatusFactory) +register(factories.ItemFactory) +register(factories.NetworkFactory) +register(factories.HostFactory) @pytest.fixture(scope='session') @@ -67,8 +78,8 @@ def session(db, request): """ connection = db.engine.connect() transaction = connection.begin() - options = dict(bind=connection, binds={}, autoflush=False) - session = db.create_scoped_session(options=options) + session = common.Session + session.configure(bind=connection, autoflush=False) session.begin_nested() # session is actually a scoped_session diff --git a/tests/functional/factories.py b/tests/functional/factories.py new file mode 100644 index 0000000..ab714d6 --- /dev/null +++ b/tests/functional/factories.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +""" +tests.functional.factories +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This module defines models factories. + +:copyright: (c) 2017 European Spallation Source ERIC +:license: BSD 2-Clause, see LICENSE for more details. + +""" +import ipaddress +import factory +from faker import Factory as FakerFactory +from app import models +from . import common + + +faker = FakerFactory.create() + + +class ActionFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = models.Action + sqlalchemy_session = common.Session + sqlalchemy_session_persistence = 'commit' + + name = factory.Sequence(lambda n: f'action{n}') + + +class ManufacturerFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = models.Manufacturer + sqlalchemy_session = common.Session + sqlalchemy_session_persistence = 'commit' + + name = factory.Sequence(lambda n: f'manufacturer{n}') + + +class ModelFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = models.Model + sqlalchemy_session = common.Session + sqlalchemy_session_persistence = 'commit' + + name = factory.Sequence(lambda n: f'model{n}') + + +class LocationFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = models.Location + sqlalchemy_session = common.Session + sqlalchemy_session_persistence = 'commit' + + name = factory.Sequence(lambda n: f'location{n}') + + +class StatusFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = models.Status + sqlalchemy_session = common.Session + sqlalchemy_session_persistence = 'commit' + + name = factory.Sequence(lambda n: f'status{n}') + + +class ItemFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = models.Item + sqlalchemy_session = common.Session + sqlalchemy_session_persistence = 'commit' + + ics_id = factory.Sequence(lambda n: f'AAA{n:03}') + serial_number = factory.Faker('isbn10') + manufacturer = factory.SubFactory(ManufacturerFactory) + model = factory.SubFactory(ModelFactory) + location = factory.SubFactory(LocationFactory) + status = factory.SubFactory(StatusFactory) + + +class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = models.Network + sqlalchemy_session = common.Session + sqlalchemy_session_persistence = 'commit' + + name = factory.Sequence(lambda n: f'network{n}') + prefix = factory.Faker('ipv4', network=True) + vlanid = factory.Sequence(lambda n: 1600 + n) + + @factory.lazy_attribute + def first(self): + net = ipaddress.ip_network(self.prefix) + hosts = list(net.hosts()) + return str(hosts[4]) + + @factory.lazy_attribute + def last(self): + net = ipaddress.ip_network(self.prefix) + hosts = list(net.hosts()) + return str(hosts[-5]) + + @factory.lazy_attribute + def gateway(self): + net = ipaddress.ip_network(self.prefix) + hosts = list(net.hosts()) + return str(hosts[-1]) + + +class HostFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = models.Host + sqlalchemy_session = common.Session + sqlalchemy_session_persistence = 'commit' + + name = factory.Sequence(lambda n: f'host{n}') + network = factory.SubFactory(NetworkFactory) diff --git a/tests/functional/test_api.py b/tests/functional/test_api.py index 7f9862e..2b4f97d 100644 --- a/tests/functional/test_api.py +++ b/tests/functional/test_api.py @@ -237,15 +237,11 @@ def test_get_item_fail(client, session, readonly_token): check_response_message(response, "Item id 'bar' not found", 404) -def test_get_item(client, session, readonly_token): +def test_get_item(client, status_factory, item_factory, readonly_token): # Create some items - session.add(models.Status(name='Stock')) - session.flush() - item1 = models.Item(serial_number='123456') - item2 = models.Item(serial_number='234567', ics_id='AAA001', status='Stock') - for item in (item1, item2): - session.add(item) - session.commit() + status_factory(name='Stock') + item1 = item_factory(serial_number='123456') + item2 = item_factory(serial_number='234567', ics_id='AAA001', status='Stock') # we can get items by id... response = get(client, f'/api/items/{item1.id}', token=readonly_token) @@ -270,7 +266,7 @@ def test_patch_item_auth_fail(client, session, readonly_token): check_response_message(response, "User doesn't have the required group", 403) -def test_patch_item_fail(client, session, user_token): +def test_patch_item_fail(client, item_factory, user_token): response = patch(client, '/api/items/50', data={}, token=user_token) check_response_message(response, 'At least one field is required', 422) data = {'location': 'ESS', 'foo': 'bar'} @@ -283,10 +279,7 @@ def test_patch_item_fail(client, session, user_token): check_response_message(response, "Item id 'bar' not found", 404) # Create an item - item1 = models.Item(serial_number='234567', ics_id='AAA001') - session.add(item1) - session.commit() - + item1 = item_factory(serial_number='234567', ics_id='AAA001') # check that we can't change the serial_number or ics_id response = patch(client, f'/api/items/{item1.id}', data={'serial_number': '12345'}, token=user_token) check_response_message(response, "Invalid field 'serial_number'", 422) @@ -438,17 +431,12 @@ def test_get_items(client, session, readonly_token): check_response_message(response, 'Invalid query arguments', 422) -def test_get_networks(client, session, readonly_token): +def test_get_networks(client, location_factory, network_factory, readonly_token): # Create some networks - location = models.Location(name='G02') - session.add(location) - session.flush() - network1 = models.Network(name='network1', prefix='172.16.1.0/24', first='172.16.1.1', last='172.16.1.254') - network2 = models.Network(name='network2', prefix='172.16.20.0/22', first='172.16.20.11', last='172.16.20.250') - network3 = models.Network(name='network3', prefix='172.16.5.0/24', first='172.16.5.10', last='172.16.5.254', location_id=location.id) - for network in (network1, network2, network3): - session.add(network) - session.commit() + location = location_factory(name='G02') + network1 = network_factory(prefix='172.16.1.0/24', first='172.16.1.1', last='172.16.1.254') + network2 = network_factory(prefix='172.16.20.0/22', first='172.16.20.11', last='172.16.20.250') + network3 = network_factory(prefix='172.16.5.0/24', first='172.16.5.10', last='172.16.5.254', location=location) response = get(client, '/api/networks', token=readonly_token) assert response.status_code == 200 @@ -474,10 +462,8 @@ def test_create_network_auth_fail(client, session, user_token): check_response_message(response, "User doesn't have the required group", 403) -def test_create_network(client, session, admin_token): - location = models.Location(name='G02') - session.add(location) - session.commit() +def test_create_network(client, location_factory, admin_token): + location = location_factory(name='G02') # check that name, prefix, first and last are mandatory response = post(client, '/api/networks', data={}, token=admin_token) check_response_message(response, "Missing mandatory field 'name'", 422) @@ -591,19 +577,13 @@ def test_create_network_invalid_range(client, session, admin_token): check_response_message(response, 'Last IP address 172.16.1.9 is less than the first address 172.16.1.10', 422) -def test_get_hosts(client, session, readonly_token): +def test_get_hosts(client, network_factory, host_factory, readonly_token): # Create some hosts - network1 = models.Network(name='network1', prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250') - network2 = models.Network(name='network2', prefix='192.168.2.0/24', first='192.168.2.10', last='192.168.2.250') - session.add(network1) - session.add(network2) - session.flush() - host1 = models.Host(network_id=network1.id, ip='192.168.1.10') - host2 = models.Host(network_id=network1.id, ip='192.168.1.11', name='hostname2') - host3 = models.Host(network_id=network2.id, ip='192.168.2.10') - for host in (host1, host2, host3): - session.add(host) - session.commit() + network1 = network_factory(prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250') + network2 = network_factory(prefix='192.168.2.0/24', first='192.168.2.10', last='192.168.2.250') + host1 = host_factory(network=network1, ip='192.168.1.10') + host2 = host_factory(network=network1, ip='192.168.1.11', name='hostname2') + host3 = host_factory(network=network2, ip='192.168.2.10') response = get(client, '/api/hosts', token=readonly_token) assert response.status_code == 200 @@ -617,10 +597,8 @@ def test_get_hosts(client, session, readonly_token): check_input_is_subset_of_response(response, (host3.to_dict(),)) -def test_create_host(client, session, user_token): - network = models.Network(name='network1', prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250') - session.add(network) - session.commit() +def test_create_host(client, network_factory, user_token): + network = network_factory(prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250') # check that network_id and ip are mandatory response = post(client, '/api/hosts', data={}, token=user_token) check_response_message(response, "Missing mandatory field 'network'", 422) @@ -654,10 +632,8 @@ def test_create_host(client, session, user_token): @pytest.mark.parametrize('ip', ('', 'foo', '192.168')) -def test_create_host_invalid_ip(ip, client, session, user_token): - network = models.Network(name='network1', prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250') - session.add(network) - session.commit() +def test_create_host_invalid_ip(ip, client, network_factory, user_token): + network = network_factory(prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250') # invalid IP address data = {'network': network.prefix, 'ip': ip} @@ -665,10 +641,8 @@ def test_create_host_invalid_ip(ip, client, session, user_token): check_response_message(response, f"'{ip}' does not appear to be an IPv4 or IPv6 address", 422) -def test_create_host_ip_not_in_network(client, session, user_token): - network = models.Network(name='network1', prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250') - session.add(network) - session.commit() +def test_create_host_ip_not_in_network(client, network_factory, user_token): + network = network_factory(prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250') # IP address not in range data = {'network': network.prefix, 'ip': '192.168.2.4'} diff --git a/tests/functional/test_models.py b/tests/functional/test_models.py index d9b36ca..cffd358 100644 --- a/tests/functional/test_models.py +++ b/tests/functional/test_models.py @@ -10,16 +10,12 @@ This module defines models tests. """ import ipaddress -from app import models -def test_network_ip_properties(session): +def test_network_ip_properties(network_factory): # Create some networks - network1 = models.Network(name='network1', prefix='172.16.1.0/24', first='172.16.1.10', last='172.16.1.250') - network2 = models.Network(name='network2', prefix='172.16.20.0/26', first='172.16.20.11', last='172.16.20.14') - for network in (network1, network2): - session.add(network) - session.commit() + network1 = network_factory(prefix='172.16.1.0/24', first='172.16.1.10', last='172.16.1.250') + network2 = network_factory(prefix='172.16.20.0/26', first='172.16.20.11', last='172.16.20.14') assert network1.network_ip == ipaddress.ip_network('172.16.1.0/24') assert network1.first_ip == ipaddress.ip_address('172.16.1.10') @@ -38,17 +34,13 @@ def test_network_ip_properties(session): assert network2.used_ips() == [] -def test_network_available_and_used_ips(session): +def test_network_available_and_used_ips(network_factory, host_factory): # Create some networks and hosts - network1 = models.Network(name='network1', prefix='172.16.1.0/24', first='172.16.1.10', last='172.16.1.250') - network2 = models.Network(name='network2', prefix='172.16.20.0/26', first='172.16.20.11', last='172.16.20.14') - for network in (network1, network2): - session.add(network) - session.flush() + network1 = network_factory(prefix='172.16.1.0/24', first='172.16.1.10', last='172.16.1.250') + network2 = network_factory(prefix='172.16.20.0/26', first='172.16.20.11', last='172.16.20.14') for i in range(10, 20): - session.add(models.Host(network_id=network1.id, ip=f'172.16.1.{i}')) - session.add(models.Host(network_id=network2.id, ip='172.16.20.13')) - session.commit() + host_factory(network=network1, ip=f'172.16.1.{i}') + host_factory(network=network2, ip='172.16.20.13') # Check available and used IPs assert network1.used_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(10, 20)] assert network1.available_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(20, 251)] @@ -58,9 +50,8 @@ def test_network_available_and_used_ips(session): ipaddress.ip_address('172.16.20.14')] # Add more hosts - session.add(models.Host(network_id=network2.id, ip='172.16.20.11')) - session.add(models.Host(network_id=network2.id, ip='172.16.20.14')) - session.commit() + host_factory(network=network2, ip='172.16.20.11') + host_factory(network=network2, ip='172.16.20.14') assert len(network2.used_ips()) == 3 assert network2.used_ips() == [ipaddress.ip_address('172.16.20.11'), ipaddress.ip_address('172.16.20.13'), @@ -68,7 +59,6 @@ def test_network_available_and_used_ips(session): assert network2.available_ips() == [ipaddress.ip_address('172.16.20.12')] # Add last available IP - session.add(models.Host(network_id=network2.id, ip='172.16.20.12')) - session.commit() + host_factory(network=network2, ip='172.16.20.12') assert network2.used_ips() == [ipaddress.ip_address(f'172.16.20.{i}') for i in range(11, 15)] assert list(network2.available_ips()) == [] diff --git a/tests/functional/test_web.py b/tests/functional/test_web.py index 3f85e18..9233a07 100644 --- a/tests/functional/test_web.py +++ b/tests/functional/test_web.py @@ -73,13 +73,12 @@ def test_protected_url(url, client): assert response.status_code == 200 -def test_retrieve_items(logged_client, session): +def test_retrieve_items(logged_client, item_factory): response = get(logged_client, '/_retrieve_items') assert response.json['data'] == [] serial_numbers = ('12345', '45678') for sn in serial_numbers: - session.add(models.Item(serial_number=sn)) - session.commit() + item_factory(serial_number=sn) response = get(logged_client, '/_retrieve_items') items = response.json['data'] assert set(serial_numbers) == set(item[4] for item in items) -- GitLab