Skip to content
Snippets Groups Projects
Commit 10a7d576 authored by Benjamin Bertrand's avatar Benjamin Bertrand
Browse files

Fix tests after refactoring

parent 1fbdf91e
No related branches found
No related tags found
No related merge requests found
......@@ -13,7 +13,7 @@ from flask import Blueprint, jsonify, request, current_app
from flask_jwt_extended import jwt_required
from .. import utils, models
from ..decorators import jwt_groups_accepted
from .utils import commit, create_generic_model, get_generic_model
from .utils import commit, create_generic_model, get_qrcode_model, get_generic_model
bp = Blueprint('inventory_api', __name__)
......@@ -36,10 +36,8 @@ def get_item_by_id_or_ics_id(id_):
@jwt_required
def get_items():
# TODO: add pagination
query = utils.get_query(models.Item.query, request.args)
items = query.order_by(models.Item._created)
data = [item.to_dict() for item in items]
return jsonify(data)
return get_generic_model(models.Item, request.args,
order_by=models.Item.created_at)
@bp.route('/items/<id_>')
......@@ -118,13 +116,13 @@ def patch_item(id_):
@bp.route('/actions')
@jwt_required
def get_actions():
return get_generic_model(models.Action, request.args)
return get_qrcode_model(models.Action, request.args)
@bp.route('/manufacturers')
@jwt_required
def get_manufacturers():
return get_generic_model(models.Manufacturer, request.args)
return get_qrcode_model(models.Manufacturer, request.args)
@bp.route('/manufacturers', methods=['POST'])
......@@ -137,7 +135,7 @@ def create_manufacturer():
@bp.route('/models')
@jwt_required
def get_models():
return get_generic_model(models.Model, request.args)
return get_qrcode_model(models.Model, request.args)
@bp.route('/models', methods=['POST'])
......@@ -150,7 +148,7 @@ def create_model():
@bp.route('/locations')
@jwt_required
def get_locations():
return get_generic_model(models.Location, request.args)
return get_qrcode_model(models.Location, request.args)
@bp.route('/locations', methods=['POST'])
......@@ -163,7 +161,7 @@ def create_locations():
@bp.route('/status')
@jwt_required
def get_status():
return get_generic_model(models.Status, request.args)
return get_qrcode_model(models.Status, request.args)
@bp.route('/status', methods=['POST'])
......
......@@ -9,23 +9,38 @@ This module implements the network API.
:license: BSD 2-Clause, see LICENSE for more details.
"""
from flask import Blueprint, jsonify, request
from flask import Blueprint, request
from flask_jwt_extended import jwt_required
from .. import utils, models
from .. import models
from ..decorators import jwt_groups_accepted
from .utils import create_generic_model
from .utils import get_generic_model, create_generic_model
bp = Blueprint('network_api', __name__)
@bp.route('/scopes')
@jwt_required
def get_scopes():
# TODO: add pagination
return get_generic_model(models.NetworkScope, request.args,
order_by=models.NetworkScope.name)
@bp.route('/scopes', methods=['POST'])
@jwt_required
@jwt_groups_accepted('admin')
def create_scope():
"""Create a new network scope"""
return create_generic_model(models.NetworkScope, mandatory_fields=(
'name', 'first_vlan', 'last_vlan', 'supernet'))
@bp.route('/networks')
@jwt_required
def get_networks():
# TODO: add pagination
query = utils.get_query(models.Network.query, request.args)
networks = query.order_by(models.Network.address)
data = [network.to_dict() for network in networks]
return jsonify(data)
return get_generic_model(models.Network, request.args,
order_by=models.Network.address)
@bp.route('/networks', methods=['POST'])
......@@ -41,10 +56,8 @@ def create_network():
@jwt_required
def get_interfaces():
# TODO: add pagination
query = utils.get_query(models.Interface.query, request.args)
interfaces = query.order_by(models.Interface.ip)
data = [interface.to_dict() for interface in interfaces]
return jsonify(data)
return get_generic_model(models.Interface, request.args,
order_by=models.Interface.ip)
@bp.route('/interfaces', methods=['POST'])
......
......@@ -23,7 +23,23 @@ def commit():
raise utils.CSEntryError(str(e), status_code=422)
def get_generic_model(model, args):
def get_generic_model(model, args, order_by=None):
"""Return data from model as json
:param model: model class
:param MultiDict args: args from the request
:param order_by: column to order the result by
:returns: data from model as json
"""
query = utils.get_query(model.query, request.args)
if order_by is None:
order_by = getattr(model, 'created_at')
instances = query.order_by(order_by)
data = [instance.to_dict() for instance in instances]
return jsonify(data)
def get_qrcode_model(model, args):
"""Return data from model as json
:param model: model class
......
......@@ -24,7 +24,7 @@ def login():
form = LDAPLoginForm(request.form)
if form.validate_on_submit():
login_user(form.user, remember=form.remember_me.data)
return redirect(request.args.get('next') or url_for('items.index'))
return redirect(request.args.get('next') or url_for('main.index'))
return render_template('users/login.html', form=form)
......
......@@ -25,6 +25,7 @@ register(factories.StatusFactory)
register(factories.ItemFactory)
register(factories.NetworkScopeFactory)
register(factories.NetworkFactory)
register(factories.InterfaceFactory)
register(factories.HostFactory)
......
......@@ -87,7 +87,7 @@ class NetworkScopeFactory(factory.alchemy.SQLAlchemyModelFactory):
name = factory.Sequence(lambda n: f'scope{n}')
first_vlan = factory.Sequence(lambda n: 1600 + 10 * n)
last_vlan = factory.Sequence(lambda n: 1609 + 10 * n)
subnet = factory.Faker('ipv4', network=True)
supernet = factory.Faker('ipv4', network=True)
class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory):
......@@ -97,8 +97,8 @@ class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory):
sqlalchemy_session_persistence = 'commit'
vlan_name = factory.Sequence(lambda n: f'vlan{n}')
address = factory.Faker('ipv4', network=True)
vlan_id = factory.Sequence(lambda n: 1600 + n)
address = factory.Faker('ipv4', network=True)
scope = factory.SubFactory(NetworkScopeFactory)
@factory.lazy_attribute
......@@ -113,11 +113,16 @@ class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory):
hosts = list(net.hosts())
return str(hosts[-5])
@factory.lazy_attribute
def gateway(self):
net = ipaddress.ip_network(self.address)
hosts = list(net.hosts())
return str(hosts[-1])
class InterfaceFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
model = models.Interface
sqlalchemy_session = common.Session
sqlalchemy_session_persistence = 'commit'
name = factory.Sequence(lambda n: f'host{n}')
network = factory.SubFactory(NetworkFactory)
ip = factory.LazyAttributeSequence(lambda o, n: str(ipaddress.ip_address(o.network.first_ip) + n))
class HostFactory(factory.alchemy.SQLAlchemyModelFactory):
......@@ -127,4 +132,3 @@ class HostFactory(factory.alchemy.SQLAlchemyModelFactory):
sqlalchemy_session_persistence = 'commit'
name = factory.Sequence(lambda n: f'host{n}')
network = factory.SubFactory(NetworkFactory)
This diff is collapsed.
......@@ -34,13 +34,13 @@ def test_network_ip_properties(network_factory):
assert network2.used_ips() == []
def test_network_available_and_used_ips(network_factory, host_factory):
# Create some networks and hosts
def test_network_available_and_used_ips(network_factory, interface_factory):
# Create some networks and interfaces
network1 = network_factory(address='172.16.1.0/24', first_ip='172.16.1.10', last_ip='172.16.1.250')
network2 = network_factory(address='172.16.20.0/26', first_ip='172.16.20.11', last_ip='172.16.20.14')
for i in range(10, 20):
host_factory(network=network1, ip=f'172.16.1.{i}')
host_factory(network=network2, ip='172.16.20.13')
interface_factory(network=network1, ip=f'172.16.1.{i}')
interface_factory(network=network2, ip='172.16.20.13')
# Check available and used IPs
assert network1.used_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(10, 20)]
assert network1.available_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(20, 251)]
......@@ -49,9 +49,9 @@ def test_network_available_and_used_ips(network_factory, host_factory):
ipaddress.ip_address('172.16.20.12'),
ipaddress.ip_address('172.16.20.14')]
# Add more hosts
host_factory(network=network2, ip='172.16.20.11')
host_factory(network=network2, ip='172.16.20.14')
# Add more interfaces
interface_factory(network=network2, ip='172.16.20.11')
interface_factory(network=network2, ip='172.16.20.14')
assert len(network2.used_ips()) == 3
assert network2.used_ips() == [ipaddress.ip_address('172.16.20.11'),
ipaddress.ip_address('172.16.20.13'),
......@@ -59,6 +59,6 @@ def test_network_available_and_used_ips(network_factory, host_factory):
assert network2.available_ips() == [ipaddress.ip_address('172.16.20.12')]
# Add last available IP
host_factory(network=network2, ip='172.16.20.12')
interface_factory(network=network2, ip='172.16.20.12')
assert network2.used_ips() == [ipaddress.ip_address(f'172.16.20.{i}') for i in range(11, 15)]
assert list(network2.available_ips()) == []
......@@ -11,7 +11,6 @@ This module defines basic web tests.
"""
import json
import pytest
from app import models
def get(client, url):
......@@ -26,11 +25,11 @@ def login(client, username, password):
'username': username,
'password': password
}
return client.post('/login', data=data, follow_redirects=True)
return client.post('/users/login', data=data, follow_redirects=True)
def logout(client):
return client.get('/logout', follow_redirects=True)
return client.get('/users/logout', follow_redirects=True)
@pytest.fixture
......@@ -60,26 +59,26 @@ def test_index(logged_client):
@pytest.mark.parametrize('url', [
'/',
'/items',
'/qrcodes',
'_retrieve_items',
'/inventory/items',
'/inventory/_retrieve_items',
'/network/networks',
])
def test_protected_url(url, client):
response = client.get(url)
assert response.status_code == 302
assert '/login' in response.headers['Location']
assert '/users/login' in response.headers['Location']
login(client, 'user_ro', 'userro')
response = client.get(url)
assert response.status_code == 200
def test_retrieve_items(logged_client, item_factory):
response = get(logged_client, '/_retrieve_items')
response = get(logged_client, '/inventory/_retrieve_items')
assert response.json['data'] == []
serial_numbers = ('12345', '45678')
for sn in serial_numbers:
item_factory(serial_number=sn)
response = get(logged_client, '/_retrieve_items')
response = get(logged_client, '/inventory/_retrieve_items')
items = response.json['data']
assert set(serial_numbers) == set(item[4] for item in items)
assert len(items[0]) == 10
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment