From 56f8f5ad7f9ed32a9371b027e1387a28d98df0ce Mon Sep 17 00:00:00 2001
From: Benjamin Bertrand <benjamin.bertrand@esss.se>
Date: Sat, 25 Nov 2017 21:39:01 +0100
Subject: [PATCH] Use factory_boy to ease testing

---
 tests/functional/__init__.py    |   0
 tests/functional/common.py      |  16 +++++
 tests/functional/conftest.py    |  15 +++-
 tests/functional/factories.py   | 117 ++++++++++++++++++++++++++++++++
 tests/functional/test_api.py    |  76 +++++++--------------
 tests/functional/test_models.py |  32 +++------
 tests/functional/test_web.py    |   5 +-
 7 files changed, 184 insertions(+), 77 deletions(-)
 create mode 100644 tests/functional/__init__.py
 create mode 100644 tests/functional/common.py
 create mode 100644 tests/functional/factories.py

diff --git a/tests/functional/__init__.py b/tests/functional/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/functional/common.py b/tests/functional/common.py
new file mode 100644
index 0000000..3f5b6c6
--- /dev/null
+++ b/tests/functional/common.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+"""
+tests.functional.common
+~~~~~~~~~~~~~~~~~~~~~~~
+
+Define common functions and variables used in the tests.
+
+:copyright: (c) 2017 European Spallation Source ERIC
+:license: BSD 2-Clause, see LICENSE for more details.
+
+"""
+from sqlalchemy import orm
+
+# create the global scope_session for the tests
+# See http://factoryboy.readthedocs.io/en/latest/orms.html#managing-sessions
+Session = orm.scoped_session(orm.sessionmaker())
diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py
index 71a747f..9fb3008 100644
--- a/tests/functional/conftest.py
+++ b/tests/functional/conftest.py
@@ -11,9 +11,20 @@ Pytest fixtures common to all functional tests.
 """
 import pytest
 import sqlalchemy as sa
+from pytest_factoryboy import register
 from flask_ldap3_login import AuthenticationResponse, AuthenticationResponseStatus
 from app.factory import create_app
 from app.extensions import db as _db
+from . import common, factories
+
+register(factories.ActionFactory)
+register(factories.ManufacturerFactory)
+register(factories.ModelFactory)
+register(factories.LocationFactory)
+register(factories.StatusFactory)
+register(factories.ItemFactory)
+register(factories.NetworkFactory)
+register(factories.HostFactory)
 
 
 @pytest.fixture(scope='session')
@@ -67,8 +78,8 @@ def session(db, request):
     """
     connection = db.engine.connect()
     transaction = connection.begin()
-    options = dict(bind=connection, binds={}, autoflush=False)
-    session = db.create_scoped_session(options=options)
+    session = common.Session
+    session.configure(bind=connection, autoflush=False)
     session.begin_nested()
 
     # session is actually a scoped_session
diff --git a/tests/functional/factories.py b/tests/functional/factories.py
new file mode 100644
index 0000000..ab714d6
--- /dev/null
+++ b/tests/functional/factories.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+"""
+tests.functional.factories
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+This module defines models factories.
+
+:copyright: (c) 2017 European Spallation Source ERIC
+:license: BSD 2-Clause, see LICENSE for more details.
+
+"""
+import ipaddress
+import factory
+from faker import Factory as FakerFactory
+from app import models
+from . import common
+
+
+faker = FakerFactory.create()
+
+
+class ActionFactory(factory.alchemy.SQLAlchemyModelFactory):
+    class Meta:
+        model = models.Action
+        sqlalchemy_session = common.Session
+        sqlalchemy_session_persistence = 'commit'
+
+    name = factory.Sequence(lambda n: f'action{n}')
+
+
+class ManufacturerFactory(factory.alchemy.SQLAlchemyModelFactory):
+    class Meta:
+        model = models.Manufacturer
+        sqlalchemy_session = common.Session
+        sqlalchemy_session_persistence = 'commit'
+
+    name = factory.Sequence(lambda n: f'manufacturer{n}')
+
+
+class ModelFactory(factory.alchemy.SQLAlchemyModelFactory):
+    class Meta:
+        model = models.Model
+        sqlalchemy_session = common.Session
+        sqlalchemy_session_persistence = 'commit'
+
+    name = factory.Sequence(lambda n: f'model{n}')
+
+
+class LocationFactory(factory.alchemy.SQLAlchemyModelFactory):
+    class Meta:
+        model = models.Location
+        sqlalchemy_session = common.Session
+        sqlalchemy_session_persistence = 'commit'
+
+    name = factory.Sequence(lambda n: f'location{n}')
+
+
+class StatusFactory(factory.alchemy.SQLAlchemyModelFactory):
+    class Meta:
+        model = models.Status
+        sqlalchemy_session = common.Session
+        sqlalchemy_session_persistence = 'commit'
+
+    name = factory.Sequence(lambda n: f'status{n}')
+
+
+class ItemFactory(factory.alchemy.SQLAlchemyModelFactory):
+    class Meta:
+        model = models.Item
+        sqlalchemy_session = common.Session
+        sqlalchemy_session_persistence = 'commit'
+
+    ics_id = factory.Sequence(lambda n: f'AAA{n:03}')
+    serial_number = factory.Faker('isbn10')
+    manufacturer = factory.SubFactory(ManufacturerFactory)
+    model = factory.SubFactory(ModelFactory)
+    location = factory.SubFactory(LocationFactory)
+    status = factory.SubFactory(StatusFactory)
+
+
+class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory):
+    class Meta:
+        model = models.Network
+        sqlalchemy_session = common.Session
+        sqlalchemy_session_persistence = 'commit'
+
+    name = factory.Sequence(lambda n: f'network{n}')
+    prefix = factory.Faker('ipv4', network=True)
+    vlanid = factory.Sequence(lambda n: 1600 + n)
+
+    @factory.lazy_attribute
+    def first(self):
+        net = ipaddress.ip_network(self.prefix)
+        hosts = list(net.hosts())
+        return str(hosts[4])
+
+    @factory.lazy_attribute
+    def last(self):
+        net = ipaddress.ip_network(self.prefix)
+        hosts = list(net.hosts())
+        return str(hosts[-5])
+
+    @factory.lazy_attribute
+    def gateway(self):
+        net = ipaddress.ip_network(self.prefix)
+        hosts = list(net.hosts())
+        return str(hosts[-1])
+
+
+class HostFactory(factory.alchemy.SQLAlchemyModelFactory):
+    class Meta:
+        model = models.Host
+        sqlalchemy_session = common.Session
+        sqlalchemy_session_persistence = 'commit'
+
+    name = factory.Sequence(lambda n: f'host{n}')
+    network = factory.SubFactory(NetworkFactory)
diff --git a/tests/functional/test_api.py b/tests/functional/test_api.py
index 7f9862e..2b4f97d 100644
--- a/tests/functional/test_api.py
+++ b/tests/functional/test_api.py
@@ -237,15 +237,11 @@ def test_get_item_fail(client, session, readonly_token):
     check_response_message(response, "Item id 'bar' not found", 404)
 
 
-def test_get_item(client, session, readonly_token):
+def test_get_item(client, status_factory, item_factory, readonly_token):
     # Create some items
-    session.add(models.Status(name='Stock'))
-    session.flush()
-    item1 = models.Item(serial_number='123456')
-    item2 = models.Item(serial_number='234567', ics_id='AAA001', status='Stock')
-    for item in (item1, item2):
-        session.add(item)
-    session.commit()
+    status_factory(name='Stock')
+    item1 = item_factory(serial_number='123456')
+    item2 = item_factory(serial_number='234567', ics_id='AAA001', status='Stock')
 
     # we can get items by id...
     response = get(client, f'/api/items/{item1.id}', token=readonly_token)
@@ -270,7 +266,7 @@ def test_patch_item_auth_fail(client, session, readonly_token):
     check_response_message(response, "User doesn't have the required group", 403)
 
 
-def test_patch_item_fail(client, session, user_token):
+def test_patch_item_fail(client, item_factory, user_token):
     response = patch(client, '/api/items/50', data={}, token=user_token)
     check_response_message(response, 'At least one field is required', 422)
     data = {'location': 'ESS', 'foo': 'bar'}
@@ -283,10 +279,7 @@ def test_patch_item_fail(client, session, user_token):
     check_response_message(response, "Item id 'bar' not found", 404)
 
     # Create an item
-    item1 = models.Item(serial_number='234567', ics_id='AAA001')
-    session.add(item1)
-    session.commit()
-
+    item1 = item_factory(serial_number='234567', ics_id='AAA001')
     # check that we can't change the serial_number or ics_id
     response = patch(client, f'/api/items/{item1.id}', data={'serial_number': '12345'}, token=user_token)
     check_response_message(response, "Invalid field 'serial_number'", 422)
@@ -438,17 +431,12 @@ def test_get_items(client, session, readonly_token):
     check_response_message(response, 'Invalid query arguments', 422)
 
 
-def test_get_networks(client, session, readonly_token):
+def test_get_networks(client, location_factory, network_factory, readonly_token):
     # Create some networks
-    location = models.Location(name='G02')
-    session.add(location)
-    session.flush()
-    network1 = models.Network(name='network1', prefix='172.16.1.0/24', first='172.16.1.1', last='172.16.1.254')
-    network2 = models.Network(name='network2', prefix='172.16.20.0/22', first='172.16.20.11', last='172.16.20.250')
-    network3 = models.Network(name='network3', prefix='172.16.5.0/24', first='172.16.5.10', last='172.16.5.254', location_id=location.id)
-    for network in (network1, network2, network3):
-        session.add(network)
-    session.commit()
+    location = location_factory(name='G02')
+    network1 = network_factory(prefix='172.16.1.0/24', first='172.16.1.1', last='172.16.1.254')
+    network2 = network_factory(prefix='172.16.20.0/22', first='172.16.20.11', last='172.16.20.250')
+    network3 = network_factory(prefix='172.16.5.0/24', first='172.16.5.10', last='172.16.5.254', location=location)
 
     response = get(client, '/api/networks', token=readonly_token)
     assert response.status_code == 200
@@ -474,10 +462,8 @@ def test_create_network_auth_fail(client, session, user_token):
     check_response_message(response, "User doesn't have the required group", 403)
 
 
-def test_create_network(client, session, admin_token):
-    location = models.Location(name='G02')
-    session.add(location)
-    session.commit()
+def test_create_network(client, location_factory, admin_token):
+    location = location_factory(name='G02')
     # check that name, prefix, first and last are mandatory
     response = post(client, '/api/networks', data={}, token=admin_token)
     check_response_message(response, "Missing mandatory field 'name'", 422)
@@ -591,19 +577,13 @@ def test_create_network_invalid_range(client, session, admin_token):
     check_response_message(response, 'Last IP address 172.16.1.9 is less than the first address 172.16.1.10', 422)
 
 
-def test_get_hosts(client, session, readonly_token):
+def test_get_hosts(client, network_factory, host_factory, readonly_token):
     # Create some hosts
-    network1 = models.Network(name='network1', prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250')
-    network2 = models.Network(name='network2', prefix='192.168.2.0/24', first='192.168.2.10', last='192.168.2.250')
-    session.add(network1)
-    session.add(network2)
-    session.flush()
-    host1 = models.Host(network_id=network1.id, ip='192.168.1.10')
-    host2 = models.Host(network_id=network1.id, ip='192.168.1.11', name='hostname2')
-    host3 = models.Host(network_id=network2.id, ip='192.168.2.10')
-    for host in (host1, host2, host3):
-        session.add(host)
-    session.commit()
+    network1 = network_factory(prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250')
+    network2 = network_factory(prefix='192.168.2.0/24', first='192.168.2.10', last='192.168.2.250')
+    host1 = host_factory(network=network1, ip='192.168.1.10')
+    host2 = host_factory(network=network1, ip='192.168.1.11', name='hostname2')
+    host3 = host_factory(network=network2, ip='192.168.2.10')
 
     response = get(client, '/api/hosts', token=readonly_token)
     assert response.status_code == 200
@@ -617,10 +597,8 @@ def test_get_hosts(client, session, readonly_token):
     check_input_is_subset_of_response(response, (host3.to_dict(),))
 
 
-def test_create_host(client, session, user_token):
-    network = models.Network(name='network1', prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250')
-    session.add(network)
-    session.commit()
+def test_create_host(client, network_factory, user_token):
+    network = network_factory(prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250')
     # check that network_id and ip are mandatory
     response = post(client, '/api/hosts', data={}, token=user_token)
     check_response_message(response, "Missing mandatory field 'network'", 422)
@@ -654,10 +632,8 @@ def test_create_host(client, session, user_token):
 
 
 @pytest.mark.parametrize('ip', ('', 'foo', '192.168'))
-def test_create_host_invalid_ip(ip, client, session, user_token):
-    network = models.Network(name='network1', prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250')
-    session.add(network)
-    session.commit()
+def test_create_host_invalid_ip(ip, client, network_factory, user_token):
+    network = network_factory(prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250')
     # invalid IP address
     data = {'network': network.prefix,
             'ip': ip}
@@ -665,10 +641,8 @@ def test_create_host_invalid_ip(ip, client, session, user_token):
     check_response_message(response, f"'{ip}' does not appear to be an IPv4 or IPv6 address", 422)
 
 
-def test_create_host_ip_not_in_network(client, session, user_token):
-    network = models.Network(name='network1', prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250')
-    session.add(network)
-    session.commit()
+def test_create_host_ip_not_in_network(client, network_factory, user_token):
+    network = network_factory(prefix='192.168.1.0/24', first='192.168.1.10', last='192.168.1.250')
     # IP address not in range
     data = {'network': network.prefix,
             'ip': '192.168.2.4'}
diff --git a/tests/functional/test_models.py b/tests/functional/test_models.py
index d9b36ca..cffd358 100644
--- a/tests/functional/test_models.py
+++ b/tests/functional/test_models.py
@@ -10,16 +10,12 @@ This module defines models tests.
 
 """
 import ipaddress
-from app import models
 
 
-def test_network_ip_properties(session):
+def test_network_ip_properties(network_factory):
     # Create some networks
-    network1 = models.Network(name='network1', prefix='172.16.1.0/24', first='172.16.1.10', last='172.16.1.250')
-    network2 = models.Network(name='network2', prefix='172.16.20.0/26', first='172.16.20.11', last='172.16.20.14')
-    for network in (network1, network2):
-        session.add(network)
-    session.commit()
+    network1 = network_factory(prefix='172.16.1.0/24', first='172.16.1.10', last='172.16.1.250')
+    network2 = network_factory(prefix='172.16.20.0/26', first='172.16.20.11', last='172.16.20.14')
 
     assert network1.network_ip == ipaddress.ip_network('172.16.1.0/24')
     assert network1.first_ip == ipaddress.ip_address('172.16.1.10')
@@ -38,17 +34,13 @@ def test_network_ip_properties(session):
     assert network2.used_ips() == []
 
 
-def test_network_available_and_used_ips(session):
+def test_network_available_and_used_ips(network_factory, host_factory):
     # Create some networks and hosts
-    network1 = models.Network(name='network1', prefix='172.16.1.0/24', first='172.16.1.10', last='172.16.1.250')
-    network2 = models.Network(name='network2', prefix='172.16.20.0/26', first='172.16.20.11', last='172.16.20.14')
-    for network in (network1, network2):
-        session.add(network)
-    session.flush()
+    network1 = network_factory(prefix='172.16.1.0/24', first='172.16.1.10', last='172.16.1.250')
+    network2 = network_factory(prefix='172.16.20.0/26', first='172.16.20.11', last='172.16.20.14')
     for i in range(10, 20):
-        session.add(models.Host(network_id=network1.id, ip=f'172.16.1.{i}'))
-    session.add(models.Host(network_id=network2.id, ip='172.16.20.13'))
-    session.commit()
+        host_factory(network=network1, ip=f'172.16.1.{i}')
+    host_factory(network=network2, ip='172.16.20.13')
     # Check available and used IPs
     assert network1.used_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(10, 20)]
     assert network1.available_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(20, 251)]
@@ -58,9 +50,8 @@ def test_network_available_and_used_ips(session):
                                         ipaddress.ip_address('172.16.20.14')]
 
     # Add more hosts
-    session.add(models.Host(network_id=network2.id, ip='172.16.20.11'))
-    session.add(models.Host(network_id=network2.id, ip='172.16.20.14'))
-    session.commit()
+    host_factory(network=network2, ip='172.16.20.11')
+    host_factory(network=network2, ip='172.16.20.14')
     assert len(network2.used_ips()) == 3
     assert network2.used_ips() == [ipaddress.ip_address('172.16.20.11'),
                                    ipaddress.ip_address('172.16.20.13'),
@@ -68,7 +59,6 @@ def test_network_available_and_used_ips(session):
     assert network2.available_ips() == [ipaddress.ip_address('172.16.20.12')]
 
     # Add last available IP
-    session.add(models.Host(network_id=network2.id, ip='172.16.20.12'))
-    session.commit()
+    host_factory(network=network2, ip='172.16.20.12')
     assert network2.used_ips() == [ipaddress.ip_address(f'172.16.20.{i}') for i in range(11, 15)]
     assert list(network2.available_ips()) == []
diff --git a/tests/functional/test_web.py b/tests/functional/test_web.py
index 3f85e18..9233a07 100644
--- a/tests/functional/test_web.py
+++ b/tests/functional/test_web.py
@@ -73,13 +73,12 @@ def test_protected_url(url, client):
     assert response.status_code == 200
 
 
-def test_retrieve_items(logged_client, session):
+def test_retrieve_items(logged_client, item_factory):
     response = get(logged_client, '/_retrieve_items')
     assert response.json['data'] == []
     serial_numbers = ('12345', '45678')
     for sn in serial_numbers:
-        session.add(models.Item(serial_number=sn))
-    session.commit()
+        item_factory(serial_number=sn)
     response = get(logged_client, '/_retrieve_items')
     items = response.json['data']
     assert set(serial_numbers) == set(item[4] for item in items)
-- 
GitLab