Skip to content
Snippets Groups Projects
Commit 10a7d576 authored by Benjamin Bertrand's avatar Benjamin Bertrand
Browse files

Fix tests after refactoring

parent 1fbdf91e
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ from flask import Blueprint, jsonify, request, current_app ...@@ -13,7 +13,7 @@ from flask import Blueprint, jsonify, request, current_app
from flask_jwt_extended import jwt_required from flask_jwt_extended import jwt_required
from .. import utils, models from .. import utils, models
from ..decorators import jwt_groups_accepted from ..decorators import jwt_groups_accepted
from .utils import commit, create_generic_model, get_generic_model from .utils import commit, create_generic_model, get_qrcode_model, get_generic_model
bp = Blueprint('inventory_api', __name__) bp = Blueprint('inventory_api', __name__)
...@@ -36,10 +36,8 @@ def get_item_by_id_or_ics_id(id_): ...@@ -36,10 +36,8 @@ def get_item_by_id_or_ics_id(id_):
@jwt_required @jwt_required
def get_items(): def get_items():
# TODO: add pagination # TODO: add pagination
query = utils.get_query(models.Item.query, request.args) return get_generic_model(models.Item, request.args,
items = query.order_by(models.Item._created) order_by=models.Item.created_at)
data = [item.to_dict() for item in items]
return jsonify(data)
@bp.route('/items/<id_>') @bp.route('/items/<id_>')
...@@ -118,13 +116,13 @@ def patch_item(id_): ...@@ -118,13 +116,13 @@ def patch_item(id_):
@bp.route('/actions') @bp.route('/actions')
@jwt_required @jwt_required
def get_actions(): def get_actions():
return get_generic_model(models.Action, request.args) return get_qrcode_model(models.Action, request.args)
@bp.route('/manufacturers') @bp.route('/manufacturers')
@jwt_required @jwt_required
def get_manufacturers(): def get_manufacturers():
return get_generic_model(models.Manufacturer, request.args) return get_qrcode_model(models.Manufacturer, request.args)
@bp.route('/manufacturers', methods=['POST']) @bp.route('/manufacturers', methods=['POST'])
...@@ -137,7 +135,7 @@ def create_manufacturer(): ...@@ -137,7 +135,7 @@ def create_manufacturer():
@bp.route('/models') @bp.route('/models')
@jwt_required @jwt_required
def get_models(): def get_models():
return get_generic_model(models.Model, request.args) return get_qrcode_model(models.Model, request.args)
@bp.route('/models', methods=['POST']) @bp.route('/models', methods=['POST'])
...@@ -150,7 +148,7 @@ def create_model(): ...@@ -150,7 +148,7 @@ def create_model():
@bp.route('/locations') @bp.route('/locations')
@jwt_required @jwt_required
def get_locations(): def get_locations():
return get_generic_model(models.Location, request.args) return get_qrcode_model(models.Location, request.args)
@bp.route('/locations', methods=['POST']) @bp.route('/locations', methods=['POST'])
...@@ -163,7 +161,7 @@ def create_locations(): ...@@ -163,7 +161,7 @@ def create_locations():
@bp.route('/status') @bp.route('/status')
@jwt_required @jwt_required
def get_status(): def get_status():
return get_generic_model(models.Status, request.args) return get_qrcode_model(models.Status, request.args)
@bp.route('/status', methods=['POST']) @bp.route('/status', methods=['POST'])
......
...@@ -9,23 +9,38 @@ This module implements the network API. ...@@ -9,23 +9,38 @@ This module implements the network API.
:license: BSD 2-Clause, see LICENSE for more details. :license: BSD 2-Clause, see LICENSE for more details.
""" """
from flask import Blueprint, jsonify, request from flask import Blueprint, request
from flask_jwt_extended import jwt_required from flask_jwt_extended import jwt_required
from .. import utils, models from .. import models
from ..decorators import jwt_groups_accepted from ..decorators import jwt_groups_accepted
from .utils import create_generic_model from .utils import get_generic_model, create_generic_model
bp = Blueprint('network_api', __name__) bp = Blueprint('network_api', __name__)
@bp.route('/scopes')
@jwt_required
def get_scopes():
# TODO: add pagination
return get_generic_model(models.NetworkScope, request.args,
order_by=models.NetworkScope.name)
@bp.route('/scopes', methods=['POST'])
@jwt_required
@jwt_groups_accepted('admin')
def create_scope():
"""Create a new network scope"""
return create_generic_model(models.NetworkScope, mandatory_fields=(
'name', 'first_vlan', 'last_vlan', 'supernet'))
@bp.route('/networks') @bp.route('/networks')
@jwt_required @jwt_required
def get_networks(): def get_networks():
# TODO: add pagination # TODO: add pagination
query = utils.get_query(models.Network.query, request.args) return get_generic_model(models.Network, request.args,
networks = query.order_by(models.Network.address) order_by=models.Network.address)
data = [network.to_dict() for network in networks]
return jsonify(data)
@bp.route('/networks', methods=['POST']) @bp.route('/networks', methods=['POST'])
...@@ -41,10 +56,8 @@ def create_network(): ...@@ -41,10 +56,8 @@ def create_network():
@jwt_required @jwt_required
def get_interfaces(): def get_interfaces():
# TODO: add pagination # TODO: add pagination
query = utils.get_query(models.Interface.query, request.args) return get_generic_model(models.Interface, request.args,
interfaces = query.order_by(models.Interface.ip) order_by=models.Interface.ip)
data = [interface.to_dict() for interface in interfaces]
return jsonify(data)
@bp.route('/interfaces', methods=['POST']) @bp.route('/interfaces', methods=['POST'])
......
...@@ -23,7 +23,23 @@ def commit(): ...@@ -23,7 +23,23 @@ def commit():
raise utils.CSEntryError(str(e), status_code=422) raise utils.CSEntryError(str(e), status_code=422)
def get_generic_model(model, args): def get_generic_model(model, args, order_by=None):
"""Return data from model as json
:param model: model class
:param MultiDict args: args from the request
:param order_by: column to order the result by
:returns: data from model as json
"""
query = utils.get_query(model.query, request.args)
if order_by is None:
order_by = getattr(model, 'created_at')
instances = query.order_by(order_by)
data = [instance.to_dict() for instance in instances]
return jsonify(data)
def get_qrcode_model(model, args):
"""Return data from model as json """Return data from model as json
:param model: model class :param model: model class
......
...@@ -24,7 +24,7 @@ def login(): ...@@ -24,7 +24,7 @@ def login():
form = LDAPLoginForm(request.form) form = LDAPLoginForm(request.form)
if form.validate_on_submit(): if form.validate_on_submit():
login_user(form.user, remember=form.remember_me.data) login_user(form.user, remember=form.remember_me.data)
return redirect(request.args.get('next') or url_for('items.index')) return redirect(request.args.get('next') or url_for('main.index'))
return render_template('users/login.html', form=form) return render_template('users/login.html', form=form)
......
...@@ -25,6 +25,7 @@ register(factories.StatusFactory) ...@@ -25,6 +25,7 @@ register(factories.StatusFactory)
register(factories.ItemFactory) register(factories.ItemFactory)
register(factories.NetworkScopeFactory) register(factories.NetworkScopeFactory)
register(factories.NetworkFactory) register(factories.NetworkFactory)
register(factories.InterfaceFactory)
register(factories.HostFactory) register(factories.HostFactory)
......
...@@ -87,7 +87,7 @@ class NetworkScopeFactory(factory.alchemy.SQLAlchemyModelFactory): ...@@ -87,7 +87,7 @@ class NetworkScopeFactory(factory.alchemy.SQLAlchemyModelFactory):
name = factory.Sequence(lambda n: f'scope{n}') name = factory.Sequence(lambda n: f'scope{n}')
first_vlan = factory.Sequence(lambda n: 1600 + 10 * n) first_vlan = factory.Sequence(lambda n: 1600 + 10 * n)
last_vlan = factory.Sequence(lambda n: 1609 + 10 * n) last_vlan = factory.Sequence(lambda n: 1609 + 10 * n)
subnet = factory.Faker('ipv4', network=True) supernet = factory.Faker('ipv4', network=True)
class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory): class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory):
...@@ -97,8 +97,8 @@ class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory): ...@@ -97,8 +97,8 @@ class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory):
sqlalchemy_session_persistence = 'commit' sqlalchemy_session_persistence = 'commit'
vlan_name = factory.Sequence(lambda n: f'vlan{n}') vlan_name = factory.Sequence(lambda n: f'vlan{n}')
address = factory.Faker('ipv4', network=True)
vlan_id = factory.Sequence(lambda n: 1600 + n) vlan_id = factory.Sequence(lambda n: 1600 + n)
address = factory.Faker('ipv4', network=True)
scope = factory.SubFactory(NetworkScopeFactory) scope = factory.SubFactory(NetworkScopeFactory)
@factory.lazy_attribute @factory.lazy_attribute
...@@ -113,11 +113,16 @@ class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory): ...@@ -113,11 +113,16 @@ class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory):
hosts = list(net.hosts()) hosts = list(net.hosts())
return str(hosts[-5]) return str(hosts[-5])
@factory.lazy_attribute
def gateway(self): class InterfaceFactory(factory.alchemy.SQLAlchemyModelFactory):
net = ipaddress.ip_network(self.address) class Meta:
hosts = list(net.hosts()) model = models.Interface
return str(hosts[-1]) sqlalchemy_session = common.Session
sqlalchemy_session_persistence = 'commit'
name = factory.Sequence(lambda n: f'host{n}')
network = factory.SubFactory(NetworkFactory)
ip = factory.LazyAttributeSequence(lambda o, n: str(ipaddress.ip_address(o.network.first_ip) + n))
class HostFactory(factory.alchemy.SQLAlchemyModelFactory): class HostFactory(factory.alchemy.SQLAlchemyModelFactory):
...@@ -127,4 +132,3 @@ class HostFactory(factory.alchemy.SQLAlchemyModelFactory): ...@@ -127,4 +132,3 @@ class HostFactory(factory.alchemy.SQLAlchemyModelFactory):
sqlalchemy_session_persistence = 'commit' sqlalchemy_session_persistence = 'commit'
name = factory.Sequence(lambda n: f'host{n}') name = factory.Sequence(lambda n: f'host{n}')
network = factory.SubFactory(NetworkFactory)
This diff is collapsed.
...@@ -34,13 +34,13 @@ def test_network_ip_properties(network_factory): ...@@ -34,13 +34,13 @@ def test_network_ip_properties(network_factory):
assert network2.used_ips() == [] assert network2.used_ips() == []
def test_network_available_and_used_ips(network_factory, host_factory): def test_network_available_and_used_ips(network_factory, interface_factory):
# Create some networks and hosts # Create some networks and interfaces
network1 = network_factory(address='172.16.1.0/24', first_ip='172.16.1.10', last_ip='172.16.1.250') network1 = network_factory(address='172.16.1.0/24', first_ip='172.16.1.10', last_ip='172.16.1.250')
network2 = network_factory(address='172.16.20.0/26', first_ip='172.16.20.11', last_ip='172.16.20.14') network2 = network_factory(address='172.16.20.0/26', first_ip='172.16.20.11', last_ip='172.16.20.14')
for i in range(10, 20): for i in range(10, 20):
host_factory(network=network1, ip=f'172.16.1.{i}') interface_factory(network=network1, ip=f'172.16.1.{i}')
host_factory(network=network2, ip='172.16.20.13') interface_factory(network=network2, ip='172.16.20.13')
# Check available and used IPs # Check available and used IPs
assert network1.used_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(10, 20)] assert network1.used_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(10, 20)]
assert network1.available_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(20, 251)] assert network1.available_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(20, 251)]
...@@ -49,9 +49,9 @@ def test_network_available_and_used_ips(network_factory, host_factory): ...@@ -49,9 +49,9 @@ def test_network_available_and_used_ips(network_factory, host_factory):
ipaddress.ip_address('172.16.20.12'), ipaddress.ip_address('172.16.20.12'),
ipaddress.ip_address('172.16.20.14')] ipaddress.ip_address('172.16.20.14')]
# Add more hosts # Add more interfaces
host_factory(network=network2, ip='172.16.20.11') interface_factory(network=network2, ip='172.16.20.11')
host_factory(network=network2, ip='172.16.20.14') interface_factory(network=network2, ip='172.16.20.14')
assert len(network2.used_ips()) == 3 assert len(network2.used_ips()) == 3
assert network2.used_ips() == [ipaddress.ip_address('172.16.20.11'), assert network2.used_ips() == [ipaddress.ip_address('172.16.20.11'),
ipaddress.ip_address('172.16.20.13'), ipaddress.ip_address('172.16.20.13'),
...@@ -59,6 +59,6 @@ def test_network_available_and_used_ips(network_factory, host_factory): ...@@ -59,6 +59,6 @@ def test_network_available_and_used_ips(network_factory, host_factory):
assert network2.available_ips() == [ipaddress.ip_address('172.16.20.12')] assert network2.available_ips() == [ipaddress.ip_address('172.16.20.12')]
# Add last available IP # Add last available IP
host_factory(network=network2, ip='172.16.20.12') interface_factory(network=network2, ip='172.16.20.12')
assert network2.used_ips() == [ipaddress.ip_address(f'172.16.20.{i}') for i in range(11, 15)] assert network2.used_ips() == [ipaddress.ip_address(f'172.16.20.{i}') for i in range(11, 15)]
assert list(network2.available_ips()) == [] assert list(network2.available_ips()) == []
...@@ -11,7 +11,6 @@ This module defines basic web tests. ...@@ -11,7 +11,6 @@ This module defines basic web tests.
""" """
import json import json
import pytest import pytest
from app import models
def get(client, url): def get(client, url):
...@@ -26,11 +25,11 @@ def login(client, username, password): ...@@ -26,11 +25,11 @@ def login(client, username, password):
'username': username, 'username': username,
'password': password 'password': password
} }
return client.post('/login', data=data, follow_redirects=True) return client.post('/users/login', data=data, follow_redirects=True)
def logout(client): def logout(client):
return client.get('/logout', follow_redirects=True) return client.get('/users/logout', follow_redirects=True)
@pytest.fixture @pytest.fixture
...@@ -60,26 +59,26 @@ def test_index(logged_client): ...@@ -60,26 +59,26 @@ def test_index(logged_client):
@pytest.mark.parametrize('url', [ @pytest.mark.parametrize('url', [
'/', '/',
'/items', '/inventory/items',
'/qrcodes', '/inventory/_retrieve_items',
'_retrieve_items', '/network/networks',
]) ])
def test_protected_url(url, client): def test_protected_url(url, client):
response = client.get(url) response = client.get(url)
assert response.status_code == 302 assert response.status_code == 302
assert '/login' in response.headers['Location'] assert '/users/login' in response.headers['Location']
login(client, 'user_ro', 'userro') login(client, 'user_ro', 'userro')
response = client.get(url) response = client.get(url)
assert response.status_code == 200 assert response.status_code == 200
def test_retrieve_items(logged_client, item_factory): def test_retrieve_items(logged_client, item_factory):
response = get(logged_client, '/_retrieve_items') response = get(logged_client, '/inventory/_retrieve_items')
assert response.json['data'] == [] assert response.json['data'] == []
serial_numbers = ('12345', '45678') serial_numbers = ('12345', '45678')
for sn in serial_numbers: for sn in serial_numbers:
item_factory(serial_number=sn) item_factory(serial_number=sn)
response = get(logged_client, '/_retrieve_items') response = get(logged_client, '/inventory/_retrieve_items')
items = response.json['data'] items = response.json['data']
assert set(serial_numbers) == set(item[4] for item in items) assert set(serial_numbers) == set(item[4] for item in items)
assert len(items[0]) == 10 assert len(items[0]) == 10
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment