From 504d018bfdac3211ee603ef147858f836dc2d8ac Mon Sep 17 00:00:00 2001 From: Benjamin Bertrand <benjamin.bertrand@esss.se> Date: Fri, 6 Jul 2018 10:30:28 +0200 Subject: [PATCH] Reformat using black https://black.readthedocs.io/en/stable/ --- README.rst | 3 + app/admin/views.py | 32 +- app/api/inventory.py | 91 +- app/api/network.py | 125 +- app/api/user.py | 42 +- app/api/utils.py | 46 +- app/commands.py | 41 +- app/decorators.py | 12 +- app/defaults.py | 30 +- app/extensions.py | 10 +- app/factory.py | 119 +- app/helpers.py | 1 - app/inventory/forms.py | 78 +- app/inventory/views.py | 304 ++-- app/main/views.py | 22 +- app/models.py | 659 +++++---- app/network/forms.py | 212 +-- app/network/views.py | 441 +++--- app/plugins.py | 8 +- app/settings.py | 78 +- app/task/views.py | 17 +- app/tasks.py | 115 +- app/tokens.py | 22 +- app/user/forms.py | 2 +- app/user/views.py | 67 +- app/utils.py | 49 +- app/validators.py | 22 +- docs/conf.py | 83 +- migrations/env.py | 33 +- ...33_add_stack_member_field_to_item_table.py | 31 +- migrations/versions/713ca10255ab_.py | 613 ++++---- .../versions/7d0d580cdb1a_add_task_table.py | 38 +- ...llow_to_associate_several_items_to_one_.py | 47 +- ...8f135d5efde2_rename_virtual_device_type.py | 24 +- ...fa1_add_user_favorite_attributes_tables.py | 120 +- ...9442567c6dc_add_exception_to_task_table.py | 8 +- .../ac6b3c416b07_add_machine_type_table.py | 61 +- .../c0b8036078e7_add_fields_to_task_table.py | 12 +- .../versions/dfd4eae61224_add_domain_table.py | 77 +- ...70be_rename_machine_type_to_device_type.py | 62 +- ...a606be23b95_rename_physical_device_type.py | 24 +- ...05c0c835_remove_spaces_from_device_type.py | 36 +- tests/functional/conftest.py | 56 +- tests/functional/factories.py | 76 +- tests/functional/test_api.py | 1296 +++++++++++------ tests/functional/test_models.py | 140 +- tests/functional/test_web.py | 66 +- 47 files changed, 3292 insertions(+), 2259 deletions(-) diff --git a/README.rst b/README.rst index f904c54..7c420de 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,9 @@ CSEntry ======= +.. image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/ambv/black + Control System Entry web server. diff --git a/app/admin/views.py b/app/admin/views.py index d159e87..5839cab 100644 --- a/app/admin/views.py +++ b/app/admin/views.py @@ -28,14 +28,13 @@ sqla.form.Unique.__call__ = lambda x, y, z: None # Add custom model converter for CIText type # See https://github.com/flask-admin/flask-admin/issues/1196 class AppAdminModelConverter(sqla.form.AdminModelConverter): - - @converts('CIText') + @converts("CIText") def conv_CIText(self, field_args, **extra): return fields.TextAreaField(**field_args) - @converts('sqlalchemy.dialects.postgresql.base.CIDR') + @converts("sqlalchemy.dialects.postgresql.base.CIDR") def conv_PGCidr(self, field_args, **extra): - field_args['validators'].append(IPNetwork()) + field_args["validators"].append(IPNetwork()) return fields.StringField(**field_args) @@ -43,9 +42,7 @@ class AdminModelView(sqla.ModelView): model_form_converter = AppAdminModelConverter # Replace TextAreaField (default for Text) with StringField - form_overrides = { - 'name': fields.StringField, - } + form_overrides = {"name": fields.StringField} def is_accessible(self): return current_user.is_authenticated and current_user.is_admin @@ -65,16 +62,17 @@ class TokenAdmin(AdminModelView): class ItemAdmin(AdminModelView): # Replace TextAreaField (default for Text) with StringField - form_overrides = { - 'ics_id': fields.StringField, - 'serial_number': fields.StringField, - } + form_overrides = {"ics_id": fields.StringField, "serial_number": fields.StringField} form_args = { - 'ics_id': { - 'label': 'ICS id', - 'validators': [validators.Regexp(ICS_ID_RE, message='ICS id shall match [A-Z]{3}[0-9]{3}')], - 'filters': [lambda x: x or None], + "ics_id": { + "label": "ICS id", + "validators": [ + validators.Regexp( + ICS_ID_RE, message="ICS id shall match [A-Z]{3}[0-9]{3}" + ) + ], + "filters": [lambda x: x or None], } } @@ -82,9 +80,7 @@ class ItemAdmin(AdminModelView): class NetworkAdmin(AdminModelView): # Replace TextAreaField (default for Text) with StringField - form_overrides = { - 'vlan_name': fields.StringField, - } + form_overrides = {"vlan_name": fields.StringField} class TaskAdmin(AdminModelView): diff --git a/app/api/inventory.py b/app/api/inventory.py index abc2f05..1f86909 100644 --- a/app/api/inventory.py +++ b/app/api/inventory.py @@ -15,7 +15,7 @@ from .. import utils, models from ..decorators import jwt_groups_accepted from .utils import commit, create_generic_model, get_generic_model -bp = Blueprint('inventory_api', __name__) +bp = Blueprint("inventory_api", __name__) def get_item_by_id_or_ics_id(id_): @@ -32,18 +32,17 @@ def get_item_by_id_or_ics_id(id_): return item -@bp.route('/items') +@bp.route("/items") @jwt_required def get_items(): """Return items .. :quickref: Inventory; Get items """ - return get_generic_model(models.Item, - order_by=models.Item.created_at) + return get_generic_model(models.Item, order_by=models.Item.created_at) -@bp.route('/items/<id_>') +@bp.route("/items/<id_>") @jwt_required def get_item(id_): """Retrieve item by id or ICS id @@ -56,9 +55,9 @@ def get_item(id_): return jsonify(item.to_dict()) -@bp.route('/items', methods=['POST']) +@bp.route("/items", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin', 'create') +@jwt_groups_accepted("admin", "create") def create_item(): """Register a new item @@ -79,12 +78,12 @@ def create_item(): # But there are existing items (in confluence and JIRA) that we want to # import and associate after they have been created. # In that case a temporary id is automatically assigned. - return create_generic_model(models.Item, mandatory_fields=('serial_number',)) + return create_generic_model(models.Item, mandatory_fields=("serial_number",)) -@bp.route('/items/<id_>', methods=['PATCH']) +@bp.route("/items/<id_>", methods=["PATCH"]) @jwt_required -@jwt_groups_accepted('admin', 'create') +@jwt_groups_accepted("admin", "create") def patch_item(id_): """Patch an existing item @@ -100,24 +99,34 @@ def patch_item(id_): """ data = request.get_json() if data is None: - raise utils.CSEntryError('Body should be a JSON object') + raise utils.CSEntryError("Body should be a JSON object") if not data: - raise utils.CSEntryError('At least one field is required', status_code=422) + raise utils.CSEntryError("At least one field is required", status_code=422) for key in data.keys(): - if key not in ('ics_id', 'manufacturer', 'model', - 'location', 'status', 'parent'): + if key not in ( + "ics_id", + "manufacturer", + "model", + "location", + "status", + "parent", + ): raise utils.CSEntryError(f"Invalid field '{key}'", status_code=422) item = get_item_by_id_or_ics_id(id_) # Only allow to set ICS id if the current id is a temporary one - if item.ics_id.startswith(current_app.config['TEMPORARY_ICS_ID']): - item.ics_id = data.get('ics_id', item.ics_id) - elif 'ics_id' in data: + if item.ics_id.startswith(current_app.config["TEMPORARY_ICS_ID"]): + item.ics_id = data.get("ics_id", item.ics_id) + elif "ics_id" in data: raise utils.CSEntryError("'ics_id' can't be changed", status_code=422) - item.manufacturer = utils.convert_to_model(data.get('manufacturer', item.manufacturer), models.Manufacturer) - item.model = utils.convert_to_model(data.get('model', item.model), models.Model) - item.location = utils.convert_to_model(data.get('location', item.location), models.Location) - item.status = utils.convert_to_model(data.get('status', item.status), models.Status) - parent_ics_id = data.get('parent') + item.manufacturer = utils.convert_to_model( + data.get("manufacturer", item.manufacturer), models.Manufacturer + ) + item.model = utils.convert_to_model(data.get("model", item.model), models.Model) + item.location = utils.convert_to_model( + data.get("location", item.location), models.Location + ) + item.status = utils.convert_to_model(data.get("status", item.status), models.Status) + parent_ics_id = data.get("parent") if parent_ics_id is not None: parent = models.Item.query.filter_by(ics_id=parent_ics_id).first() if parent is not None: @@ -133,7 +142,7 @@ def patch_item(id_): return jsonify(item.to_dict()) -@bp.route('/items/<id_>/comments') +@bp.route("/items/<id_>/comments") @jwt_required def get_item_comments(id_): """Get item comments @@ -146,9 +155,9 @@ def get_item_comments(id_): return jsonify([comment.to_dict() for comment in item.comments]) -@bp.route('/items/<id_>/comments', methods=['POST']) +@bp.route("/items/<id_>/comments", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin', 'create') +@jwt_groups_accepted("admin", "create") def create_item_comment(id_): """Create a comment on item @@ -158,12 +167,12 @@ def create_item_comment(id_): :jsonparam body: comment body """ item = get_item_by_id_or_ics_id(id_) - return create_generic_model(models.ItemComment, - mandatory_fields=('body',), - item_id=item.id) + return create_generic_model( + models.ItemComment, mandatory_fields=("body",), item_id=item.id + ) -@bp.route('/actions') +@bp.route("/actions") @jwt_required def get_actions(): """Get actions @@ -173,7 +182,7 @@ def get_actions(): return get_generic_model(models.Action) -@bp.route('/manufacturers') +@bp.route("/manufacturers") @jwt_required def get_manufacturers(): """Get manufacturers @@ -183,9 +192,9 @@ def get_manufacturers(): return get_generic_model(models.Manufacturer) -@bp.route('/manufacturers', methods=['POST']) +@bp.route("/manufacturers", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin', 'create') +@jwt_groups_accepted("admin", "create") def create_manufacturer(): """Create a new manufacturer @@ -197,7 +206,7 @@ def create_manufacturer(): return create_generic_model(models.Manufacturer) -@bp.route('/models') +@bp.route("/models") @jwt_required def get_models(): """Get models @@ -207,9 +216,9 @@ def get_models(): return get_generic_model(models.Model) -@bp.route('/models', methods=['POST']) +@bp.route("/models", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin', 'create') +@jwt_groups_accepted("admin", "create") def create_model(): """Create a new model @@ -221,7 +230,7 @@ def create_model(): return create_generic_model(models.Model) -@bp.route('/locations') +@bp.route("/locations") @jwt_required def get_locations(): """Get locations @@ -231,9 +240,9 @@ def get_locations(): return get_generic_model(models.Location) -@bp.route('/locations', methods=['POST']) +@bp.route("/locations", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin', 'create') +@jwt_groups_accepted("admin", "create") def create_locations(): """Create a new location @@ -245,7 +254,7 @@ def create_locations(): return create_generic_model(models.Location) -@bp.route('/statuses') +@bp.route("/statuses") @jwt_required def get_status(): """Get statuses @@ -255,9 +264,9 @@ def get_status(): return get_generic_model(models.Status) -@bp.route('/statuses', methods=['POST']) +@bp.route("/statuses", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin', 'create') +@jwt_groups_accepted("admin", "create") def create_status(): """Create a new status diff --git a/app/api/network.py b/app/api/network.py index 2079f58..3c77322 100644 --- a/app/api/network.py +++ b/app/api/network.py @@ -15,23 +15,22 @@ from .. import models from ..decorators import jwt_groups_accepted from .utils import get_generic_model, create_generic_model -bp = Blueprint('network_api', __name__) +bp = Blueprint("network_api", __name__) -@bp.route('/scopes') +@bp.route("/scopes") @jwt_required def get_scopes(): """Return network scopes .. :quickref: Network; Get network scopes """ - return get_generic_model(models.NetworkScope, - order_by=models.NetworkScope.name) + return get_generic_model(models.NetworkScope, order_by=models.NetworkScope.name) -@bp.route('/scopes', methods=['POST']) +@bp.route("/scopes", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin') +@jwt_groups_accepted("admin") def create_scope(): """Create a new network scope @@ -44,24 +43,25 @@ def create_scope(): :jsonparam domain_id: primary key of the default domain :jsonparam description: (optional) description """ - return create_generic_model(models.NetworkScope, mandatory_fields=( - 'name', 'first_vlan', 'last_vlan', 'supernet', 'domain_id')) + return create_generic_model( + models.NetworkScope, + mandatory_fields=("name", "first_vlan", "last_vlan", "supernet", "domain_id"), + ) -@bp.route('/networks') +@bp.route("/networks") @jwt_required def get_networks(): """Return networks .. :quickref: Network; Get networks """ - return get_generic_model(models.Network, - order_by=models.Network.address) + return get_generic_model(models.Network, order_by=models.Network.address) -@bp.route('/networks', methods=['POST']) +@bp.route("/networks", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin') +@jwt_groups_accepted("admin") def create_network(): """Create a new network @@ -78,32 +78,42 @@ def create_network(): :type admin_only: bool :jsonparam description: (optional) description """ - return create_generic_model(models.Network, mandatory_fields=( - 'vlan_name', 'vlan_id', 'address', 'first_ip', 'last_ip', 'scope')) - - -@bp.route('/interfaces') + return create_generic_model( + models.Network, + mandatory_fields=( + "vlan_name", + "vlan_id", + "address", + "first_ip", + "last_ip", + "scope", + ), + ) + + +@bp.route("/interfaces") @jwt_required def get_interfaces(): """Return interfaces .. :quickref: Network; Get interfaces """ - domain = request.args.get('domain', None) + domain = request.args.get("domain", None) if domain is not None: query = models.Interface.query - query = query.join(models.Interface.network).join( - models.Network.domain).filter( - models.Domain.name == domain) + query = ( + query.join(models.Interface.network) + .join(models.Network.domain) + .filter(models.Domain.name == domain) + ) query = query.order_by(models.Interface.ip) return get_generic_model(model=None, query=query) - return get_generic_model(models.Interface, - order_by=models.Interface.ip) + return get_generic_model(models.Interface, order_by=models.Interface.ip) -@bp.route('/interfaces', methods=['POST']) +@bp.route("/interfaces", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin', 'create') +@jwt_groups_accepted("admin", "create") def create_interface(): """Create a new interface @@ -118,23 +128,24 @@ def create_interface(): # The validate_interfaces method from the Network class is called when # setting interface.network. This is why we don't pass network_id here # but network (as vlan_name string) - return create_generic_model(models.Interface, mandatory_fields=('network', 'ip', 'name')) + return create_generic_model( + models.Interface, mandatory_fields=("network", "ip", "name") + ) -@bp.route('/hosts') +@bp.route("/hosts") @jwt_required def get_hosts(): """Return hosts .. :quickref: Network; Get hosts """ - return get_generic_model(models.Host, - order_by=models.Host.name) + return get_generic_model(models.Host, order_by=models.Host.name) -@bp.route('/hosts', methods=['POST']) +@bp.route("/hosts", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin', 'create') +@jwt_groups_accepted("admin", "create") def create_host(): """Create a new host @@ -145,23 +156,22 @@ def create_host(): :jsonparam description: (optional) description :jsonparam items: (optional) list of items ICS id linked to the host """ - return create_generic_model(models.Host, mandatory_fields=('name', 'device_type')) + return create_generic_model(models.Host, mandatory_fields=("name", "device_type")) -@bp.route('/macs') +@bp.route("/macs") @jwt_required def get_macs(): """Return mac addresses .. :quickref: Network; Get mac addresses """ - return get_generic_model(models.Mac, - order_by=models.Mac.address) + return get_generic_model(models.Mac, order_by=models.Mac.address) -@bp.route('/macs', methods=['POST']) +@bp.route("/macs", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin', 'create') +@jwt_groups_accepted("admin", "create") def create_macs(): """Create a new mac address @@ -170,24 +180,22 @@ def create_macs(): :jsonparam address: MAC address :jsonparam item_id: (optional) linked item primary key """ - return create_generic_model(models.Mac, - mandatory_fields=('address',)) + return create_generic_model(models.Mac, mandatory_fields=("address",)) -@bp.route('/domains') +@bp.route("/domains") @jwt_required def get_domains(): """Return domains .. :quickref: Network; Get domains """ - return get_generic_model(models.Domain, - order_by=models.Domain.name) + return get_generic_model(models.Domain, order_by=models.Domain.name) -@bp.route('/domains', methods=['POST']) +@bp.route("/domains", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin') +@jwt_groups_accepted("admin") def create_domain(): """Create a new domain @@ -195,33 +203,33 @@ def create_domain(): :jsonparam name: domain name """ - return create_generic_model(models.Domain, - mandatory_fields=('name',)) + return create_generic_model(models.Domain, mandatory_fields=("name",)) -@bp.route('/cnames') +@bp.route("/cnames") @jwt_required def get_cnames(): """Return cnames .. :quickref: Network; Get cnames """ - domain = request.args.get('domain', None) + domain = request.args.get("domain", None) if domain is not None: query = models.Cname.query - query = query.join(models.Cname.interface).join( - models.Interface.network).join( - models.Network.domain).filter( - models.Domain.name == domain) + query = ( + query.join(models.Cname.interface) + .join(models.Interface.network) + .join(models.Network.domain) + .filter(models.Domain.name == domain) + ) query = query.order_by(models.Cname.name) return get_generic_model(model=None, query=query) - return get_generic_model(models.Cname, - order_by=models.Cname.name) + return get_generic_model(models.Cname, order_by=models.Cname.name) -@bp.route('/cnames', methods=['POST']) +@bp.route("/cnames", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin') +@jwt_groups_accepted("admin") def create_cname(): """Create a new cname @@ -230,5 +238,4 @@ def create_cname(): :jsonparam name: full cname :jsonparam interface_id: primary key of the associated interface """ - return create_generic_model(models.Cname, - mandatory_fields=('name', 'interface_id')) + return create_generic_model(models.Cname, mandatory_fields=("name", "interface_id")) diff --git a/app/api/user.py b/app/api/user.py index 7593787..3d8c56f 100644 --- a/app/api/user.py +++ b/app/api/user.py @@ -17,21 +17,20 @@ from ..decorators import jwt_groups_accepted from .. import utils, tokens, models from .utils import get_generic_model, create_generic_model -bp = Blueprint('user_api', __name__) +bp = Blueprint("user_api", __name__) -@bp.route('/users') +@bp.route("/users") @jwt_required def get_users(): """Return users information .. :quickref: User; Get users information """ - return get_generic_model(models.User, - order_by=models.User.username) + return get_generic_model(models.User, order_by=models.User.username) -@bp.route('/profile') +@bp.route("/profile") @jwt_required def get_user_profile(): """Return the current user profile @@ -42,9 +41,9 @@ def get_user_profile(): return jsonify(user.to_dict()), 200 -@bp.route('/users', methods=['POST']) +@bp.route("/users", methods=["POST"]) @jwt_required -@jwt_groups_accepted('admin') +@jwt_groups_accepted("admin") def create_user(): """Create a new user @@ -54,11 +53,12 @@ def create_user(): :jsonparam display_name: User's display name :jsonparam email: User's email """ - return create_generic_model(models.User, mandatory_fields=( - 'username', 'display_name', 'email')) + return create_generic_model( + models.User, mandatory_fields=("username", "display_name", "email") + ) -@bp.route('/login', methods=['POST']) +@bp.route("/login", methods=["POST"]) def login(): """Return a JSON Web Token @@ -69,20 +69,20 @@ def login(): """ data = request.get_json() if data is None: - raise utils.CSEntryError('Body should be a JSON object') + raise utils.CSEntryError("Body should be a JSON object") try: - username = data['username'] - password = data['password'] + username = data["username"] + password = data["password"] except KeyError: - raise utils.CSEntryError('Missing mandatory field (username or password)', status_code=422) + raise utils.CSEntryError( + "Missing mandatory field (username or password)", status_code=422 + ) response = ldap_manager.authenticate(username, password) if response.status == AuthenticationResponseStatus.success: - current_app.logger.debug(f'{username} successfully logged in') + current_app.logger.debug(f"{username} successfully logged in") user = ldap_manager._save_user( - response.user_dn, - response.user_id, - response.user_info, - response.user_groups) - payload = {'access_token': tokens.generate_access_token(identity=user.id)} + response.user_dn, response.user_id, response.user_info, response.user_groups + ) + payload = {"access_token": tokens.generate_access_token(identity=user.id)} return jsonify(payload), 200 - raise utils.CSEntryError('Invalid credentials', status_code=401) + raise utils.CSEntryError("Invalid credentials", status_code=401) diff --git a/app/api/utils.py b/app/api/utils.py index f4369e4..6197c01 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -32,30 +32,30 @@ def build_pagination_header(pagination, base_url, **kwargs): :param kwargs: extra query string parameters (without page and per_page) :returns: dict with X-Total-Count and Link keys """ - header = {'X-Total-Count': pagination.total} + header = {"X-Total-Count": pagination.total} links = [] if pagination.page > 1: - params = urllib.parse.urlencode({'per_page': pagination.per_page, - 'page': 1, - **kwargs}) + params = urllib.parse.urlencode( + {"per_page": pagination.per_page, "page": 1, **kwargs} + ) links.append(f'<{base_url}?{params}>; rel="first"') if pagination.has_prev: - params = urllib.parse.urlencode({'per_page': pagination.per_page, - 'page': pagination.prev_num, - **kwargs}) + params = urllib.parse.urlencode( + {"per_page": pagination.per_page, "page": pagination.prev_num, **kwargs} + ) links.append(f'<{base_url}?{params}>; rel="prev"') if pagination.has_next: - params = urllib.parse.urlencode({'per_page': pagination.per_page, - 'page': pagination.next_num, - **kwargs}) + params = urllib.parse.urlencode( + {"per_page": pagination.per_page, "page": pagination.next_num, **kwargs} + ) links.append(f'<{base_url}?{params}>; rel="next"') if pagination.pages > pagination.page: - params = urllib.parse.urlencode({'per_page': pagination.per_page, - 'page': pagination.pages, - **kwargs}) + params = urllib.parse.urlencode( + {"per_page": pagination.per_page, "page": pagination.pages, **kwargs} + ) links.append(f'<{base_url}?{params}>; rel="last"') if links: - header['Link'] = ', '.join(links) + header["Link"] = ", ".join(links) return header @@ -68,12 +68,12 @@ def get_generic_model(model, order_by=None, query=None): :returns: data from model as json """ kwargs = request.args.to_dict() - page = int(kwargs.pop('page', 1)) - per_page = int(kwargs.pop('per_page', 20)) + page = int(kwargs.pop("page", 1)) + per_page = int(kwargs.pop("per_page", 20)) if query is None: query = utils.get_query(model.query, **kwargs) if order_by is None: - order_by = getattr(model, 'name') + order_by = getattr(model, "name") query = query.order_by(order_by) pagination = query.paginate(page, per_page) data = [item.to_dict() for item in pagination.items] @@ -81,19 +81,21 @@ def get_generic_model(model, order_by=None, query=None): return jsonify(data), 200, header -def create_generic_model(model, mandatory_fields=('name',), **kwargs): +def create_generic_model(model, mandatory_fields=("name",), **kwargs): data = request.get_json() if data is None: - raise utils.CSEntryError('Body should be a JSON object') - current_app.logger.debug(f'Received: {data}') + raise utils.CSEntryError("Body should be a JSON object") + current_app.logger.debug(f"Received: {data}") data.update(kwargs) for mandatory_field in mandatory_fields: if mandatory_field not in data: - raise utils.CSEntryError(f"Missing mandatory field '{mandatory_field}'", status_code=422) + raise utils.CSEntryError( + f"Missing mandatory field '{mandatory_field}'", status_code=422 + ) try: instance = model(**data) except TypeError as e: - message = str(e).replace('__init__() got an ', '') + message = str(e).replace("__init__() got an ", "") raise utils.CSEntryError(message, status_code=422) except ValueError as e: raise utils.CSEntryError(str(e), status_code=422) diff --git a/app/commands.py b/app/commands.py index d590302..6b8b5e6 100644 --- a/app/commands.py +++ b/app/commands.py @@ -23,41 +23,46 @@ from . import utils, tokens def sync_user(connection, user): """Synchronize the user from the database with information from the LDAP server""" - search_attr = current_app.config.get('LDAP_USER_LOGIN_ATTR') - object_filter = current_app.config.get('LDAP_USER_OBJECT_FILTER') - search_filter = f'(&{object_filter}({search_attr}={user.username}))' + search_attr = current_app.config.get("LDAP_USER_LOGIN_ATTR") + object_filter = current_app.config.get("LDAP_USER_OBJECT_FILTER") + search_filter = f"(&{object_filter}({search_attr}={user.username}))" connection.search( search_base=ldap_manager.full_user_search_dn, search_filter=search_filter, - search_scope=getattr( - ldap3, current_app.config.get('LDAP_USER_SEARCH_SCOPE')), - attributes=current_app.config.get('LDAP_GET_USER_ATTRIBUTES') + search_scope=getattr(ldap3, current_app.config.get("LDAP_USER_SEARCH_SCOPE")), + attributes=current_app.config.get("LDAP_GET_USER_ATTRIBUTES"), ) if len(connection.response) == 1: ldap_user = connection.response[0] - attributes = ldap_user['attributes'] - user.display_name = utils.attribute_to_string(attributes['cn']) - user.email = utils.attribute_to_string(attributes['mail']) - groups = ldap_manager.get_user_groups(dn=ldap_user['dn'], _connection=connection) - user.groups = sorted([utils.attribute_to_string(group['cn']) for group in groups]) - current_app.logger.info(f'{user} updated') + attributes = ldap_user["attributes"] + user.display_name = utils.attribute_to_string(attributes["cn"]) + user.email = utils.attribute_to_string(attributes["mail"]) + groups = ldap_manager.get_user_groups( + dn=ldap_user["dn"], _connection=connection + ) + user.groups = sorted( + [utils.attribute_to_string(group["cn"]) for group in groups] + ) + current_app.logger.info(f"{user} updated") else: # Clear user's groups user.groups = [] # Revoke all user's tokens for token in user.tokens: db.session.delete(token) - current_app.logger.info(f'{user} disabled') + current_app.logger.info(f"{user} disabled") return user def sync_users(): """Synchronize all users from the database with information the LDAP server""" - current_app.logger.info('Synchronize database with information from the LDAP server') + current_app.logger.info( + "Synchronize database with information from the LDAP server" + ) try: connection = ldap_manager.connection except ldap3.core.exceptions.LDAPException as e: - current_app.logger.warning(f'Failed to connect to the LDAP server: {e}') + current_app.logger.warning(f"Failed to connect to the LDAP server: {e}") return for user in User.query.all(): sync_user(connection, user) @@ -74,7 +79,7 @@ def register_cli(app): db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - app.logger.debug(f'{instance} already exists') + app.logger.debug(f"{instance} already exists") @app.cli.command() def syncusers(): @@ -95,8 +100,8 @@ def register_cli(app): @app.cli.command() def runworker(): """Run RQ worker""" - redis_url = current_app.config['REDIS_URL'] + redis_url = current_app.config["REDIS_URL"] redis_connection = redis.from_url(redis_url) with rq.Connection(redis_connection): - worker = TaskWorker(current_app.config['QUEUES']) + worker = TaskWorker(current_app.config["QUEUES"]) worker.work() diff --git a/app/decorators.py b/app/decorators.py index 90da276..634cffe 100644 --- a/app/decorators.py +++ b/app/decorators.py @@ -33,16 +33,21 @@ def jwt_groups_accepted(*groups): :param groups: accepted groups """ + def wrapper(fn): @wraps(fn) def decorated_view(*args, **kwargs): user = get_current_user() if user is None: - raise CSEntryError('Invalid indentity', status_code=403) + raise CSEntryError("Invalid indentity", status_code=403) if not user.is_member_of_one_group(groups): - raise CSEntryError("User doesn't have the required group", status_code=403) + raise CSEntryError( + "User doesn't have the required group", status_code=403 + ) return fn(*args, **kwargs) + return decorated_view + return wrapper @@ -64,6 +69,7 @@ def login_groups_accepted(*groups): :param groups: accepted groups """ + def wrapper(fn): @wraps(fn) def decorated_view(*args, **kwargs): @@ -72,5 +78,7 @@ def login_groups_accepted(*groups): if not current_user.is_member_of_one_group(groups): abort(403) return fn(*args, **kwargs) + return decorated_view + return wrapper diff --git a/app/defaults.py b/app/defaults.py index 4eaf8a1..ff6f2ab 100644 --- a/app/defaults.py +++ b/app/defaults.py @@ -13,20 +13,18 @@ from . import models defaults = [ - models.Action(name='Assign ICS id'), - models.Action(name='Clear all'), - models.Action(name='Clear attributes'), - models.Action(name='Fetch'), - models.Action(name='Register'), - models.Action(name='Set as parent'), - models.Action(name='Update'), - - models.DeviceType(name='PhysicalMachine'), - models.DeviceType(name='VirtualMachine'), - models.DeviceType(name='Network'), - models.DeviceType(name='MicroTCA'), - models.DeviceType(name='VME'), - models.DeviceType(name='PLC'), - - models.Tag(name='IOC', admin_only=False), + models.Action(name="Assign ICS id"), + models.Action(name="Clear all"), + models.Action(name="Clear attributes"), + models.Action(name="Fetch"), + models.Action(name="Register"), + models.Action(name="Set as parent"), + models.Action(name="Update"), + models.DeviceType(name="PhysicalMachine"), + models.DeviceType(name="VirtualMachine"), + models.DeviceType(name="Network"), + models.DeviceType(name="MicroTCA"), + models.DeviceType(name="VME"), + models.DeviceType(name="PLC"), + models.Tag(name="IOC", admin_only=False), ] diff --git a/app/extensions.py b/app/extensions.py index f0dc0cb..1f5f023 100644 --- a/app/extensions.py +++ b/app/extensions.py @@ -25,23 +25,23 @@ from flask_caching import Cache convention = { - "ix": 'ix_%(column_0_label)s', + "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - "pk": "pk_%(table_name)s" + "pk": "pk_%(table_name)s", } metadata = MetaData(naming_convention=convention) -db = SQLAlchemy(metadata=metadata, session_options={'autoflush': False}) +db = SQLAlchemy(metadata=metadata, session_options={"autoflush": False}) migrate = Migrate(db=db) login_manager = LoginManager() ldap_manager = LDAP3LoginManager() bootstrap = Bootstrap() -admin = Admin(template_mode='bootstrap3') +admin = Admin(template_mode="bootstrap3") mail = Mail() jwt = JWTManager() toolbar = DebugToolbarExtension() -session_redis_store = FlaskRedis(config_prefix='SESSION_REDIS') +session_redis_store = FlaskRedis(config_prefix="SESSION_REDIS") fsession = Session() cache = Cache() diff --git a/app/factory.py b/app/factory.py index d1f7ae5..71dc66c 100644 --- a/app/factory.py +++ b/app/factory.py @@ -14,10 +14,28 @@ import rq_dashboard from flask import Flask from whitenoise import WhiteNoise from . import settings, models -from .extensions import (db, migrate, login_manager, ldap_manager, bootstrap, - admin, mail, jwt, toolbar, session_redis_store, fsession, cache) -from .admin.views import (AdminModelView, ItemAdmin, UserAdmin, TokenAdmin, - NetworkAdmin, TaskAdmin) +from .extensions import ( + db, + migrate, + login_manager, + ldap_manager, + bootstrap, + admin, + mail, + jwt, + toolbar, + session_redis_store, + fsession, + cache, +) +from .admin.views import ( + AdminModelView, + ItemAdmin, + UserAdmin, + TokenAdmin, + NetworkAdmin, + TaskAdmin, +) from .main.views import bp as main from .inventory.views import bp as inventory from .network.views import bp as network @@ -35,21 +53,27 @@ def create_app(config=None): app.config.from_object(rq_dashboard.default_settings) app.config.from_object(settings) - app.config.from_envvar('LOCAL_SETTINGS', silent=True) + app.config.from_envvar("LOCAL_SETTINGS", silent=True) app.config.update(config or {}) - app.jinja_env.filters['datetimeformat'] = utils.format_datetime + app.jinja_env.filters["datetimeformat"] = utils.format_datetime if not app.debug: import logging + # Send ERROR via mail from logging.handlers import SMTPHandler - mail_handler = SMTPHandler(app.config['MAIL_SERVER'], - fromaddr=app.config['EMAIL_SENDER'], - toaddrs=app.config['ADMIN_EMAILS'], - subject='CSEntry: ERROR raised', - credentials=app.config['MAIL_CREDENTIALS']) - mail_handler.setFormatter(logging.Formatter(""" + + mail_handler = SMTPHandler( + app.config["MAIL_SERVER"], + fromaddr=app.config["EMAIL_SENDER"], + toaddrs=app.config["ADMIN_EMAILS"], + subject="CSEntry: ERROR raised", + credentials=app.config["MAIL_CREDENTIALS"], + ) + mail_handler.setFormatter( + logging.Formatter( + """ Message type: %(levelname)s Location: %(pathname)s:%(lineno)d Module: %(module)s @@ -59,46 +83,59 @@ def create_app(config=None): Message: %(message)s - """)) + """ + ) + ) mail_handler.setLevel(logging.ERROR) app.logger.addHandler(mail_handler) # Log to stderr handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter( - '%(asctime)s %(levelname)s: %(message)s ' - '[in %(pathname)s:%(lineno)d]' - )) + handler.setFormatter( + logging.Formatter( + "%(asctime)s %(levelname)s: %(message)s " "[in %(pathname)s:%(lineno)d]" + ) + ) # Set app logger level to DEBUG # otherwise only WARNING and above are propagated app.logger.setLevel(logging.DEBUG) handler.setLevel(logging.DEBUG) app.logger.addHandler(handler) - app.logger.info('CSEntry created!') + app.logger.info("CSEntry created!") # Remove variables that contain a password - settings_to_display = [f'{key}: {value}' for key, value in app.config.items() - if key not in ('SECRET_KEY', 'LDAP_BIND_USER_PASSWORD', - 'MAIL_CREDENTIALS', 'SQLALCHEMY_DATABASE_URI')] + settings_to_display = [ + f"{key}: {value}" + for key, value in app.config.items() + if key + not in ( + "SECRET_KEY", + "LDAP_BIND_USER_PASSWORD", + "MAIL_CREDENTIALS", + "SQLALCHEMY_DATABASE_URI", + ) + ] # The repr() of make_url hides the password - settings_to_display.append(f'SQLALCHEMY_DATABASE_URI: {sa.engine.url.make_url(app.config["SQLALCHEMY_DATABASE_URI"])!r}') - settings_string = '\n'.join(settings_to_display) - app.logger.info(f'Settings:\n{settings_string}') + settings_to_display.append( + f'SQLALCHEMY_DATABASE_URI: {sa.engine.url.make_url(app.config["SQLALCHEMY_DATABASE_URI"])!r}' + ) + settings_string = "\n".join(settings_to_display) + app.logger.info(f"Settings:\n{settings_string}") bootstrap.init_app(app) db.init_app(app) migrate.init_app(app) login_manager.init_app(app) - login_manager.login_view = 'user.login' + login_manager.login_view = "user.login" ldap_manager.init_app(app) mail.init_app(app) jwt.init_app(app) toolbar.init_app(app) session_redis_store.init_app(app) - app.config['SESSION_REDIS'] = session_redis_store + app.config["SESSION_REDIS"] = session_redis_store fsession.init_app(app) cache.init_app(app) admin.init_app(app) - admin.add_view(UserAdmin(models.User, db.session, endpoint='users')) + admin.add_view(UserAdmin(models.User, db.session, endpoint="users")) admin.add_view(TokenAdmin(models.Token, db.session)) admin.add_view(AdminModelView(models.Action, db.session)) admin.add_view(AdminModelView(models.Manufacturer, db.session)) @@ -109,29 +146,29 @@ def create_app(config=None): admin.add_view(AdminModelView(models.ItemComment, db.session)) admin.add_view(AdminModelView(models.Domain, db.session)) admin.add_view(AdminModelView(models.NetworkScope, db.session)) - admin.add_view(NetworkAdmin(models.Network, db.session, endpoint='networks')) + admin.add_view(NetworkAdmin(models.Network, db.session, endpoint="networks")) admin.add_view(AdminModelView(models.DeviceType, db.session)) admin.add_view(AdminModelView(models.Host, db.session)) admin.add_view(AdminModelView(models.Interface, db.session)) admin.add_view(AdminModelView(models.Mac, db.session)) admin.add_view(AdminModelView(models.Cname, db.session)) admin.add_view(AdminModelView(models.Tag, db.session)) - admin.add_view(TaskAdmin(models.Task, db.session, endpoint='tasks')) + admin.add_view(TaskAdmin(models.Task, db.session, endpoint="tasks")) app.register_blueprint(main) - app.register_blueprint(inventory, url_prefix='/inventory') - app.register_blueprint(network, url_prefix='/network') - app.register_blueprint(task, url_prefix='/task') - app.register_blueprint(user, url_prefix='/user') - app.register_blueprint(user_api, url_prefix='/api/v1/user') - app.register_blueprint(inventory_api, url_prefix='/api/v1/inventory') - app.register_blueprint(network_api, url_prefix='/api/v1/network') - app.register_blueprint(rq_dashboard.blueprint, url_prefix='/rq') - - app.wsgi_app = WhiteNoise(app.wsgi_app, root='static/') + app.register_blueprint(inventory, url_prefix="/inventory") + app.register_blueprint(network, url_prefix="/network") + app.register_blueprint(task, url_prefix="/task") + app.register_blueprint(user, url_prefix="/user") + app.register_blueprint(user_api, url_prefix="/api/v1/user") + app.register_blueprint(inventory_api, url_prefix="/api/v1/inventory") + app.register_blueprint(network_api, url_prefix="/api/v1/network") + app.register_blueprint(rq_dashboard.blueprint, url_prefix="/rq") + + app.wsgi_app = WhiteNoise(app.wsgi_app, root="static/") app.wsgi_app.add_files( - root='/opt/conda/envs/csentry/lib/python3.6/site-packages/flask_bootstrap/static/', - prefix='bootstrap/' + root="/opt/conda/envs/csentry/lib/python3.6/site-packages/flask_bootstrap/static/", + prefix="bootstrap/", ) register_cli(app) diff --git a/app/helpers.py b/app/helpers.py index d68b378..9f79ac4 100644 --- a/app/helpers.py +++ b/app/helpers.py @@ -16,7 +16,6 @@ from . import models class CSEntryForm(FlaskForm): - def __init__(self, formdata=None, obj=None, **kwargs): # Store obj for Unique validator to check if the unique object # is identical to the one being edited diff --git a/app/inventory/forms.py b/app/inventory/forms.py index 1143efc..114da89 100644 --- a/app/inventory/forms.py +++ b/app/inventory/forms.py @@ -11,51 +11,69 @@ This module defines the inventory blueprint forms. """ from wtforms import SelectField, StringField, IntegerField, TextAreaField, validators from ..helpers import CSEntryForm -from ..validators import (Unique, RegexpList, ICS_ID_RE, MAC_ADDRESS_RE, - NoValidateSelectField) +from ..validators import ( + Unique, + RegexpList, + ICS_ID_RE, + MAC_ADDRESS_RE, + NoValidateSelectField, +) from .. import utils, models class AttributeForm(CSEntryForm): - name = StringField('Name', validators=[validators.DataRequired()]) - description = StringField('Description') + name = StringField("Name", validators=[validators.DataRequired()]) + description = StringField("Description") class ItemForm(CSEntryForm): - ics_id = StringField('ICS id', - validators=[validators.InputRequired(), - validators.Regexp(ICS_ID_RE), - Unique(models.Item, 'ics_id')]) - serial_number = StringField('Serial number', - validators=[validators.InputRequired()]) - quantity = IntegerField('Quantity', default=1, - validators=[validators.NumberRange(min=1)]) - manufacturer_id = SelectField('Manufacturer', coerce=utils.coerce_to_str_or_none) - model_id = SelectField('Model', coerce=utils.coerce_to_str_or_none) - location_id = SelectField('Location', coerce=utils.coerce_to_str_or_none) - status_id = SelectField('Status', coerce=utils.coerce_to_str_or_none) - parent_id = SelectField('Parent', coerce=utils.coerce_to_str_or_none) - host_id = SelectField('Host', coerce=utils.coerce_to_str_or_none) + ics_id = StringField( + "ICS id", + validators=[ + validators.InputRequired(), + validators.Regexp(ICS_ID_RE), + Unique(models.Item, "ics_id"), + ], + ) + serial_number = StringField( + "Serial number", validators=[validators.InputRequired()] + ) + quantity = IntegerField( + "Quantity", default=1, validators=[validators.NumberRange(min=1)] + ) + manufacturer_id = SelectField("Manufacturer", coerce=utils.coerce_to_str_or_none) + model_id = SelectField("Model", coerce=utils.coerce_to_str_or_none) + location_id = SelectField("Location", coerce=utils.coerce_to_str_or_none) + status_id = SelectField("Status", coerce=utils.coerce_to_str_or_none) + parent_id = SelectField("Parent", coerce=utils.coerce_to_str_or_none) + host_id = SelectField("Host", coerce=utils.coerce_to_str_or_none) stack_member = NoValidateSelectField( - 'Stack member', - coerce=utils.coerce_to_str_or_none, - choices=[]) + "Stack member", coerce=utils.coerce_to_str_or_none, choices=[] + ) mac_addresses = StringField( - 'MAC addresses', - description='space separated list of MAC addresses', - validators=[validators.Optional(), - RegexpList(MAC_ADDRESS_RE, message='Invalid MAC address')]) + "MAC addresses", + description="space separated list of MAC addresses", + validators=[ + validators.Optional(), + RegexpList(MAC_ADDRESS_RE, message="Invalid MAC address"), + ], + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.manufacturer_id.choices = utils.get_model_choices(models.Manufacturer, allow_none=True) + self.manufacturer_id.choices = utils.get_model_choices( + models.Manufacturer, allow_none=True + ) self.model_id.choices = utils.get_model_choices(models.Model, allow_none=True) - self.location_id.choices = utils.get_model_choices(models.Location, allow_none=True) + self.location_id.choices = utils.get_model_choices( + models.Location, allow_none=True + ) self.status_id.choices = utils.get_model_choices(models.Status, allow_none=True) - self.parent_id.choices = utils.get_model_choices(models.Item, allow_none=True, attr='ics_id') + self.parent_id.choices = utils.get_model_choices( + models.Item, allow_none=True, attr="ics_id" + ) self.host_id.choices = utils.get_model_choices(models.Host, allow_none=True) class CommentForm(CSEntryForm): - body = TextAreaField('Enter your comment:', - validators=[validators.DataRequired()]) + body = TextAreaField("Enter your comment:", validators=[validators.DataRequired()]) diff --git a/app/inventory/views.py b/app/inventory/views.py index 898ad9d..60efbb1 100644 --- a/app/inventory/views.py +++ b/app/inventory/views.py @@ -10,27 +10,36 @@ This module implements the inventory blueprint. """ import sqlalchemy as sa -from flask import (Blueprint, render_template, jsonify, session, - request, redirect, url_for, flash, current_app) +from flask import ( + Blueprint, + render_template, + jsonify, + session, + request, + redirect, + url_for, + flash, + current_app, +) from flask_login import login_required, current_user from .forms import AttributeForm, ItemForm, CommentForm from ..extensions import db from ..decorators import login_groups_accepted from .. import utils, models -bp = Blueprint('inventory', __name__) +bp = Blueprint("inventory", __name__) -@bp.route('/_retrieve_items') +@bp.route("/_retrieve_items") @login_required def retrieve_items(): # Get the parameters from the query string sent by datatables - draw = int(request.args.get('draw', 0)) - start = int(request.args.get('start', 0)) - per_page = int(request.args.get('length', 20)) - search = request.args.get('search[value]', '') - order_column = int(request.args.get('order[0][column]', 3)) - if request.args.get('order[0][dir]') == 'desc': + draw = int(request.args.get("draw", 0)) + start = int(request.args.get("start", 0)) + per_page = int(request.args.get("length", 20)) + search = request.args.get("search[value]", "") + order_column = int(request.args.get("order[0][column]", 3)) + if request.args.get("order[0][dir]") == "desc": order_dir = sa.desc else: order_dir = sa.asc @@ -39,158 +48,168 @@ def retrieve_items(): # Construct the query query = models.Item.query if search: - if '%' not in search: - search = f'%{search}%' + if "%" not in search: + search = f"%{search}%" q1 = query.filter( sa.or_( - models.Item.ics_id.like(search), - models.Item.serial_number.like(search), + models.Item.ics_id.like(search), models.Item.serial_number.like(search) ) ) q2 = query.join(models.Item.manufacturer).filter( - models.Manufacturer.name.like(search)) + models.Manufacturer.name.like(search) + ) q3 = query.join(models.Item.model).filter( sa.or_( - models.Model.name.like(search), - models.Model.description.like(search) + models.Model.name.like(search), models.Model.description.like(search) ) ) - q4 = query.join(models.Item.location).filter( - models.Location.name.like(search)) - q5 = query.join(models.Item.status).filter( - models.Status.name.like(search)) + q4 = query.join(models.Item.location).filter(models.Location.name.like(search)) + q5 = query.join(models.Item.status).filter(models.Status.name.like(search)) q6 = query.join(models.Item.comments).filter( - models.ItemComment.body.like(search)) + models.ItemComment.body.like(search) + ) query = q1.union(q2).union(q3).union(q4).union(q5).union(q6) nb_items_filtered = query.order_by(None).count() else: nb_items_filtered = nb_items_total # Construct the order_by query - columns = ('id', - 'ics_id', - 'created_at', - 'updated_at', - 'serial_number', - 'quantity', - 'manufacturer', - 'model', - 'location', - 'status', - 'parent', - ) + columns = ( + "id", + "ics_id", + "created_at", + "updated_at", + "serial_number", + "quantity", + "manufacturer", + "model", + "location", + "status", + "parent", + ) query = query.order_by(order_dir(getattr(models.Item, columns[order_column]))) # Limit and offset the query if per_page != -1: query = query.limit(per_page) query = query.offset(start) data = [ - [item.id, - item.ics_id, - utils.format_field(item.created_at), - utils.format_field(item.updated_at), - item.serial_number, - item.quantity, - utils.format_field(item.manufacturer), - utils.format_field(item.model), - utils.format_field(item.location), - utils.format_field(item.status), - utils.format_field(item.parent), - ] for item in query.all() + [ + item.id, + item.ics_id, + utils.format_field(item.created_at), + utils.format_field(item.updated_at), + item.serial_number, + item.quantity, + utils.format_field(item.manufacturer), + utils.format_field(item.model), + utils.format_field(item.location), + utils.format_field(item.status), + utils.format_field(item.parent), + ] + for item in query.all() ] response = { - 'draw': draw, - 'recordsTotal': nb_items_total, - 'recordsFiltered': nb_items_filtered, - 'data': data + "draw": draw, + "recordsTotal": nb_items_total, + "recordsFiltered": nb_items_filtered, + "data": data, } return jsonify(response) -@bp.route('/items') +@bp.route("/items") @login_required def list_items(): - return render_template('inventory/items.html') + return render_template("inventory/items.html") -@bp.route('/items/create', methods=('GET', 'POST')) -@login_groups_accepted('admin', 'create') +@bp.route("/items/create", methods=("GET", "POST")) +@login_groups_accepted("admin", "create") def create_item(): # The following keys are stored in the session to easily create # several identical items - keys = ('manufacturer_id', 'model_id', 'location_id', 'status_id', 'parent_id') - settings = {key: session.get(key, '') for key in keys} + keys = ("manufacturer_id", "model_id", "location_id", "status_id", "parent_id") + settings = {key: session.get(key, "") for key in keys} form = ItemForm(request.form, **settings) if form.validate_on_submit(): for key in keys: session[key] = getattr(form, key).data - item = models.Item(ics_id=form.ics_id.data, - serial_number=form.serial_number.data, - quantity=form.quantity.data, - manufacturer_id=form.manufacturer_id.data, - model_id=form.model_id.data, - location_id=form.location_id.data, - status_id=form.status_id.data, - parent_id=form.parent_id.data, - host_id=form.host_id.data, - stack_member=form.stack_member.data) - item.macs = [models.Mac(address=address) for address in form.mac_addresses.data.split()] - current_app.logger.debug(f'Trying to create: {item!r}') + item = models.Item( + ics_id=form.ics_id.data, + serial_number=form.serial_number.data, + quantity=form.quantity.data, + manufacturer_id=form.manufacturer_id.data, + model_id=form.model_id.data, + location_id=form.location_id.data, + status_id=form.status_id.data, + parent_id=form.parent_id.data, + host_id=form.host_id.data, + stack_member=form.stack_member.data, + ) + item.macs = [ + models.Mac(address=address) for address in form.mac_addresses.data.split() + ] + current_app.logger.debug(f"Trying to create: {item!r}") db.session.add(item) try: db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - current_app.logger.warning(f'{e}') - flash(f'{e}', 'error') + current_app.logger.warning(f"{e}") + flash(f"{e}", "error") else: - flash(f'Item {item} created!', 'success') - return redirect(url_for('inventory.create_item')) - return render_template('inventory/create_item.html', form=form) + flash(f"Item {item} created!", "success") + return redirect(url_for("inventory.create_item")) + return render_template("inventory/create_item.html", form=form) -@bp.route('/items/view/<ics_id>') +@bp.route("/items/view/<ics_id>") @login_required def view_item(ics_id): item = models.Item.query.filter_by(ics_id=ics_id).first_or_404() - return render_template('inventory/view_item.html', item=item) + return render_template("inventory/view_item.html", item=item) -@bp.route('/items/comment/<ics_id>', methods=('GET', 'POST')) -@login_groups_accepted('admin', 'create') +@bp.route("/items/comment/<ics_id>", methods=("GET", "POST")) +@login_groups_accepted("admin", "create") def comment_item(ics_id): item = models.Item.query.filter_by(ics_id=ics_id).first_or_404() form = CommentForm() if form.validate_on_submit(): - comment = models.ItemComment(body=form.body.data, - item_id=item.id) + comment = models.ItemComment(body=form.body.data, item_id=item.id) db.session.add(comment) db.session.commit() - return redirect(url_for('inventory.view_item', ics_id=ics_id)) - return render_template('inventory/comment_item.html', item=item, form=form) + return redirect(url_for("inventory.view_item", ics_id=ics_id)) + return render_template("inventory/comment_item.html", item=item, form=form) -@bp.route('/items/edit/<ics_id>', methods=('GET', 'POST')) -@login_groups_accepted('admin', 'create') +@bp.route("/items/edit/<ics_id>", methods=("GET", "POST")) +@login_groups_accepted("admin", "create") def edit_item(ics_id): item = models.Item.query.filter_by(ics_id=ics_id).first_or_404() - mac_addresses = ' '.join([str(mac) for mac in item.macs]) + mac_addresses = " ".join([str(mac) for mac in item.macs]) form = ItemForm(request.form, obj=item, mac_addresses=mac_addresses) if form.validate_on_submit(): # Only allow to update temporary ics_id - if item.ics_id.startswith(current_app.config['TEMPORARY_ICS_ID']): + if item.ics_id.startswith(current_app.config["TEMPORARY_ICS_ID"]): item.ics_id = form.ics_id.data item.serial_number = form.serial_number.data item.quantity = form.quantity.data # When a field is disabled, it's value is not passed to the request # We don't use request.form.get('stack_member', None) to let the coerce # function of the field properly convert the value - if 'stack_member' in request.form: + if "stack_member" in request.form: item.stack_member = form.stack_member.data else: # Field is disabled, force it to None item.stack_member = None - for key in ('manufacturer_id', 'model_id', 'location_id', 'status_id', - 'parent_id', 'host_id'): + for key in ( + "manufacturer_id", + "model_id", + "location_id", + "status_id", + "parent_id", + "host_id", + ): setattr(item, key, getattr(form, key).data) new_addresses = form.mac_addresses.data.split() # Delete the MAC addresses that have been removed @@ -203,60 +222,66 @@ def edit_item(ics_id): if address not in mac_addresses: mac = models.Mac(address=address) item.macs.append(mac) - current_app.logger.debug(f'Trying to update: {item!r}') + current_app.logger.debug(f"Trying to update: {item!r}") try: db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - current_app.logger.warning(f'{e}') - flash(f'{e}', 'error') + current_app.logger.warning(f"{e}") + flash(f"{e}", "error") else: - flash(f'Item {item} updated!', 'success') - return redirect(url_for('inventory.view_item', ics_id=item.ics_id)) - return render_template('inventory/edit_item.html', form=form) + flash(f"Item {item} updated!", "success") + return redirect(url_for("inventory.view_item", ics_id=item.ics_id)) + return render_template("inventory/edit_item.html", form=form) -@bp.route('/attributes/favorites') +@bp.route("/attributes/favorites") @login_required def attributes_favorites(): - return render_template('inventory/attributes_favorites.html') + return render_template("inventory/attributes_favorites.html") -@bp.route('/_retrieve_attributes_favorites') +@bp.route("/_retrieve_attributes_favorites") @login_required def retrieve_attributes_favorites(): if current_user not in db.session: # If the current user is cached, it won't be in the sqlalchemy session # Add it to access the user favorite attributes relationship db.session.add(current_user) - data = [(favorite.base64_image(), - type(favorite).__name__, - favorite.name, - favorite.description) for favorite in current_user.favorite_attributes()] + data = [ + ( + favorite.base64_image(), + type(favorite).__name__, + favorite.name, + favorite.description, + ) + for favorite in current_user.favorite_attributes() + ] return jsonify(data=data) -@bp.route('/attributes/<kind>', methods=('GET', 'POST')) -@login_groups_accepted('admin', 'create') +@bp.route("/attributes/<kind>", methods=("GET", "POST")) +@login_groups_accepted("admin", "create") def attributes(kind): form = AttributeForm() if form.validate_on_submit(): model = getattr(models, kind) - new_model = model(name=form.name.data, - description=form.description.data or None) + new_model = model( + name=form.name.data, description=form.description.data or None + ) db.session.add(new_model) try: db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - flash(f'{form.name.data} already exists! {kind} not created.', 'error') + flash(f"{form.name.data} already exists! {kind} not created.", "error") else: - flash(f'{kind} {new_model} created!', 'success') - return redirect(url_for('inventory.attributes', kind=kind)) - return render_template('inventory/attributes.html', kind=kind, form=form) + flash(f"{kind} {new_model} created!", "success") + return redirect(url_for("inventory.attributes", kind=kind)) + return render_template("inventory/attributes.html", kind=kind, form=form) -@bp.route('/_retrieve_attributes/<kind>') +@bp.route("/_retrieve_attributes/<kind>") @login_required def retrieve_attributes(kind): try: @@ -264,14 +289,19 @@ def retrieve_attributes(kind): except AttributeError: raise utils.CSEntryError(f"Unknown model '{kind}'", status_code=422) items = db.session.query(model).order_by(model.name) - data = [({'id': item.id, 'favorite': item.is_user_favorite()}, - item.base64_image(), - item.name, - item.description) for item in items] + data = [ + ( + {"id": item.id, "favorite": item.is_user_favorite()}, + item.base64_image(), + item.name, + item.description, + ) + for item in items + ] return jsonify(data=data) -@bp.route('/_update_favorites/<kind>', methods=['POST']) +@bp.route("/_update_favorites/<kind>", methods=["POST"]) @login_required def update_favorites(kind): """Update the current user favorite attributes @@ -288,28 +318,28 @@ def update_favorites(kind): except AttributeError: raise utils.CSEntryError(f"Unknown model '{kind}'", status_code=422) data = request.get_json() - attribute = model.query.get(data['id']) - favorite_attributes_str = utils.pluralize(f'favorite_{kind.lower()}') + attribute = model.query.get(data["id"]) + favorite_attributes_str = utils.pluralize(f"favorite_{kind.lower()}") user_favorite_attributes = getattr(current_user, favorite_attributes_str) - if data['checked']: + if data["checked"]: user_favorite_attributes.append(attribute) - message = 'Attribute added to the favorites' + message = "Attribute added to the favorites" else: user_favorite_attributes.remove(attribute) - message = 'Attribute removed from the favorites' + message = "Attribute removed from the favorites" db.session.commit() - data = {'message': message} + data = {"message": message} return jsonify(data=data), 201 -@bp.route('/scanner') +@bp.route("/scanner") @login_required def scanner(): """Render the scanner setup codes""" - return render_template('inventory/scanner.html') + return render_template("inventory/scanner.html") -@bp.route('/_retrieve_free_stack_members/<host_id>') +@bp.route("/_retrieve_free_stack_members/<host_id>") @login_required def retrieve_free_stack_members(host_id): """Return as json the free stack members numbers for the given host @@ -320,32 +350,30 @@ def retrieve_free_stack_members(host_id): Used to populate dynamically the stack_member field in the create and edit item forms """ - disabled_data = { - 'stack_members': [], - 'selected_member': None, - 'disabled': True, - } + disabled_data = {"stack_members": [], "selected_member": None, "disabled": True} try: host = models.Host.query.get(host_id) except sa.exc.DataError: # In case of unknown host_id or if host_id is None - current_app.logger.debug(f'Invalid host_id: {host_id}') + current_app.logger.debug(f"Invalid host_id: {host_id}") return jsonify(data=disabled_data) - if str(host.device_type) != 'Network': + if str(host.device_type) != "Network": return jsonify(data=disabled_data) members = host.free_stack_members() - selected_member = 'None' - ics_id = request.args.get('ics_id', None) + selected_member = "None" + ics_id = request.args.get("ics_id", None) item = models.Item.query.filter_by(ics_id=ics_id).first() if item is None: - current_app.logger.debug(f'Unknown ics_id: {ics_id}') + current_app.logger.debug(f"Unknown ics_id: {ics_id}") else: if item.stack_member is not None: members.append(item.stack_member) members.sort() selected_member = item.stack_member - members = ['None'] + members - data = {'stack_members': members, - 'selected_member': selected_member, - 'disabled': False} + members = ["None"] + members + data = { + "stack_members": members, + "selected_member": selected_member, + "disabled": False, + } return jsonify(data=data) diff --git a/app/main/views.py b/app/main/views.py index 3fc9bcb..668ea43 100644 --- a/app/main/views.py +++ b/app/main/views.py @@ -17,7 +17,7 @@ from flask_login import login_required, current_user from rq import push_connection, pop_connection from .. import utils -bp = Blueprint('main', __name__) +bp = Blueprint("main", __name__) # Allow only admin to access the RQ Dashboard @@ -31,17 +31,17 @@ def before_request(): # Declare custom error handlers for all views @bp.app_errorhandler(403) def forbidden_error(error): - return render_template('403.html'), 403 + return render_template("403.html"), 403 @bp.app_errorhandler(404) def not_found_error(error): - return render_template('404.html'), 404 + return render_template("404.html"), 404 @bp.app_errorhandler(500) def internal_error(error): - return render_template('500.html'), 500 + return render_template("500.html"), 500 @bp.app_errorhandler(utils.CSEntryError) @@ -61,18 +61,18 @@ def modified_static_file(endpoint, values): Inspired from http://flask.pocoo.org/snippets/40/ and https://gist.github.com/Ostrovski/f16779933ceee3a9d181 """ - if endpoint == 'static': - filename = values.get('filename') + if endpoint == "static": + filename = values.get("filename") if filename: # The same static folder is used for all blueprints file_path = os.path.join(current_app.static_folder, filename) - values['m'] = int(os.stat(file_path).st_mtime) + values["m"] = int(os.stat(file_path).st_mtime) def get_redis_connection(): - redis_connection = getattr(g, '_redis_connection', None) + redis_connection = getattr(g, "_redis_connection", None) if redis_connection is None: - redis_url = current_app.config['REDIS_URL'] + redis_url = current_app.config["REDIS_URL"] redis_connection = g._redis_connection = redis.from_url(redis_url) return redis_connection @@ -87,7 +87,7 @@ def pop_rq_connection(exception=None): pop_connection() -@bp.route('/') +@bp.route("/") @login_required def index(): - return render_template('index.html') + return render_template("index.html") diff --git a/app/models.py b/app/models.py index de1fb69..80ccc1c 100644 --- a/app/models.py +++ b/app/models.py @@ -26,8 +26,14 @@ from wtforms import ValidationError from rq import Queue from .extensions import db, login_manager, ldap_manager, cache from .plugins import FlaskUserPlugin -from .validators import (ICS_ID_RE, HOST_NAME_RE, VLAN_NAME_RE, MAC_ADDRESS_RE, - DEVICE_TYPE_RE, TAG_RE) +from .validators import ( + ICS_ID_RE, + HOST_NAME_RE, + VLAN_NAME_RE, + MAC_ADDRESS_RE, + DEVICE_TYPE_RE, + TAG_RE, +) from . import utils @@ -39,23 +45,25 @@ class utcnow(sa.sql.expression.FunctionElement): type = sa.types.DateTime() -@sa.ext.compiler.compiles(utcnow, 'postgresql') +@sa.ext.compiler.compiles(utcnow, "postgresql") def pg_utcnow(element, compiler, **kw): - return "TIMEZONE('utc', CURRENT_TIMESTAMP)" + return "TIMEZONE('utc', CURRENT_TIMESTAMP)" def temporary_ics_ids(): """Generator that returns the full list of temporary ICS ids""" - return (f'{current_app.config["TEMPORARY_ICS_ID"]}{letter}{number:0=3d}' - for letter in string.ascii_uppercase - for number in range(0, 1000)) + return ( + f'{current_app.config["TEMPORARY_ICS_ID"]}{letter}{number:0=3d}' + for letter in string.ascii_uppercase + for number in range(0, 1000) + ) def used_temporary_ics_ids(): """Return a set with the temporary ICS ids used""" temporary_items = Item.query.filter( - Item.ics_id.startswith( - current_app.config['TEMPORARY_ICS_ID'])).all() + Item.ics_id.startswith(current_app.config["TEMPORARY_ICS_ID"]) + ).all() return {item.ics_id for item in temporary_items} @@ -66,7 +74,7 @@ def get_temporary_ics_id(): if ics_id not in used_temp_ics_ids: return ics_id else: - raise ValueError('No temporary ICS id available') + raise ValueError("No temporary ICS id available") @login_manager.user_loader @@ -89,11 +97,15 @@ def save_user(dn, username, data, memberships): """ user = User.query.filter_by(username=username).first() if user is None: - user = User(username=username, - display_name=utils.attribute_to_string(data['cn']), - email=utils.attribute_to_string(data['mail'])) + user = User( + username=username, + display_name=utils.attribute_to_string(data["cn"]), + email=utils.attribute_to_string(data["mail"]), + ) # Always update the user groups to keep them up-to-date - user.groups = sorted([utils.attribute_to_string(group['cn']) for group in memberships]) + user.groups = sorted( + [utils.attribute_to_string(group["cn"]) for group in memberships] + ) db.session.add(user) db.session.commit() return user @@ -101,36 +113,53 @@ def save_user(dn, username, data, memberships): # Tables required for Many-to-Many relationships between users and favorites attributes favorite_manufacturers_table = db.Table( - 'favorite_manufacturers', - db.Column('user_id', db.Integer, db.ForeignKey('user_account.id'), primary_key=True), - db.Column('manufacturer_id', db.Integer, db.ForeignKey('manufacturer.id'), primary_key=True) + "favorite_manufacturers", + db.Column( + "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True + ), + db.Column( + "manufacturer_id", + db.Integer, + db.ForeignKey("manufacturer.id"), + primary_key=True, + ), ) favorite_models_table = db.Table( - 'favorite_models', - db.Column('user_id', db.Integer, db.ForeignKey('user_account.id'), primary_key=True), - db.Column('model_id', db.Integer, db.ForeignKey('model.id'), primary_key=True) + "favorite_models", + db.Column( + "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True + ), + db.Column("model_id", db.Integer, db.ForeignKey("model.id"), primary_key=True), ) favorite_locations_table = db.Table( - 'favorite_locations', - db.Column('user_id', db.Integer, db.ForeignKey('user_account.id'), primary_key=True), - db.Column('location_id', db.Integer, db.ForeignKey('location.id'), primary_key=True) + "favorite_locations", + db.Column( + "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True + ), + db.Column( + "location_id", db.Integer, db.ForeignKey("location.id"), primary_key=True + ), ) favorite_statuses_table = db.Table( - 'favorite_statuses', - db.Column('user_id', db.Integer, db.ForeignKey('user_account.id'), primary_key=True), - db.Column('status_id', db.Integer, db.ForeignKey('status.id'), primary_key=True) + "favorite_statuses", + db.Column( + "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True + ), + db.Column("status_id", db.Integer, db.ForeignKey("status.id"), primary_key=True), ) favorite_actions_table = db.Table( - 'favorite_actions', - db.Column('user_id', db.Integer, db.ForeignKey('user_account.id'), primary_key=True), - db.Column('action_id', db.Integer, db.ForeignKey('action.id'), primary_key=True) + "favorite_actions", + db.Column( + "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True + ), + db.Column("action_id", db.Integer, db.ForeignKey("action.id"), primary_key=True), ) class User(db.Model, UserMixin): # "user" is a reserved word in postgresql # so let's use another name - __tablename__ = 'user_account' + __tablename__ = "user_account" id = db.Column(db.Integer, primary_key=True) username = db.Column(db.Text, nullable=False, unique=True) @@ -138,34 +167,39 @@ class User(db.Model, UserMixin): email = db.Column(db.Text) groups = db.Column(postgresql.ARRAY(db.Text), default=[]) tokens = db.relationship("Token", backref="user") - tasks = db.relationship('Task', backref='user') + tasks = db.relationship("Task", backref="user") # The favorites won't be accessed very often so we load them # only when necessary (lazy=True) favorite_manufacturers = db.relationship( - 'Manufacturer', + "Manufacturer", secondary=favorite_manufacturers_table, lazy=True, - backref=db.backref('favorite_users', lazy=True)) + backref=db.backref("favorite_users", lazy=True), + ) favorite_models = db.relationship( - 'Model', + "Model", secondary=favorite_models_table, lazy=True, - backref=db.backref('favorite_users', lazy=True)) + backref=db.backref("favorite_users", lazy=True), + ) favorite_locations = db.relationship( - 'Location', + "Location", secondary=favorite_locations_table, lazy=True, - backref=db.backref('favorite_users', lazy=True)) + backref=db.backref("favorite_users", lazy=True), + ) favorite_statuses = db.relationship( - 'Status', + "Status", secondary=favorite_statuses_table, lazy=True, - backref=db.backref('favorite_users', lazy=True)) + backref=db.backref("favorite_users", lazy=True), + ) favorite_actions = db.relationship( - 'Action', + "Action", secondary=favorite_actions_table, lazy=True, - backref=db.backref('favorite_users', lazy=True)) + backref=db.backref("favorite_users", lazy=True), + ) def get_id(self): """Return the user id as unicode @@ -177,7 +211,7 @@ class User(db.Model, UserMixin): @property def csentry_groups(self): groups = [] - for key, values in current_app.config['CSENTRY_LDAP_GROUPS'].items(): + for key, values in current_app.config["CSENTRY_LDAP_GROUPS"].items(): for value in values: if value in self.groups: groups.append(key) @@ -185,7 +219,7 @@ class User(db.Model, UserMixin): @property def is_admin(self): - for group in current_app.config['CSENTRY_LDAP_GROUPS']['admin']: + for group in current_app.config["CSENTRY_LDAP_GROUPS"]["admin"]: if group in self.groups: return True return False @@ -194,23 +228,32 @@ class User(db.Model, UserMixin): """Return True if the user is at least member of one of the given groups""" names = [] for group in groups: - names.extend(current_app.config['CSENTRY_LDAP_GROUPS'].get(group)) + names.extend(current_app.config["CSENTRY_LDAP_GROUPS"].get(group)) return bool(set(self.groups) & set(names)) def favorite_attributes(self): """Return all user's favorite attributes""" - favorites_list = [self.favorite_manufacturers, self.favorite_models, - self.favorite_locations, self.favorite_statuses, - self.favorite_actions] + favorites_list = [ + self.favorite_manufacturers, + self.favorite_models, + self.favorite_locations, + self.favorite_statuses, + self.favorite_actions, + ] return [favorite for favorites in favorites_list for favorite in favorites] def launch_task(self, name, func, *args, **kwargs): """Launch a task in the background using RQ""" q = Queue() - job = q.enqueue(f'app.tasks.{func}', *args, **kwargs) + job = q.enqueue(f"app.tasks.{func}", *args, **kwargs) # The status will be set to QUEUED or DEFERRED - task = Task(id=job.id, name=name, command=job.get_call_string(), - status=JobStatus(job.status), user=self) + task = Task( + id=job.id, + name=name, + command=job.get_call_string(), + status=JobStatus(job.status), + user=self, + ) db.session.add(task) db.session.commit() return task @@ -226,8 +269,11 @@ class User(db.Model, UserMixin): def get_tasks_in_progress(self, name): """Return all the <name> tasks not finished or failed""" - return Task.query.filter_by(name=name).filter( - ~Task.status.in_([JobStatus.FINISHED, JobStatus.FAILED])).all() + return ( + Task.query.filter_by(name=name) + .filter(~Task.status.in_([JobStatus.FINISHED, JobStatus.FAILED])) + .all() + ) def get_task_started(self, name): """Return the <name> task currently running or None""" @@ -235,8 +281,11 @@ class User(db.Model, UserMixin): def is_task_waiting(self, name): """Return True if a <name> task is waiting (queued or deferred)""" - count = Task.query.filter_by(name=name).filter( - Task.status.in_([JobStatus.DEFERRED, JobStatus.QUEUED])).count() + count = ( + Task.query.filter_by(name=name) + .filter(Task.status.in_([JobStatus.DEFERRED, JobStatus.QUEUED])) + .count() + ) return count > 0 def __str__(self): @@ -244,28 +293,27 @@ class User(db.Model, UserMixin): def to_dict(self): return { - 'id': self.id, - 'username': self.username, - 'display_name': self.display_name, - 'email': self.email, - 'groups': self.csentry_groups, + "id": self.id, + "username": self.username, + "display_name": self.display_name, + "email": self.email, + "groups": self.csentry_groups, } class Token(db.Model): """Table to store valid tokens""" + id = db.Column(db.Integer, primary_key=True) jti = db.Column(postgresql.UUID, nullable=False) token_type = db.Column(db.Text, nullable=False) - user_id = db.Column(db.Integer, db.ForeignKey('user_account.id'), nullable=False) + user_id = db.Column(db.Integer, db.ForeignKey("user_account.id"), nullable=False) issued_at = db.Column(db.DateTime, nullable=False) # expires can be set to None for tokens that never expire expires = db.Column(db.DateTime) description = db.Column(db.Text) - __table_args__ = ( - sa.UniqueConstraint(jti, user_id), - ) + __table_args__ = (sa.UniqueConstraint(jti, user_id),) def __str__(self): return self.jti @@ -284,7 +332,7 @@ class QRCodeMixin: - the table name - the name of the record """ - data = ':'.join(['CSE', self.__tablename__, self.name]) + data = ":".join(["CSE", self.__tablename__, self.name]) return qrcode.make(data, version=1, box_size=5) @cache.memoize(timeout=0) @@ -304,14 +352,14 @@ class QRCodeMixin: # The cache.memoize decorator performs a repr() on the passed in arguments # __repr__ is used as part of the cache key and shall be a uniquely identifying string # See https://flask-caching.readthedocs.io/en/latest/#memoization - return f'{self.__class__.__name__}(id={self.id}, name={self.name})' + return f"{self.__class__.__name__}(id={self.id}, name={self.name})" def to_dict(self): return { - 'id': self.id, - 'name': self.name, - 'description': self.description, - 'qrcode': self.base64_image(), + "id": self.id, + "name": self.name, + "description": self.description, + "qrcode": self.base64_image(), } @@ -320,19 +368,19 @@ class Action(QRCodeMixin, db.Model): class Manufacturer(QRCodeMixin, db.Model): - items = db.relationship('Item', back_populates='manufacturer') + items = db.relationship("Item", back_populates="manufacturer") class Model(QRCodeMixin, db.Model): - items = db.relationship('Item', back_populates='model') + items = db.relationship("Item", back_populates="model") class Location(QRCodeMixin, db.Model): - items = db.relationship('Item', back_populates='location') + items = db.relationship("Item", back_populates="location") class Status(QRCodeMixin, db.Model): - items = db.relationship('Item', back_populates='status') + items = db.relationship("Item", back_populates="status") class CreatedMixin: @@ -344,17 +392,21 @@ class CreatedMixin: # See http://docs.sqlalchemy.org/en/latest/orm/extensions/declarative/mixins.html @declared_attr def user_id(cls): - return db.Column(db.Integer, db.ForeignKey('user_account.id'), - nullable=False, default=utils.fetch_current_user_id) + return db.Column( + db.Integer, + db.ForeignKey("user_account.id"), + nullable=False, + default=utils.fetch_current_user_id, + ) @declared_attr def user(cls): - return db.relationship('User') + return db.relationship("User") def __init__(self, **kwargs): # Automatically convert created_at/updated_at strings # to datetime object - for key in ('created_at', 'updated_at'): + for key in ("created_at", "updated_at"): if key in kwargs: if isinstance(kwargs[key], str): kwargs[key] = utils.parse_to_utc(kwargs[key]) @@ -362,54 +414,65 @@ class CreatedMixin: def to_dict(self): return { - 'id': self.id, - 'created_at': utils.format_field(self.created_at), - 'updated_at': utils.format_field(self.updated_at), - 'user': str(self.user), + "id": self.id, + "created_at": utils.format_field(self.created_at), + "updated_at": utils.format_field(self.updated_at), + "user": str(self.user), } class Item(CreatedMixin, db.Model): __versioned__ = { - 'exclude': ['created_at', 'user_id', 'ics_id', 'serial_number', - 'manufacturer_id', 'model_id'] + "exclude": [ + "created_at", + "user_id", + "ics_id", + "serial_number", + "manufacturer_id", + "model_id", + ] } # WARNING! Inheriting id from CreatedMixin doesn't play well with # SQLAlchemy-Continuum. It has to be defined here. id = db.Column(db.Integer, primary_key=True) - ics_id = db.Column(db.Text, unique=True, nullable=False, - index=True, default=get_temporary_ics_id) + ics_id = db.Column( + db.Text, unique=True, nullable=False, index=True, default=get_temporary_ics_id + ) serial_number = db.Column(db.Text, nullable=False) quantity = db.Column(db.Integer, nullable=False, default=1) - manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id')) - model_id = db.Column(db.Integer, db.ForeignKey('model.id')) - location_id = db.Column(db.Integer, db.ForeignKey('location.id')) - status_id = db.Column(db.Integer, db.ForeignKey('status.id')) - parent_id = db.Column(db.Integer, db.ForeignKey('item.id')) - host_id = db.Column(db.Integer, db.ForeignKey('host.id')) + manufacturer_id = db.Column(db.Integer, db.ForeignKey("manufacturer.id")) + model_id = db.Column(db.Integer, db.ForeignKey("model.id")) + location_id = db.Column(db.Integer, db.ForeignKey("location.id")) + status_id = db.Column(db.Integer, db.ForeignKey("status.id")) + parent_id = db.Column(db.Integer, db.ForeignKey("item.id")) + host_id = db.Column(db.Integer, db.ForeignKey("host.id")) stack_member = db.Column(db.SmallInteger) - manufacturer = db.relationship('Manufacturer', back_populates='items') - model = db.relationship('Model', back_populates='items') - location = db.relationship('Location', back_populates='items') - status = db.relationship('Status', back_populates='items') - children = db.relationship('Item', backref=db.backref('parent', remote_side=[id])) - macs = db.relationship('Mac', backref='item') - comments = db.relationship('ItemComment', backref='item') + manufacturer = db.relationship("Manufacturer", back_populates="items") + model = db.relationship("Model", back_populates="items") + location = db.relationship("Location", back_populates="items") + status = db.relationship("Status", back_populates="items") + children = db.relationship("Item", backref=db.backref("parent", remote_side=[id])) + macs = db.relationship("Mac", backref="item") + comments = db.relationship("ItemComment", backref="item") __table_args__ = ( - sa.CheckConstraint('stack_member >= 0 AND stack_member <=9', name='stack_member_range'), - sa.UniqueConstraint(host_id, stack_member, name='uq_item_host_id_stack_member'), + sa.CheckConstraint( + "stack_member >= 0 AND stack_member <=9", name="stack_member_range" + ), + sa.UniqueConstraint(host_id, stack_member, name="uq_item_host_id_stack_member"), ) def __init__(self, **kwargs): # Automatically convert manufacturer/model/location/status to an # instance of their class if passed as a string - for key, cls in [('manufacturer', Manufacturer), - ('model', Model), - ('location', Location), - ('status', Status)]: + for key, cls in [ + ("manufacturer", Manufacturer), + ("model", Model), + ("location", Location), + ("status", Status), + ]: if key in kwargs: kwargs[key] = utils.convert_to_model(kwargs[key], cls) super().__init__(**kwargs) @@ -417,32 +480,34 @@ class Item(CreatedMixin, db.Model): def __str__(self): return str(self.ics_id) - @validates('ics_id') + @validates("ics_id") def validate_ics_id(self, key, string): """Ensure the ICS id field matches the required format""" if string is not None: if ICS_ID_RE.fullmatch(string) is None: - raise ValidationError('ICS id shall match [A-Z]{3}[0-9]{3}') + raise ValidationError("ICS id shall match [A-Z]{3}[0-9]{3}") return string def to_dict(self): d = super().to_dict() - d.update({ - 'ics_id': self.ics_id, - 'serial_number': self.serial_number, - 'quantity': self.quantity, - 'manufacturer': utils.format_field(self.manufacturer), - 'model': utils.format_field(self.model), - 'location': utils.format_field(self.location), - 'status': utils.format_field(self.status), - 'parent': utils.format_field(self.parent), - 'children': [str(child) for child in self.children], - 'macs': [str(mac) for mac in self.macs], - 'host': utils.format_field(self.host), - 'stack_member': utils.format_field(self.stack_member), - 'history': self.history(), - 'comments': [str(comment) for comment in self.comments], - }) + d.update( + { + "ics_id": self.ics_id, + "serial_number": self.serial_number, + "quantity": self.quantity, + "manufacturer": utils.format_field(self.manufacturer), + "model": utils.format_field(self.model), + "location": utils.format_field(self.location), + "status": utils.format_field(self.status), + "parent": utils.format_field(self.parent), + "children": [str(child) for child in self.children], + "macs": [str(mac) for mac in self.macs], + "host": utils.format_field(self.host), + "stack_member": utils.format_field(self.stack_member), + "history": self.history(), + "comments": [str(comment) for comment in self.comments], + } + ) return d def history(self): @@ -456,29 +521,28 @@ class Item(CreatedMixin, db.Model): parent = None else: parent = Item.query.get(version.parent_id) - versions.append({ - 'updated_at': utils.format_field(version.updated_at), - 'quantity': version.quantity, - 'location': utils.format_field(version.location), - 'status': utils.format_field(version.status), - 'parent': utils.format_field(parent), - }) + versions.append( + { + "updated_at": utils.format_field(version.updated_at), + "quantity": version.quantity, + "location": utils.format_field(version.location), + "status": utils.format_field(version.status), + "parent": utils.format_field(parent), + } + ) return versions class ItemComment(CreatedMixin, db.Model): body = db.Column(db.Text, nullable=False) - item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False) + item_id = db.Column(db.Integer, db.ForeignKey("item.id"), nullable=False) def __str__(self): return self.body def to_dict(self): d = super().to_dict() - d.update({ - 'body': self.body, - 'item': str(self.item), - }) + d.update({"body": self.body, "item": str(self.item)}) return d @@ -490,25 +554,27 @@ class Network(CreatedMixin, db.Model): last_ip = db.Column(postgresql.INET, nullable=False, unique=True) description = db.Column(db.Text) admin_only = db.Column(db.Boolean, nullable=False, default=False) - scope_id = db.Column(db.Integer, db.ForeignKey('network_scope.id'), nullable=False) - domain_id = db.Column(db.Integer, db.ForeignKey('domain.id'), nullable=False) + scope_id = db.Column(db.Integer, db.ForeignKey("network_scope.id"), nullable=False) + domain_id = db.Column(db.Integer, db.ForeignKey("domain.id"), nullable=False) - interfaces = db.relationship('Interface', backref='network') + interfaces = db.relationship("Interface", backref="network") __table_args__ = ( - sa.CheckConstraint('first_ip < last_ip', name='first_ip_less_than_last_ip'), - sa.CheckConstraint('first_ip << address', name='first_ip_in_network'), - sa.CheckConstraint('last_ip << address', name='last_ip_in_network'), + sa.CheckConstraint("first_ip < last_ip", name="first_ip_less_than_last_ip"), + sa.CheckConstraint("first_ip << address", name="first_ip_in_network"), + sa.CheckConstraint("last_ip << address", name="last_ip_in_network"), ) def __init__(self, **kwargs): # Automatically convert scope to an instance of NetworkScope if it was passed # as a string - if 'scope' in kwargs: - kwargs['scope'] = utils.convert_to_model(kwargs['scope'], NetworkScope, 'name') + if "scope" in kwargs: + kwargs["scope"] = utils.convert_to_model( + kwargs["scope"], NetworkScope, "name" + ) # If domain_id is not passed, we set it to the network scope value - if 'domain_id' not in kwargs: - kwargs['domain_id'] = kwargs['scope'].domain_id + if "domain_id" not in kwargs: + kwargs["domain_id"] = kwargs["scope"].domain_id super().__init__(**kwargs) def __str__(self): @@ -535,8 +601,9 @@ class Network(CreatedMixin, db.Model): The range is defined by the first and last IP """ - return [addr for addr in self.network_ip.hosts() - if self.first <= addr <= self.last] + return [ + addr for addr in self.network_ip.hosts() if self.first <= addr <= self.last + ] def used_ips(self): """Return the list of IP addresses in use @@ -547,8 +614,7 @@ class Network(CreatedMixin, db.Model): def available_ips(self): """Return the list of IP addresses available""" - return [addr for addr in self.ip_range() - if addr not in self.used_ips()] + return [addr for addr in self.ip_range() if addr not in self.used_ips()] @property def gateway(self): @@ -566,24 +632,26 @@ class Network(CreatedMixin, db.Model): addr = ipaddress.ip_address(ip) net = ipaddress.ip_network(address) if addr not in net: - raise ValidationError(f'IP address {ip} is not in network {address}') + raise ValidationError(f"IP address {ip} is not in network {address}") return (addr, net) - @validates('first_ip') + @validates("first_ip") def validate_first_ip(self, key, ip): """Ensure the first IP is in the network""" self.ip_in_network(ip, self.address) return ip - @validates('last_ip') + @validates("last_ip") def validate_last_ip(self, key, ip): """Ensure the last IP is in the network and greater than first_ip""" addr, net = self.ip_in_network(ip, self.address) if addr < self.first: - raise ValidationError(f'Last IP address {ip} is less than the first address {self.first}') + raise ValidationError( + f"Last IP address {ip} is less than the first address {self.first}" + ) return ip - @validates('interfaces') + @validates("interfaces") def validate_interfaces(self, key, interface): """Ensure the interface IP is in the network range""" addr, net = self.ip_in_network(interface.ip, self.address) @@ -591,48 +659,54 @@ class Network(CreatedMixin, db.Model): user = utils.cse_current_user() if user is None or not user.is_admin: if addr < self.first or addr > self.last: - raise ValidationError(f'IP address {interface.ip} is not in range {self.first} - {self.last}') + raise ValidationError( + f"IP address {interface.ip} is not in range {self.first} - {self.last}" + ) return interface - @validates('vlan_name') + @validates("vlan_name") def validate_vlan_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None if VLAN_NAME_RE.fullmatch(string) is None: - raise ValidationError('Vlan name shall match [A-Za-z0-9\-]{3,25}') + raise ValidationError("Vlan name shall match [A-Za-z0-9\-]{3,25}") return string def to_dict(self): d = super().to_dict() - d.update({ - 'vlan_name': self.vlan_name, - 'vlan_id': self.vlan_id, - 'address': self.address, - 'netmask': str(self.netmask), - 'first_ip': self.first_ip, - 'last_ip': self.last_ip, - 'description': self.description, - 'admin_only': self.admin_only, - 'scope': utils.format_field(self.scope), - 'domain': str(self.domain), - 'interfaces': [str(interface) for interface in self.interfaces], - }) + d.update( + { + "vlan_name": self.vlan_name, + "vlan_id": self.vlan_id, + "address": self.address, + "netmask": str(self.netmask), + "first_ip": self.first_ip, + "last_ip": self.last_ip, + "description": self.description, + "admin_only": self.admin_only, + "scope": utils.format_field(self.scope), + "domain": str(self.domain), + "interfaces": [str(interface) for interface in self.interfaces], + } + ) return d # Table required for Many-to-Many relationships between interfaces and tags interfacetags_table = db.Table( - 'interfacetags', - db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True), - db.Column('interface_id', db.Integer, db.ForeignKey('interface.id'), primary_key=True) + "interfacetags", + db.Column("tag_id", db.Integer, db.ForeignKey("tag.id"), primary_key=True), + db.Column( + "interface_id", db.Integer, db.ForeignKey("interface.id"), primary_key=True + ), ) class Tag(QRCodeMixin, db.Model): admin_only = db.Column(db.Boolean, nullable=False, default=False) - @validates('name') + @validates("name") def validate_name(self, key, string): """Ensure the name field matches the required format""" if string is not None: @@ -642,13 +716,13 @@ class Tag(QRCodeMixin, db.Model): class DeviceType(db.Model): - __tablename__ = 'device_type' + __tablename__ = "device_type" id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) - hosts = db.relationship('Host', backref='device_type') + hosts = db.relationship("Host", backref="device_type") - @validates('name') + @validates("name") def validate_name(self, key, string): """Ensure the name field matches the required format""" if string is not None: @@ -661,28 +735,34 @@ class DeviceType(db.Model): def to_dict(self): return { - 'id': self.id, - 'name': self.name, - 'hosts': [str(host) for host in self.hosts] + "id": self.id, + "name": self.name, + "hosts": [str(host) for host in self.hosts], } class Host(CreatedMixin, db.Model): name = db.Column(db.Text, nullable=False, unique=True) description = db.Column(db.Text) - device_type_id = db.Column(db.Integer, db.ForeignKey('device_type.id'), nullable=False) + device_type_id = db.Column( + db.Integer, db.ForeignKey("device_type.id"), nullable=False + ) - interfaces = db.relationship('Interface', backref='host') - items = db.relationship('Item', backref='host') + interfaces = db.relationship("Interface", backref="host") + items = db.relationship("Item", backref="host") def __init__(self, **kwargs): # Automatically convert device_type as an instance of its class if passed as a string - if 'device_type' in kwargs: - kwargs['device_type'] = utils.convert_to_model(kwargs['device_type'], DeviceType) + if "device_type" in kwargs: + kwargs["device_type"] = utils.convert_to_model( + kwargs["device_type"], DeviceType + ) # Automatically convert items to a list of instances if passed as a list of ics_id - if 'items' in kwargs: - kwargs['items'] = [utils.convert_to_model(item, Item, filter='ics_id') - for item in kwargs['items']] + if "items" in kwargs: + kwargs["items"] = [ + utils.convert_to_model(item, Item, filter="ics_id") + for item in kwargs["items"] + ] super().__init__(**kwargs) @property @@ -695,7 +775,7 @@ class Host(CreatedMixin, db.Model): def __str__(self): return str(self.name) - @validates('name') + @validates("name") def validate_name(self, key, string): """Ensure the name matches the required format""" if string is None: @@ -703,7 +783,7 @@ class Host(CreatedMixin, db.Model): # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: - raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}') + raise ValidationError("Interface name shall match [a-z0-9\-]{2,20}") return lower_string def stack_members(self): @@ -721,51 +801,60 @@ class Host(CreatedMixin, db.Model): def to_dict(self): d = super().to_dict() - d.update({ - 'name': self.name, - 'device_type': str(self.device_type), - 'description': self.description, - 'items': [str(item) for item in self.items], - 'interfaces': [str(interface) for interface in self.interfaces], - }) + d.update( + { + "name": self.name, + "device_type": str(self.device_type), + "description": self.description, + "items": [str(item) for item in self.items], + "interfaces": [str(interface) for interface in self.interfaces], + } + ) return d class Interface(CreatedMixin, db.Model): - network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False) + network_id = db.Column(db.Integer, db.ForeignKey("network.id"), nullable=False) ip = db.Column(postgresql.INET, nullable=False, unique=True) name = db.Column(db.Text, nullable=False, unique=True) - mac_id = db.Column(db.Integer, db.ForeignKey('mac.id')) - host_id = db.Column(db.Integer, db.ForeignKey('host.id')) + mac_id = db.Column(db.Integer, db.ForeignKey("mac.id")) + host_id = db.Column(db.Integer, db.ForeignKey("host.id")) # Add delete and delete-orphan options to automatically delete cnames when: # - deleting an interface # - de-associating a cname (removing it from the interface.cnames list) - cnames = db.relationship('Cname', backref='interface', - cascade='all, delete, delete-orphan') - tags = db.relationship('Tag', secondary=interfacetags_table, lazy='subquery', - backref=db.backref('interfaces', lazy=True)) + cnames = db.relationship( + "Cname", backref="interface", cascade="all, delete, delete-orphan" + ) + tags = db.relationship( + "Tag", + secondary=interfacetags_table, + lazy="subquery", + backref=db.backref("interfaces", lazy=True), + ) def __init__(self, **kwargs): # Always set self.network and not self.network_id to call validate_interfaces - network_id = kwargs.pop('network_id', None) + network_id = kwargs.pop("network_id", None) if network_id is not None: - kwargs['network'] = Network.query.get(network_id) - elif 'network' in kwargs: + kwargs["network"] = Network.query.get(network_id) + elif "network" in kwargs: # Automatically convert network to an instance of Network if it was passed # as a string - kwargs['network'] = utils.convert_to_model(kwargs['network'], Network, 'vlan_name') + kwargs["network"] = utils.convert_to_model( + kwargs["network"], Network, "vlan_name" + ) # WARNING! Setting self.network will call validate_interfaces in the Network class # For the validation to work, self.ip must be set before! # Ensure that ip is passed before network try: - ip = kwargs.pop('ip') + ip = kwargs.pop("ip") except KeyError: super().__init__(**kwargs) else: super().__init__(ip=ip, **kwargs) - @validates('name') + @validates("name") def validate_name(self, key, string): """Ensure the name matches the required format""" if string is None: @@ -773,7 +862,7 @@ class Interface(CreatedMixin, db.Model): # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: - raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}') + raise ValidationError("Interface name shall match [a-z0-9\-]{2,20}") return lower_string @property @@ -783,7 +872,7 @@ class Interface(CreatedMixin, db.Model): @property def is_ioc(self): for tag in self.tags: - if tag.name == 'IOC': + if tag.name == "IOC": return True return False @@ -791,38 +880,40 @@ class Interface(CreatedMixin, db.Model): return str(self.name) def __repr__(self): - return f'Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})' + return f"Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})" def to_dict(self): d = super().to_dict() - d.update({ - 'network': str(self.network), - 'ip': self.ip, - 'name': self.name, - 'mac': utils.format_field(self.mac), - 'host': utils.format_field(self.host), - 'cnames': [str(cname) for cname in self.cnames], - 'domain': str(self.network.domain), - 'tags': [str(tag) for tag in self.tags], - }) + d.update( + { + "network": str(self.network), + "ip": self.ip, + "name": self.name, + "mac": utils.format_field(self.mac), + "host": utils.format_field(self.host), + "cnames": [str(cname) for cname in self.cnames], + "domain": str(self.network.domain), + "tags": [str(tag) for tag in self.tags], + } + ) if self.host: - d['device_type'] = str(self.host.device_type) + d["device_type"] = str(self.host.device_type) else: - d['device_type'] = None + d["device_type"] = None return d class Mac(db.Model): id = db.Column(db.Integer, primary_key=True) address = db.Column(postgresql.MACADDR, nullable=False, unique=True) - item_id = db.Column(db.Integer, db.ForeignKey('item.id')) + item_id = db.Column(db.Integer, db.ForeignKey("item.id")) - interfaces = db.relationship('Interface', backref='mac') + interfaces = db.relationship("Interface", backref="mac") def __str__(self): return str(self.address) - @validates('address') + @validates("address") def validate_address(self, key, string): """Ensure the address is a valid MAC address""" if string is None: @@ -833,61 +924,62 @@ class Mac(db.Model): def to_dict(self): return { - 'id': self.id, - 'address': self.address, - 'item': utils.format_field(self.item), - 'interfaces': [str(interface) for interface in self.interfaces], + "id": self.id, + "address": self.address, + "item": utils.format_field(self.item), + "interfaces": [str(interface) for interface in self.interfaces], } class Cname(CreatedMixin, db.Model): name = db.Column(db.Text, nullable=False, unique=True) - interface_id = db.Column(db.Integer, db.ForeignKey('interface.id'), nullable=False) + interface_id = db.Column(db.Integer, db.ForeignKey("interface.id"), nullable=False) def __str__(self): return str(self.name) def to_dict(self): d = super().to_dict() - d.update({ - 'name': self.name, - 'interface': str(self.interface), - }) + d.update({"name": self.name, "interface": str(self.interface)}) return d class Domain(CreatedMixin, db.Model): name = db.Column(db.Text, nullable=False, unique=True) - scopes = db.relationship('NetworkScope', backref='domain') - networks = db.relationship('Network', backref='domain') + scopes = db.relationship("NetworkScope", backref="domain") + networks = db.relationship("Network", backref="domain") def __str__(self): return str(self.name) def to_dict(self): d = super().to_dict() - d.update({ - 'name': self.name, - 'scopes': [str(scope) for scope in self.scopes], - 'networks': [str(network) for network in self.networks], - }) + d.update( + { + "name": self.name, + "scopes": [str(scope) for scope in self.scopes], + "networks": [str(network) for network in self.networks], + } + ) return d class NetworkScope(CreatedMixin, db.Model): - __tablename__ = 'network_scope' + __tablename__ = "network_scope" name = db.Column(CIText, nullable=False, unique=True) first_vlan = db.Column(db.Integer, nullable=False, unique=True) last_vlan = db.Column(db.Integer, nullable=False, unique=True) supernet = db.Column(postgresql.CIDR, nullable=False, unique=True) - domain_id = db.Column(db.Integer, db.ForeignKey('domain.id'), nullable=False) + domain_id = db.Column(db.Integer, db.ForeignKey("domain.id"), nullable=False) description = db.Column(db.Text) - networks = db.relationship('Network', backref='scope') + networks = db.relationship("Network", backref="scope") __table_args__ = ( - sa.CheckConstraint('first_vlan < last_vlan', name='first_vlan_less_than_last_vlan'), + sa.CheckConstraint( + "first_vlan < last_vlan", name="first_vlan_less_than_last_vlan" + ), ) def __str__(self): @@ -917,8 +1009,7 @@ class NetworkScope(CreatedMixin, db.Model): def available_vlans(self): """Return the list of vlan ids available""" - return [vlan for vlan in self.vlan_range() - if vlan not in self.used_vlans()] + return [vlan for vlan in self.vlan_range() if vlan not in self.used_vlans()] def used_subnets(self): """Return the list of subnets in use @@ -929,20 +1020,25 @@ class NetworkScope(CreatedMixin, db.Model): def available_subnets(self, prefix): """Return the list of available subnets with the given prefix""" - return [str(subnet) for subnet in self.supernet_ip.subnets(new_prefix=prefix) - if subnet not in self.used_subnets()] + return [ + str(subnet) + for subnet in self.supernet_ip.subnets(new_prefix=prefix) + if subnet not in self.used_subnets() + ] def to_dict(self): d = super().to_dict() - d.update({ - 'name': self.name, - 'first_vlan': self.first_vlan, - 'last_vlan': self.last_vlan, - 'supernet': self.supernet, - 'description': self.description, - 'domain': str(self.domain), - 'networks': [str(network) for network in self.networks], - }) + d.update( + { + "name": self.name, + "first_vlan": self.first_vlan, + "last_vlan": self.last_vlan, + "supernet": self.supernet, + "description": self.description, + "domain": str(self.domain), + "networks": [str(network) for network in self.networks], + } + ) return d @@ -951,11 +1047,11 @@ class NetworkScope(CreatedMixin, db.Model): # not a real enum (it's a custom one) and is not # compatible with sqlalchemy class JobStatus(Enum): - QUEUED = 'queued' - FINISHED = 'finished' - FAILED = 'failed' - STARTED = 'started' - DEFERRED = 'deferred' + QUEUED = "queued" + FINISHED = "finished" + FAILED = "failed" + STARTED = "started" + DEFERRED = "deferred" class Task(db.Model): @@ -965,19 +1061,22 @@ class Task(db.Model): ended_at = db.Column(db.DateTime) name = db.Column(db.Text, nullable=False, index=True) command = db.Column(db.Text) - status = db.Column(db.Enum(JobStatus, name='job_status')) + status = db.Column(db.Enum(JobStatus, name="job_status")) awx_job_id = db.Column(db.Integer) exception = db.Column(db.Text) - user_id = db.Column(db.Integer, db.ForeignKey('user_account.id'), - nullable=False, default=utils.fetch_current_user_id) + user_id = db.Column( + db.Integer, + db.ForeignKey("user_account.id"), + nullable=False, + default=utils.fetch_current_user_id, + ) @property def awx_job_url(self): if self.awx_job_id is None: return None return urllib.parse.urljoin( - current_app.config['AWX_URL'], - f'/#/jobs/{self.awx_job_id}' + current_app.config["AWX_URL"], f"/#/jobs/{self.awx_job_id}" ) def __str__(self): @@ -985,16 +1084,16 @@ class Task(db.Model): def to_dict(self): return { - 'id': self.id, - 'name': self.name, - 'created_at': utils.format_field(self.created_at), - 'ended_at': utils.format_field(self.ended_at), - 'status': self.status.name, - 'awx_job_id': self.awx_job_id, - 'awx_job_url': self.awx_job_url, - 'command': self.command, - 'exception': self.exception, - 'user': str(self.user), + "id": self.id, + "name": self.name, + "created_at": utils.format_field(self.created_at), + "ended_at": utils.format_field(self.ended_at), + "status": self.status.name, + "awx_job_id": self.awx_job_id, + "awx_job_url": self.awx_job_url, + "command": self.command, + "exception": self.exception, + "user": str(self.user), } diff --git a/app/network/forms.py b/app/network/forms.py index 989f1fd..b564519 100644 --- a/app/network/forms.py +++ b/app/network/forms.py @@ -12,11 +12,25 @@ This module defines the network blueprint forms. import ipaddress from flask import current_app from flask_login import current_user -from wtforms import (SelectField, StringField, TextAreaField, IntegerField, - SelectMultipleField, BooleanField, validators) +from wtforms import ( + SelectField, + StringField, + TextAreaField, + IntegerField, + SelectMultipleField, + BooleanField, + validators, +) from ..helpers import CSEntryForm -from ..validators import (Unique, RegexpList, IPNetwork, HOST_NAME_RE, - VLAN_NAME_RE, MAC_ADDRESS_RE, NoValidateSelectField) +from ..validators import ( + Unique, + RegexpList, + IPNetwork, + HOST_NAME_RE, + VLAN_NAME_RE, + MAC_ADDRESS_RE, + NoValidateSelectField, +) from .. import utils, models @@ -24,86 +38,107 @@ def starts_with_hostname(form, field): """Check that interface name starts with hostname""" try: # Create / Edit interface form - host_id_field = form['host_id'] + host_id_field = form["host_id"] except KeyError: # Create host form - hostname = form['name'].data + hostname = form["name"].data else: host = models.Host.query.get(host_id_field.data) hostname = host.name if not field.data.startswith(hostname): - raise validators.ValidationError(f'Interface name shall start with the hostname "{hostname}"') + raise validators.ValidationError( + f'Interface name shall start with the hostname "{hostname}"' + ) def ip_in_network(form, field): """Check that the IP is in the network""" - network_id_field = form['network_id'] + network_id_field = form["network_id"] network = models.Network.query.get(network_id_field.data) ip = ipaddress.ip_address(field.data) if ip not in network.network_ip: - raise validators.ValidationError(f'IP address {ip} is not in network {network.address}') + raise validators.ValidationError( + f"IP address {ip} is not in network {network.address}" + ) # Admin user can create IP outside the defined range if current_user.is_authenticated and not current_user.is_admin: if ip < network.first or ip > network.last: - raise validators.ValidationError(f'IP address {ip} is not in range {network.first} - {network.last}') + raise validators.ValidationError( + f"IP address {ip} is not in range {network.first} - {network.last}" + ) class DomainForm(CSEntryForm): - name = StringField('Name', - validators=[validators.InputRequired(), - Unique(models.Domain, column='name')]) + name = StringField( + "Name", + validators=[validators.InputRequired(), Unique(models.Domain, column="name")], + ) class NetworkScopeForm(CSEntryForm): - name = StringField('Name', - description='name must be 3-25 characters long and contain only letters, numbers and dash', - validators=[validators.InputRequired(), - validators.Regexp(VLAN_NAME_RE), - Unique(models.NetworkScope, column='name')]) - description = TextAreaField('Description') - first_vlan = IntegerField('First vlan') - last_vlan = IntegerField('Last vlan') - supernet = StringField('Supernet', - validators=[validators.InputRequired(), - IPNetwork()]) - domain_id = SelectField('Default domain') + name = StringField( + "Name", + description="name must be 3-25 characters long and contain only letters, numbers and dash", + validators=[ + validators.InputRequired(), + validators.Regexp(VLAN_NAME_RE), + Unique(models.NetworkScope, column="name"), + ], + ) + description = TextAreaField("Description") + first_vlan = IntegerField("First vlan") + last_vlan = IntegerField("Last vlan") + supernet = StringField( + "Supernet", validators=[validators.InputRequired(), IPNetwork()] + ) + domain_id = SelectField("Default domain") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.domain_id.choices = utils.get_model_choices(models.Domain, attr='name') + self.domain_id.choices = utils.get_model_choices(models.Domain, attr="name") class NetworkForm(CSEntryForm): - scope_id = SelectField('Network Scope') - vlan_name = StringField('Vlan name', - description='vlan name must be 3-25 characters long and contain only letters, numbers and dash', - validators=[validators.InputRequired(), - validators.Regexp(VLAN_NAME_RE), - Unique(models.Network, column='vlan_name')]) - vlan_id = NoValidateSelectField('Vlan id', choices=[]) - description = TextAreaField('Description') - prefix = NoValidateSelectField('Prefix', choices=[]) - address = NoValidateSelectField('Address', choices=[]) - first_ip = NoValidateSelectField('First IP', choices=[]) - last_ip = NoValidateSelectField('Last IP', choices=[]) - domain_id = SelectField('Domain') - admin_only = BooleanField('Admin only') + scope_id = SelectField("Network Scope") + vlan_name = StringField( + "Vlan name", + description="vlan name must be 3-25 characters long and contain only letters, numbers and dash", + validators=[ + validators.InputRequired(), + validators.Regexp(VLAN_NAME_RE), + Unique(models.Network, column="vlan_name"), + ], + ) + vlan_id = NoValidateSelectField("Vlan id", choices=[]) + description = TextAreaField("Description") + prefix = NoValidateSelectField("Prefix", choices=[]) + address = NoValidateSelectField("Address", choices=[]) + first_ip = NoValidateSelectField("First IP", choices=[]) + last_ip = NoValidateSelectField("Last IP", choices=[]) + domain_id = SelectField("Domain") + admin_only = BooleanField("Admin only") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.scope_id.choices = utils.get_model_choices(models.NetworkScope, attr='name') - self.domain_id.choices = utils.get_model_choices(models.Domain, attr='name') + self.scope_id.choices = utils.get_model_choices( + models.NetworkScope, attr="name" + ) + self.domain_id.choices = utils.get_model_choices(models.Domain, attr="name") class HostForm(CSEntryForm): - name = StringField('Hostname', - description='hostname must be 2-20 characters long and contain only letters, numbers and dash', - validators=[validators.InputRequired(), - validators.Regexp(HOST_NAME_RE), - Unique(models.Host)], - filters=[utils.lowercase_field]) - description = TextAreaField('Description') - device_type_id = SelectField('Device Type') + name = StringField( + "Hostname", + description="hostname must be 2-20 characters long and contain only letters, numbers and dash", + validators=[ + validators.InputRequired(), + validators.Regexp(HOST_NAME_RE), + Unique(models.Host), + ], + filters=[utils.lowercase_field], + ) + description = TextAreaField("Description") + device_type_id = SelectField("Device Type") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -111,35 +146,42 @@ class HostForm(CSEntryForm): class InterfaceForm(CSEntryForm): - host_id = SelectField('Host') - network_id = SelectField('Network') + host_id = SelectField("Host") + network_id = SelectField("Network") ip = StringField( - 'IP address', - validators=[validators.InputRequired(), - validators.IPAddress(), - ip_in_network, - Unique(models.Interface, column='ip'), - ], + "IP address", + validators=[ + validators.InputRequired(), + validators.IPAddress(), + ip_in_network, + Unique(models.Interface, column="ip"), + ], ) interface_name = StringField( - 'Interface name', - description='name must be 2-20 characters long and contain only letters, numbers and dash', - validators=[validators.InputRequired(), - validators.Regexp(HOST_NAME_RE), - Unique(models.Interface), - starts_with_hostname], - filters=[utils.lowercase_field]) - random_mac = BooleanField('Random MAC', default=False) + "Interface name", + description="name must be 2-20 characters long and contain only letters, numbers and dash", + validators=[ + validators.InputRequired(), + validators.Regexp(HOST_NAME_RE), + Unique(models.Interface), + starts_with_hostname, + ], + filters=[utils.lowercase_field], + ) + random_mac = BooleanField("Random MAC", default=False) mac_address = StringField( - 'MAC', - validators=[validators.Optional(), - validators.Regexp(MAC_ADDRESS_RE, message='Invalid MAC address')]) + "MAC", + validators=[ + validators.Optional(), + validators.Regexp(MAC_ADDRESS_RE, message="Invalid MAC address"), + ], + ) cnames_string = StringField( - 'Cnames', - description='space separated list of cnames (must be 2-20 characters long and contain only letters, numbers and dash)', - validators=[validators.Optional(), - RegexpList(HOST_NAME_RE)]) - tags = SelectMultipleField('Tags', coerce=utils.coerce_to_str_or_none) + "Cnames", + description="space separated list of cnames (must be 2-20 characters long and contain only letters, numbers and dash)", + validators=[validators.Optional(), RegexpList(HOST_NAME_RE)], + ) + tags = SelectMultipleField("Tags", coerce=utils.coerce_to_str_or_none) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -148,12 +190,16 @@ class InterfaceForm(CSEntryForm): network_query = models.Network.query tags_query = models.Tag.query else: - network_query = models.Network.query.filter(models.Network.admin_only.is_(False)) + network_query = models.Network.query.filter( + models.Network.admin_only.is_(False) + ) tags_query = models.Tag.query.filter(models.Tag.admin_only.is_(False)) - self.network_id.choices = utils.get_model_choices(models.Network, allow_none=False, - attr='vlan_name', query=network_query) - self.tags.choices = utils.get_model_choices(models.Tag, - attr='name', query=tags_query) + self.network_id.choices = utils.get_model_choices( + models.Network, allow_none=False, attr="vlan_name", query=network_query + ) + self.tags.choices = utils.get_model_choices( + models.Tag, attr="name", query=tags_query + ) class HostInterfaceForm(HostForm, InterfaceForm): @@ -161,10 +207,10 @@ class HostInterfaceForm(HostForm, InterfaceForm): class CreateVMForm(CSEntryForm): - cores = SelectField('Cores', default=2, coerce=int) - memory = SelectField('Memory (GB)', default=2, coerce=int) + cores = SelectField("Cores", default=2, coerce=int) + memory = SelectField("Memory (GB)", default=2, coerce=int) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.cores.choices = utils.get_choices(current_app.config['VM_CORES_CHOICES']) - self.memory.choices = utils.get_choices(current_app.config['VM_MEMORY_CHOICES']) + self.cores.choices = utils.get_choices(current_app.config["VM_CORES_CHOICES"]) + self.memory.choices = utils.get_choices(current_app.config["VM_MEMORY_CHOICES"]) diff --git a/app/network/views.py b/app/network/views.py index 168ac94..9cbbbd4 100644 --- a/app/network/views.py +++ b/app/network/views.py @@ -11,97 +11,126 @@ This module implements the network blueprint. """ import ipaddress import sqlalchemy as sa -from flask import (Blueprint, render_template, jsonify, session, - redirect, url_for, request, flash, current_app) +from flask import ( + Blueprint, + render_template, + jsonify, + session, + redirect, + url_for, + request, + flash, + current_app, +) from flask_login import login_required, current_user -from .forms import (HostForm, InterfaceForm, HostInterfaceForm, NetworkForm, - NetworkScopeForm, DomainForm, CreateVMForm) +from .forms import ( + HostForm, + InterfaceForm, + HostInterfaceForm, + NetworkForm, + NetworkScopeForm, + DomainForm, + CreateVMForm, +) from ..extensions import db from ..decorators import login_groups_accepted from .. import models, utils, helpers, tasks -bp = Blueprint('network', __name__) +bp = Blueprint("network", __name__) -@bp.route('/hosts') +@bp.route("/hosts") @login_required def list_hosts(): - return render_template('network/hosts.html') + return render_template("network/hosts.html") -@bp.route('/hosts/create', methods=('GET', 'POST')) -@login_groups_accepted('admin', 'create') +@bp.route("/hosts/create", methods=("GET", "POST")) +@login_groups_accepted("admin", "create") def create_host(): - kwargs = {'random_mac': True} + kwargs = {"random_mac": True} # Try to get the network_id from the session # to pre-fill the form with the same network try: - network_id = session['network_id'] + network_id = session["network_id"] except KeyError: pass else: - kwargs['network_id'] = network_id + kwargs["network_id"] = network_id form = HostInterfaceForm(request.form, **kwargs) # Remove the host_id field inherited from the InterfaceForm # It's not used in this form del form.host_id if form.validate_on_submit(): network_id = form.network_id.data - host = models.Host(name=form.name.data, - device_type_id=form.device_type_id.data, - description=form.description.data or None) + host = models.Host( + name=form.name.data, + device_type_id=form.device_type_id.data, + description=form.description.data or None, + ) # The total number of tags will always be quite small # It's more efficient to retrieve all of them in one query # and do the filtering here all_tags = models.Tag.query.all() tags = [tag for tag in all_tags if str(tag.id) in form.tags.data] - interface = models.Interface(name=form.interface_name.data, - ip=form.ip.data, - network_id=network_id, - tags=tags) - interface.cnames = [models.Cname(name=name) for name in form.cnames_string.data.split()] + interface = models.Interface( + name=form.interface_name.data, + ip=form.ip.data, + network_id=network_id, + tags=tags, + ) + interface.cnames = [ + models.Cname(name=name) for name in form.cnames_string.data.split() + ] helpers.associate_mac_to_interface(form.mac_address.data, interface) host.interfaces = [interface] - current_app.logger.debug(f'Trying to create: {host!r}') + current_app.logger.debug(f"Trying to create: {host!r}") db.session.add(host) try: db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - current_app.logger.warning(f'{e}') - flash(f'{e}', 'error') + current_app.logger.warning(f"{e}") + flash(f"{e}", "error") else: - flash(f'Host {host} created!', 'success') + flash(f"Host {host} created!", "success") tasks.trigger_core_services_update() # Save network_id to the session to retrieve it after the redirect - session['network_id'] = network_id - return redirect(url_for('network.view_host', name=host.name)) - return render_template('network/create_host.html', form=form) + session["network_id"] = network_id + return redirect(url_for("network.view_host", name=host.name)) + return render_template("network/create_host.html", form=form) -@bp.route('/hosts/view/<name>', methods=('GET', 'POST')) +@bp.route("/hosts/view/<name>", methods=("GET", "POST")) @login_required def view_host(name): host = models.Host.query.filter_by(name=name).first_or_404() form = CreateVMForm() if host.is_ioc: - form.cores.choices = utils.get_choices(current_app.config['VIOC_CORES_CHOICES']) - form.memory.choices = utils.get_choices(current_app.config['VIOC_MEMORY_CHOICES']) + form.cores.choices = utils.get_choices(current_app.config["VIOC_CORES_CHOICES"]) + form.memory.choices = utils.get_choices( + current_app.config["VIOC_MEMORY_CHOICES"] + ) if form.validate_on_submit(): if not current_user.is_admin: - flash(f'Only admin users are allowed to create a VM!', 'info') - return redirect(url_for('network.view_host', name=name)) + flash(f"Only admin users are allowed to create a VM!", "info") + return redirect(url_for("network.view_host", name=name)) else: interface = host.interfaces[0] - task = tasks.trigger_vm_creation(name, interface, int(form.memory.data) * 1000, form.cores.data) - current_app.logger.info(f'Creation of {name} requested: task {task.id}') - flash(f'Creation of {name} requested! Refresh the page to update the status.', 'success') - return redirect(url_for('task.view_task', id_=task.id)) - return render_template('network/view_host.html', host=host, form=form) - - -@bp.route('/hosts/edit/<name>', methods=('GET', 'POST')) -@login_groups_accepted('admin', 'create') + task = tasks.trigger_vm_creation( + name, interface, int(form.memory.data) * 1000, form.cores.data + ) + current_app.logger.info(f"Creation of {name} requested: task {task.id}") + flash( + f"Creation of {name} requested! Refresh the page to update the status.", + "success", + ) + return redirect(url_for("task.view_task", id_=task.id)) + return render_template("network/view_host.html", host=host, form=form) + + +@bp.route("/hosts/edit/<name>", methods=("GET", "POST")) +@login_groups_accepted("admin", "create") def edit_host(name): host = models.Host.query.filter_by(name=name).first_or_404() form = HostForm(request.form, obj=host) @@ -109,68 +138,78 @@ def edit_host(name): host.name = form.name.data host.device_type_id = form.device_type_id.data host.description = form.description.data or None - current_app.logger.debug(f'Trying to update: {host!r}') + current_app.logger.debug(f"Trying to update: {host!r}") try: db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - current_app.logger.warning(f'{e}') - flash(f'{e}', 'error') + current_app.logger.warning(f"{e}") + flash(f"{e}", "error") else: - flash(f'Host {host} updated!', 'success') + flash(f"Host {host} updated!", "success") tasks.trigger_core_services_update() - return redirect(url_for('network.view_host', name=host.name)) - return render_template('network/edit_host.html', form=form) + return redirect(url_for("network.view_host", name=host.name)) + return render_template("network/edit_host.html", form=form) -@bp.route('/interfaces/create/<hostname>', methods=('GET', 'POST')) -@login_groups_accepted('admin', 'create') +@bp.route("/interfaces/create/<hostname>", methods=("GET", "POST")) +@login_groups_accepted("admin", "create") def create_interface(hostname): host = models.Host.query.filter_by(name=hostname).first_or_404() - random_mac = host.device_type.name.startswith('Virtual') - form = InterfaceForm(request.form, host_id=host.id, interface_name=host.name, - random_mac=random_mac) + random_mac = host.device_type.name.startswith("Virtual") + form = InterfaceForm( + request.form, host_id=host.id, interface_name=host.name, random_mac=random_mac + ) if form.validate_on_submit(): # The total number of tags will always be quite small # It's more efficient to retrieve all of them in one query # and do the filtering here all_tags = models.Tag.query.all() tags = [tag for tag in all_tags if str(tag.id) in form.tags.data] - interface = models.Interface(host_id=host.id, - name=form.interface_name.data, - ip=form.ip.data, - network_id=form.network_id.data, - tags=tags) - interface.cnames = [models.Cname(name=name) for name in form.cnames_string.data.split()] + interface = models.Interface( + host_id=host.id, + name=form.interface_name.data, + ip=form.ip.data, + network_id=form.network_id.data, + tags=tags, + ) + interface.cnames = [ + models.Cname(name=name) for name in form.cnames_string.data.split() + ] helpers.associate_mac_to_interface(form.mac_address.data, interface) - current_app.logger.debug(f'Trying to create: {interface!r}') + current_app.logger.debug(f"Trying to create: {interface!r}") db.session.add(interface) try: db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - current_app.logger.warning(f'{e}') - flash(f'{e}', 'error') + current_app.logger.warning(f"{e}") + flash(f"{e}", "error") else: - flash(f'Host {interface} created!', 'success') + flash(f"Host {interface} created!", "success") tasks.trigger_core_services_update() - return redirect(url_for('network.create_interface', hostname=hostname)) - return render_template('network/create_interface.html', form=form, hostname=hostname) + return redirect(url_for("network.create_interface", hostname=hostname)) + return render_template( + "network/create_interface.html", form=form, hostname=hostname + ) -@bp.route('/interfaces/edit/<name>', methods=('GET', 'POST')) -@login_groups_accepted('admin', 'create') +@bp.route("/interfaces/edit/<name>", methods=("GET", "POST")) +@login_groups_accepted("admin", "create") def edit_interface(name): interface = models.Interface.query.filter_by(name=name).first_or_404() - cnames_string = ' '.join([str(cname) for cname in interface.cnames]) + cnames_string = " ".join([str(cname) for cname in interface.cnames]) try: mac_address = interface.mac.address except AttributeError: - mac_address = '' - form = InterfaceForm(request.form, obj=interface, - interface_name=interface.name, - mac_address=mac_address, - cnames_string=cnames_string) + mac_address = "" + form = InterfaceForm( + request.form, + obj=interface, + interface_name=interface.name, + mac_address=mac_address, + cnames_string=cnames_string, + ) # Remove the random_mac field (not used when editing) del form.random_mac ips = [interface.ip] @@ -199,7 +238,7 @@ def edit_interface(name): new_cnames_string = form.cnames_string.data.split() for (index, cname) in enumerate(interface.cnames): if cname.name not in new_cnames_string: - current_app.logger.debug(f'Deleting cname: {cname}') + current_app.logger.debug(f"Deleting cname: {cname}") # Removing the cname from interface.cnames list will # delete it from the database due to the cascade # delete-orphan option defined on the model @@ -208,152 +247,164 @@ def edit_interface(name): for name in new_cnames_string: if name not in cnames_string: cname = models.Cname(name=name) - current_app.logger.debug(f'Creating cname: {cname}') + current_app.logger.debug(f"Creating cname: {cname}") interface.cnames.append(cname) all_tags = models.Tag.query.all() tags = [tag for tag in all_tags if str(tag.id) in form.tags.data] interface.tags = tags - current_app.logger.debug(f'Trying to update: {interface!r}') + current_app.logger.debug(f"Trying to update: {interface!r}") try: db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - current_app.logger.warning(f'{e}') - flash(f'{e}', 'error') + current_app.logger.warning(f"{e}") + flash(f"{e}", "error") else: - flash(f'Interface {interface} updated!', 'success') + flash(f"Interface {interface} updated!", "success") tasks.trigger_core_services_update() - return redirect(url_for('network.view_host', name=interface.host.name)) - return render_template('network/edit_interface.html', form=form, hostname=interface.host.name) + return redirect(url_for("network.view_host", name=interface.host.name)) + return render_template( + "network/edit_interface.html", form=form, hostname=interface.host.name + ) -@bp.route('/interfaces/delete', methods=['POST']) -@login_groups_accepted('admin', 'create') +@bp.route("/interfaces/delete", methods=["POST"]) +@login_groups_accepted("admin", "create") def delete_interface(): - interface = models.Interface.query.get_or_404(request.form['interface_id']) + interface = models.Interface.query.get_or_404(request.form["interface_id"]) hostname = interface.host.name # Deleting the interface will also delete all # associated cnames due to the cascade delete option # defined on the model db.session.delete(interface) db.session.commit() - flash(f'Interface {interface.name} has been deleted', 'success') - return redirect(url_for('network.view_host', name=hostname)) + flash(f"Interface {interface.name} has been deleted", "success") + return redirect(url_for("network.view_host", name=hostname)) -@bp.route('/domains') +@bp.route("/domains") @login_required def list_domains(): - return render_template('network/domains.html') + return render_template("network/domains.html") -@bp.route('/domains/create', methods=('GET', 'POST')) -@login_groups_accepted('admin') +@bp.route("/domains/create", methods=("GET", "POST")) +@login_groups_accepted("admin") def create_domain(): form = DomainForm() if form.validate_on_submit(): domain = models.Domain(name=form.name.data) - current_app.logger.debug(f'Trying to create: {domain!r}') + current_app.logger.debug(f"Trying to create: {domain!r}") db.session.add(domain) try: db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - current_app.logger.warning(f'{e}') - flash(f'{e}', 'error') + current_app.logger.warning(f"{e}") + flash(f"{e}", "error") else: - flash(f'Domain {domain} created!', 'success') - return redirect(url_for('network.create_domain')) - return render_template('network/create_domain.html', form=form) + flash(f"Domain {domain} created!", "success") + return redirect(url_for("network.create_domain")) + return render_template("network/create_domain.html", form=form) -@bp.route('/scopes') +@bp.route("/scopes") @login_required def list_scopes(): - return render_template('network/scopes.html') + return render_template("network/scopes.html") -@bp.route('/scopes/create', methods=('GET', 'POST')) -@login_groups_accepted('admin') +@bp.route("/scopes/create", methods=("GET", "POST")) +@login_groups_accepted("admin") def create_scope(): form = NetworkScopeForm() if form.validate_on_submit(): - scope = models.NetworkScope(name=form.name.data, - description=form.description.data or None, - first_vlan=form.first_vlan.data, - last_vlan=form.last_vlan.data, - supernet=form.supernet.data, - domain_id=form.domain_id.data) - current_app.logger.debug(f'Trying to create: {scope!r}') + scope = models.NetworkScope( + name=form.name.data, + description=form.description.data or None, + first_vlan=form.first_vlan.data, + last_vlan=form.last_vlan.data, + supernet=form.supernet.data, + domain_id=form.domain_id.data, + ) + current_app.logger.debug(f"Trying to create: {scope!r}") db.session.add(scope) try: db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - current_app.logger.warning(f'{e}') - flash(f'{e}', 'error') + current_app.logger.warning(f"{e}") + flash(f"{e}", "error") else: - flash(f'Network Scope {scope} created!', 'success') - return redirect(url_for('network.create_scope')) - return render_template('network/create_scope.html', form=form) + flash(f"Network Scope {scope} created!", "success") + return redirect(url_for("network.create_scope")) + return render_template("network/create_scope.html", form=form) -@bp.route('/_retrieve_hosts') +@bp.route("/_retrieve_hosts") @login_required def retrieve_hosts(): - data = [(host.name, - str(host.device_type), - host.description, - interface.name, - interface.ip, - str(interface.network)) - for host in models.Host.query.all() - for interface in host.interfaces] + data = [ + ( + host.name, + str(host.device_type), + host.description, + interface.name, + interface.ip, + str(interface.network), + ) + for host in models.Host.query.all() + for interface in host.interfaces + ] return jsonify(data=data) -@bp.route('/_retrieve_first_available_ip/<int:network_id>') +@bp.route("/_retrieve_first_available_ip/<int:network_id>") @login_required def retrieve_first_available_ip(network_id): try: network = models.Network.query.get(network_id) except sa.exc.DataError: - current_app.logger.warning(f'Invalid network_id: {network_id}') - data = '' + current_app.logger.warning(f"Invalid network_id: {network_id}") + data = "" else: data = str(network.available_ips()[0]) return jsonify(data=data) -@bp.route('/networks') +@bp.route("/networks") @login_required def list_networks(): - return render_template('network/networks.html') + return render_template("network/networks.html") -@bp.route('/_retrieve_networks') +@bp.route("/_retrieve_networks") @login_required def retrieve_networks(): - data = [(str(network.scope), - network.vlan_name, - network.vlan_id, - network.description, - network.address, - network.first_ip, - network.last_ip, - str(network.domain), - network.admin_only) - for network in models.Network.query.all()] + data = [ + ( + str(network.scope), + network.vlan_name, + network.vlan_id, + network.description, + network.address, + network.first_ip, + network.last_ip, + str(network.domain), + network.admin_only, + ) + for network in models.Network.query.all() + ] return jsonify(data=data) -@bp.route('/networks/create', methods=('GET', 'POST')) -@login_groups_accepted('admin') +@bp.route("/networks/create", methods=("GET", "POST")) +@login_groups_accepted("admin") def create_network(): # Try to get the scope_id from the session # to pre-fill the form with the same network scope try: - scope_id = session['scope_id'] + scope_id = session["scope_id"] except KeyError: # No need to pass request.form when no extra keywords are given form = NetworkForm() @@ -361,80 +412,87 @@ def create_network(): form = NetworkForm(request.form, scope_id=scope_id) if form.validate_on_submit(): scope_id = form.scope_id.data - network = models.Network(scope_id=scope_id, - vlan_name=form.vlan_name.data, - vlan_id=form.vlan_id.data, - description=form.description.data or None, - address=form.address.data, - first_ip=form.first_ip.data, - last_ip=form.last_ip.data, - domain_id=form.domain_id.data, - admin_only=form.admin_only.data) - current_app.logger.debug(f'Trying to create: {network!r}') + network = models.Network( + scope_id=scope_id, + vlan_name=form.vlan_name.data, + vlan_id=form.vlan_id.data, + description=form.description.data or None, + address=form.address.data, + first_ip=form.first_ip.data, + last_ip=form.last_ip.data, + domain_id=form.domain_id.data, + admin_only=form.admin_only.data, + ) + current_app.logger.debug(f"Trying to create: {network!r}") db.session.add(network) try: db.session.commit() except sa.exc.IntegrityError as e: db.session.rollback() - current_app.logger.warning(f'{e}') - flash(f'{e}', 'error') + current_app.logger.warning(f"{e}") + flash(f"{e}", "error") else: - flash(f'Network {network} created!', 'success') + flash(f"Network {network} created!", "success") # Save scope_id to the session to retrieve it after the redirect - session['scope_id'] = scope_id - return redirect(url_for('network.create_network')) - return render_template('network/create_network.html', form=form) + session["scope_id"] = scope_id + return redirect(url_for("network.create_network")) + return render_template("network/create_network.html", form=form) -@bp.route('/_retrieve_scope_defaults/<int:scope_id>') +@bp.route("/_retrieve_scope_defaults/<int:scope_id>") @login_required def retrieve_scope_defaults(scope_id): try: scope = models.NetworkScope.query.get(scope_id) except sa.exc.DataError: - current_app.logger.warning(f'Invalid scope_id: {scope_id}') - data = {'vlans': [], 'prefixes': [], - 'selected_vlan': '', 'selected_prefix': '', - 'domain_id': ''} + current_app.logger.warning(f"Invalid scope_id: {scope_id}") + data = { + "vlans": [], + "prefixes": [], + "selected_vlan": "", + "selected_prefix": "", + "domain_id": "", + } else: vlans = [vlan_id for vlan_id in scope.available_vlans()] prefixes = scope.prefix_range() - default_prefix = current_app.config['NETWORK_DEFAULT_PREFIX'] + default_prefix = current_app.config["NETWORK_DEFAULT_PREFIX"] if default_prefix in prefixes: selected_prefix = default_prefix else: selected_prefix = prefixes[0] - data = {'vlans': vlans, - 'prefixes': prefixes, - 'selected_vlan': vlans[0], - 'selected_prefix': selected_prefix, - 'domain_id': scope.domain_id} + data = { + "vlans": vlans, + "prefixes": prefixes, + "selected_vlan": vlans[0], + "selected_prefix": selected_prefix, + "domain_id": scope.domain_id, + } return jsonify(data=data) -@bp.route('/_retrieve_subnets/<int:scope_id>/<int:prefix>') +@bp.route("/_retrieve_subnets/<int:scope_id>/<int:prefix>") @login_required def retrieve_subnets(scope_id, prefix): try: scope = models.NetworkScope.query.get(scope_id) except sa.exc.DataError: - current_app.logger.warning(f'Invalid scope_id: {scope_id}') - data = {'subnets': [], 'selected_subnet': ''} + current_app.logger.warning(f"Invalid scope_id: {scope_id}") + data = {"subnets": [], "selected_subnet": ""} else: subnets = [subnet for subnet in scope.available_subnets(int(prefix))] - data = {'subnets': subnets, - 'selected_subnet': subnets[0]} + data = {"subnets": subnets, "selected_subnet": subnets[0]} return jsonify(data=data) -@bp.route('/_retrieve_ips/<subnet>/<int:prefix>') +@bp.route("/_retrieve_ips/<subnet>/<int:prefix>") @login_required def retrieve_ips(subnet, prefix): try: - address = ipaddress.ip_network(f'{subnet}/{prefix}') + address = ipaddress.ip_network(f"{subnet}/{prefix}") except ValueError: - current_app.logger.warning(f'Invalid address: {subnet}/{prefix}') - data = {'ips': [], 'first': '', 'last': ''} + current_app.logger.warning(f"Invalid address: {subnet}/{prefix}") + data = {"ips": [], "first": "", "last": ""} else: hosts = [str(ip) for ip in address.hosts()] if len(hosts) > 17: @@ -443,35 +501,36 @@ def retrieve_ips(subnet, prefix): else: first = hosts[0] last = hosts[-1] - data = {'ips': hosts, - 'selected_first': first, - 'selected_last': last} + data = {"ips": hosts, "selected_first": first, "selected_last": last} return jsonify(data=data) -@bp.route('/_retrieve_scopes') +@bp.route("/_retrieve_scopes") @login_required def retrieve_scopes(): - data = [(scope.name, - scope.description, - scope.first_vlan, - scope.last_vlan, - scope.supernet, - str(scope.domain)) - for scope in models.NetworkScope.query.all()] + data = [ + ( + scope.name, + scope.description, + scope.first_vlan, + scope.last_vlan, + scope.supernet, + str(scope.domain), + ) + for scope in models.NetworkScope.query.all() + ] return jsonify(data=data) -@bp.route('/_retrieve_domains') +@bp.route("/_retrieve_domains") @login_required def retrieve_domains(): - data = [(domain.name,) - for domain in models.Domain.query.all()] + data = [(domain.name,) for domain in models.Domain.query.all()] return jsonify(data=data) -@bp.route('/_generate_random_mac') +@bp.route("/_generate_random_mac") @login_required def generate_random_mac(): - data = {'mac': utils.random_mac()} + data = {"mac": utils.random_mac()} return jsonify(data=data) diff --git a/app/plugins.py b/app/plugins.py index fb09e7e..d23ae9e 100644 --- a/app/plugins.py +++ b/app/plugins.py @@ -25,14 +25,12 @@ from . import utils class FlaskUserPlugin(Plugin): - def __init__(self, current_user_id_factory=None): self.current_user_id_factory = ( - utils.fetch_current_user_id if current_user_id_factory is None + utils.fetch_current_user_id + if current_user_id_factory is None else current_user_id_factory ) def transaction_args(self, uow, session): - return { - 'user_id': self.current_user_id_factory(), - } + return {"user_id": self.current_user_id_factory()} diff --git a/app/settings.py b/app/settings.py index 5ec4621..455b247 100644 --- a/app/settings.py +++ b/app/settings.py @@ -12,73 +12,75 @@ This module implements the app default settings. import os from datetime import timedelta -SQLALCHEMY_DATABASE_URI = 'postgresql://ics:icspwd@postgres/csentry_db' +SQLALCHEMY_DATABASE_URI = "postgresql://ics:icspwd@postgres/csentry_db" SQLALCHEMY_TRACK_MODIFICATIONS = False BOOTSTRAP_SERVE_LOCAL = True -SECRET_KEY = (os.environ.get('SECRET_KEY') or - b'\x0d\x11{\xd3\x13$\xeeel\xa6\xfb\x1d~\xfd\xb2\x9d\x16\x00\xfb5\xd64\xd4\xe0') -MAIL_SERVER = 'mail.esss.lu.se' +SECRET_KEY = ( + os.environ.get("SECRET_KEY") + or b"\x0d\x11{\xd3\x13$\xeeel\xa6\xfb\x1d~\xfd\xb2\x9d\x16\x00\xfb5\xd64\xd4\xe0" +) +MAIL_SERVER = "mail.esss.lu.se" MAIL_CREDENTIALS = None -ADMIN_EMAILS = ['admin@example.com'] -EMAIL_SENDER = 'noreply@esss.se' +ADMIN_EMAILS = ["admin@example.com"] +EMAIL_SENDER = "noreply@esss.se" JWT_BLACKLIST_ENABLED = True -JWT_BLACKLIST_TOKEN_CHECKS = ['access', 'refresh'] +JWT_BLACKLIST_TOKEN_CHECKS = ["access", "refresh"] JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=12) -SESSION_TYPE = 'redis' -SESSION_REDIS_URL = 'redis://redis:6379/0' -CACHE_TYPE = 'redis' -CACHE_REDIS_URL = 'redis://redis:6379/1' -REDIS_URL = 'redis://redis:6379/2' -QUEUES = ['default'] +SESSION_TYPE = "redis" +SESSION_REDIS_URL = "redis://redis:6379/0" +CACHE_TYPE = "redis" +CACHE_REDIS_URL = "redis://redis:6379/1" +REDIS_URL = "redis://redis:6379/2" +QUEUES = ["default"] -LDAP_HOST = 'esss.lu.se' -LDAP_BASE_DN = 'DC=esss,DC=lu,DC=se' -LDAP_USER_DN = 'OU=ESS Users' -LDAP_GROUP_DN = '' -LDAP_BIND_USER_DN = 'ldapuser' -LDAP_BIND_USER_PASSWORD = 'secret' -LDAP_USER_RDN_ATTR = 'cn' -LDAP_USER_LOGIN_ATTR = 'sAMAccountName' +LDAP_HOST = "esss.lu.se" +LDAP_BASE_DN = "DC=esss,DC=lu,DC=se" +LDAP_USER_DN = "OU=ESS Users" +LDAP_GROUP_DN = "" +LDAP_BIND_USER_DN = "ldapuser" +LDAP_BIND_USER_PASSWORD = "secret" +LDAP_USER_RDN_ATTR = "cn" +LDAP_USER_LOGIN_ATTR = "sAMAccountName" LDAP_ALWAYS_SEARCH_BIND = True -LDAP_USER_OBJECT_FILTER = '(samAccountType=805306368)' -LDAP_GROUP_OBJECT_FILTER = '' -LDAP_USER_SEARCH_SCOPE = 'SUBTREE' -LDAP_GROUP_SEARCH_SCOPE = 'SUBTREE' -LDAP_GROUP_MEMBERS_ATTR = 'member' -LDAP_GET_USER_ATTRIBUTES = ['cn', 'sAMAccountName', 'mail'] -LDAP_GET_GROUP_ATTRIBUTES = ['cn'] +LDAP_USER_OBJECT_FILTER = "(samAccountType=805306368)" +LDAP_GROUP_OBJECT_FILTER = "" +LDAP_USER_SEARCH_SCOPE = "SUBTREE" +LDAP_GROUP_SEARCH_SCOPE = "SUBTREE" +LDAP_GROUP_MEMBERS_ATTR = "member" +LDAP_GET_USER_ATTRIBUTES = ["cn", "sAMAccountName", "mail"] +LDAP_GET_GROUP_ATTRIBUTES = ["cn"] CSENTRY_LDAP_GROUPS = { - 'admin': ['ICS Control System Infrastructure group'], - 'create': ['ICS Employees', 'ICS Consultants'], + "admin": ["ICS Control System Infrastructure group"], + "create": ["ICS Employees", "ICS Consultants"], } NETWORK_DEFAULT_PREFIX = 24 # ICS Ids starting with this prefix are considered temporary and can be changed # (waiting for a real label to be assigned) # WARNING: This is defined here as a global settings but should not be changed! -TEMPORARY_ICS_ID = 'ZZ' +TEMPORARY_ICS_ID = "ZZ" # CSENTRY MAC organizationally unique identifier # This is a locally administered address -MAC_OUI = '02:42:42' +MAC_OUI = "02:42:42" -DOCUMENTATION_URL = 'http://ics-infrastructure.pages.esss.lu.se/csentry/index.html' +DOCUMENTATION_URL = "http://ics-infrastructure.pages.esss.lu.se/csentry/index.html" CSENTRY_STAGING = False -AWX_URL = 'https://torn.tn.esss.lu.se' +AWX_URL = "https://torn.tn.esss.lu.se" # AWX job templates -AWX_CORE_SERVICES_UPDATE = 'ics-ans-core @ DHCP test' -AWX_CREATE_VM = 'ics-ans-deploy-proxmox-vm' -AWX_CREATE_VIOC = 'ics-ans-deploy-vioc' +AWX_CORE_SERVICES_UPDATE = "ics-ans-core @ DHCP test" +AWX_CREATE_VM = "ics-ans-deploy-proxmox-vm" +AWX_CREATE_VIOC = "ics-ans-deploy-vioc" AWX_JOB_ENABLED = False AWX_VM_CREATION_ENABLED = False VM_CORES_CHOICES = [1, 2, 4, 6, 8, 24] VM_MEMORY_CHOICES = [2, 4, 8, 16, 32, 128] -VM_DEFAULT_DNS = '172.16.6.21' +VM_DEFAULT_DNS = "172.16.6.21" VIOC_CORES_CHOICES = [1, 3, 6] VIOC_MEMORY_CHOICES = [2, 4, 8] diff --git a/app/task/views.py b/app/task/views.py index 4642213..e660e38 100644 --- a/app/task/views.py +++ b/app/task/views.py @@ -9,30 +9,29 @@ This module implements the task blueprint. :license: BSD 2-Clause, see LICENSE for more details. """ -from flask import (Blueprint, render_template, jsonify, - request) +from flask import Blueprint, render_template, jsonify, request from flask_login import login_required, current_user from .. import models -bp = Blueprint('task', __name__) +bp = Blueprint("task", __name__) -@bp.route('/tasks') +@bp.route("/tasks") @login_required def list_tasks(): - return render_template('task/tasks.html') + return render_template("task/tasks.html") -@bp.route('/tasks/view/<id_>') +@bp.route("/tasks/view/<id_>") @login_required def view_task(id_): task = models.Task.query.get_or_404(id_) - return render_template('task/view_task.html', task=task) + return render_template("task/view_task.html", task=task) -@bp.route('/_retrieve_tasks') +@bp.route("/_retrieve_tasks") @login_required def retrieve_tasks(): - all = request.args.get('all', 'false') == 'true' + all = request.args.get("all", "false") == "true" data = [task.to_dict() for task in current_user.get_tasks(all=all)] return jsonify(data=data) diff --git a/app/tasks.py b/app/tasks.py index 7e3368a..3c8c69e 100644 --- a/app/tasks.py +++ b/app/tasks.py @@ -34,7 +34,8 @@ class TaskWorker(Worker): if task.awx_job_id is None: # No AWX job was triggered. An exception occured before. Save it. task.exception = self._get_safe_exception_string( - traceback.format_exception(*exc_info)) + traceback.format_exception(*exc_info) + ) db.session.commit() def update_task_attributes(self, job, attributes): @@ -47,10 +48,10 @@ class TaskWorker(Worker): if task is not None: break else: - self.log.warning('task not found...') + self.log.warning("task not found...") time.sleep(1) else: - self.log.error(f'Task {job.id} not found! Task attribute not updated!') + self.log.error(f"Task {job.id} not found! Task attribute not updated!") return for name, value in attributes.items(): setattr(task, name, value) @@ -60,56 +61,56 @@ class TaskWorker(Worker): # This could be achieved by passing a custom exception handler # when initializing the worker. As we already subclass it, it's # easier to override the default handler in case of failure - self.update_task_attributes(job, { - 'ended_at': job.ended_at, - 'status': models.JobStatus.FAILED, - }) + self.update_task_attributes( + job, {"ended_at": job.ended_at, "status": models.JobStatus.FAILED} + ) self.save_exception(job, *exc_info) super().move_to_failed_queue(job, *exc_info) def handle_job_success(self, job, queue, started_job_registry): - self.update_task_attributes(job, { - 'ended_at': job.ended_at, - 'status': models.JobStatus.FINISHED, - }) + self.update_task_attributes( + job, {"ended_at": job.ended_at, "status": models.JobStatus.FINISHED} + ) super().handle_job_success(job, queue, started_job_registry) def prepare_job_execution(self, job): - self.update_task_attributes(job, { - 'status': models.JobStatus.STARTED, - }) + self.update_task_attributes(job, {"status": models.JobStatus.STARTED}) super().prepare_job_execution(job) def trigger_vm_creation(name, interface, memory, cores): """Trigger a job to create a virtual machine or virtual IOC""" extra_vars = [ - f'vmname={name}', - f'memory={memory}', - f'cores={cores}', - f'vcpus={cores}', - f'vlan_name={interface.network.vlan_name}', - f'vlan_id={interface.network.vlan_id}', - f'mac={interface.mac.address}', + f"vmname={name}", + f"memory={memory}", + f"cores={cores}", + f"vcpus={cores}", + f"vlan_name={interface.network.vlan_name}", + f"vlan_id={interface.network.vlan_id}", + f"mac={interface.mac.address}", ] if interface.is_ioc: - task_name = 'trigger_vioc_creation' - job_template = current_app.config['AWX_CREATE_VIOC'] + task_name = "trigger_vioc_creation" + job_template = current_app.config["AWX_CREATE_VIOC"] else: - task_name = 'trigger_vm_creation' - job_template = current_app.config['AWX_CREATE_VM'] - extra_vars.extend([ - f'ip_address={interface.ip}', - f'domain={interface.network.domain.name}', - f'dns={current_app.config["VM_DEFAULT_DNS"]}', - f'netmask={interface.network.netmask}', - f'gateway={interface.network.gateway}', - ]) - current_app.logger.info(f'Launch new job to create the {name} VM: {job_template} with {extra_vars}') + task_name = "trigger_vm_creation" + job_template = current_app.config["AWX_CREATE_VM"] + extra_vars.extend( + [ + f"ip_address={interface.ip}", + f"domain={interface.network.domain.name}", + f'dns={current_app.config["VM_DEFAULT_DNS"]}', + f"netmask={interface.network.netmask}", + f"gateway={interface.network.gateway}", + ] + ) + current_app.logger.info( + f"Launch new job to create the {name} VM: {job_template} with {extra_vars}" + ) user = utils.cse_current_user() task = user.launch_task( task_name, - func='launch_job_template', + func="launch_job_template", job_template=job_template, extra_vars=extra_vars, timeout=500, @@ -125,43 +126,41 @@ def trigger_core_services_update(): We can have one running job + one in queue to apply the latest changes. Make sure that we don't have more than one in queue. """ - job_template = current_app.config['AWX_CORE_SERVICES_UPDATE'] + job_template = current_app.config["AWX_CORE_SERVICES_UPDATE"] user = utils.cse_current_user() - if user.is_task_waiting('trigger_core_services_update'): - current_app.logger.info('Already one "trigger_core_services_update" task waiting. No need to trigger a new one.') + if user.is_task_waiting("trigger_core_services_update"): + current_app.logger.info( + 'Already one "trigger_core_services_update" task waiting. No need to trigger a new one.' + ) return None - kwargs = { - 'func': 'launch_job_template', - 'job_template': job_template - } - started = user.get_task_started('trigger_core_services_update') + kwargs = {"func": "launch_job_template", "job_template": job_template} + started = user.get_task_started("trigger_core_services_update") if started: # There is already one running task. Trigger a new one when it's done. - kwargs['depends_on'] = started.id - current_app.logger.info(f'Launch new job to update core services: {job_template}') - task = user.launch_task( - 'trigger_core_services_update', - **kwargs, - ) + kwargs["depends_on"] = started.id + current_app.logger.info(f"Launch new job to update core services: {job_template}") + task = user.launch_task("trigger_core_services_update", **kwargs) return task def launch_job_template(job_template, **kwargs): rq_job = get_current_job() - if (job_template in (current_app.config['AWX_CREATE_VIOC'], current_app.config['AWX_CREATE_VM']) and - not current_app.config.get('AWX_VM_CREATION_ENABLED', False)): - current_app.logger.info('AWX VM creation is disabled. Not sending any request.') - return 'AWX VM creation not triggered' - if not current_app.config.get('AWX_JOB_ENABLED', False): - current_app.logger.info('AWX job is disabled. Not sending any request.') - return 'AWX job not triggered' + if job_template in ( + current_app.config["AWX_CREATE_VIOC"], + current_app.config["AWX_CREATE_VM"], + ) and not current_app.config.get("AWX_VM_CREATION_ENABLED", False): + current_app.logger.info("AWX VM creation is disabled. Not sending any request.") + return "AWX VM creation not triggered" + if not current_app.config.get("AWX_JOB_ENABLED", False): + current_app.logger.info("AWX job is disabled. Not sending any request.") + return "AWX job not triggered" # Launch the AWX job - resource = tower_cli.get_resource('job') + resource = tower_cli.get_resource("job") result = resource.launch(job_template=job_template, **kwargs) # Save the AWX job id in the task task = models.Task.query.get(rq_job.id) - task.awx_job_id = result['id'] + task.awx_job_id = result["id"] db.session.commit() # Monitor the job until done - result = resource.monitor(pk=result['id']) + result = resource.monitor(pk=result["id"]) return result diff --git a/app/tokens.py b/app/tokens.py index d646622..282cd8f 100644 --- a/app/tokens.py +++ b/app/tokens.py @@ -34,7 +34,7 @@ def is_token_in_blacklist(decoded_token): All created tokens are added to the database. If a token is not found in the database, it is considered blacklisted / revoked. """ - jti = decoded_token['jti'] + jti = decoded_token["jti"] try: models.Token.query.filter_by(jti=jti).one() except sa.orm.exc.NoResultFound: @@ -51,14 +51,14 @@ def generate_access_token(identity, fresh=False, expires_delta=None, description def save_token(encoded_token, description=None): """Add a new token to the database""" - identity_claim = current_app.config['JWT_IDENTITY_CLAIM'] + identity_claim = current_app.config["JWT_IDENTITY_CLAIM"] decoded_token = decode_token(encoded_token) - jti = decoded_token['jti'] - token_type = decoded_token['type'] + jti = decoded_token["jti"] + token_type = decoded_token["type"] user_id = int(decoded_token[identity_claim]) - iat = datetime.fromtimestamp(decoded_token['iat']) + iat = datetime.fromtimestamp(decoded_token["iat"]) try: - expires = datetime.fromtimestamp(decoded_token['exp']) + expires = datetime.fromtimestamp(decoded_token["exp"]) except KeyError: expires = None db_token = models.Token( @@ -81,16 +81,20 @@ def revoke_token(token_id, user_id): """ token = models.Token.query.get(token_id) if token is None: - raise utils.CSEntryError(f'Could not find the token {token_id}', status_code=404) + raise utils.CSEntryError( + f"Could not find the token {token_id}", status_code=404 + ) if token.user_id != user_id: - raise utils.CSEntryError(f"Token {token_id} doesn't belong to user {user_id}", status_code=401) + raise utils.CSEntryError( + f"Token {token_id} doesn't belong to user {user_id}", status_code=401 + ) db.session.delete(token) db.session.commit() def prune_database(): """Delete tokens that have expired from the database""" - current_app.logger.info('Delete expired tokens') + current_app.logger.info("Delete expired tokens") now = datetime.now() expired = models.Token.query.filter(models.Token.expires < now).all() for token in expired: diff --git a/app/user/forms.py b/app/user/forms.py index 38e5047..95ca383 100644 --- a/app/user/forms.py +++ b/app/user/forms.py @@ -14,4 +14,4 @@ from wtforms import StringField, validators class TokenForm(FlaskForm): - description = StringField('description', validators=[validators.DataRequired()]) + description = StringField("description", validators=[validators.DataRequired()]) diff --git a/app/user/views.py b/app/user/views.py index 9440dfc..c85f5b0 100644 --- a/app/user/views.py +++ b/app/user/views.py @@ -9,8 +9,16 @@ This module implements the user blueprint. :license: BSD 2-Clause, see LICENSE for more details. """ -from flask import (Blueprint, render_template, request, redirect, url_for, - flash, current_app, session) +from flask import ( + Blueprint, + render_template, + request, + redirect, + url_for, + flash, + current_app, + session, +) from flask_login import login_user, logout_user, login_required, current_user from flask_ldap3_login.forms import LDAPLoginForm from .forms import TokenForm @@ -18,28 +26,28 @@ from ..extensions import cache, db from ..models import load_user from .. import tokens, utils -bp = Blueprint('user', __name__) +bp = Blueprint("user", __name__) -@bp.route('/login', methods=['GET', 'POST']) +@bp.route("/login", methods=["GET", "POST"]) def login(): form = LDAPLoginForm(request.form) if form.validate_on_submit(): login_user(form.user, remember=form.remember_me.data) - return redirect(request.args.get('next') or url_for('main.index')) - return render_template('user/login.html', form=form) + return redirect(request.args.get("next") or url_for("main.index")) + return render_template("user/login.html", form=form) -@bp.route('/logout') +@bp.route("/logout") @login_required def logout(): # Don't forget to remove the user from the cache cache.delete_memoized(load_user, str(current_user.id)) logout_user() - return redirect(url_for('user.login')) + return redirect(url_for("user.login")) -@bp.route('/profile', methods=['GET', 'POST']) +@bp.route("/profile", methods=["GET", "POST"]) @login_required def profile(): if current_user not in db.session: @@ -47,32 +55,39 @@ def profile(): # Add it to access the user.tokens relationship in the template db.session.add(current_user) # Try to get the generated token from the session - token = session.pop('generated_token', None) + token = session.pop("generated_token", None) form = TokenForm(request.form) if form.validate_on_submit(): - token = tokens.generate_access_token(identity=current_user.id, - expires_delta=False, - description=form.description.data) + token = tokens.generate_access_token( + identity=current_user.id, + expires_delta=False, + description=form.description.data, + ) # Save token to the session to retrieve it after the redirect - session['generated_token'] = token - flash('Make sure to copy your new personal access token now. You won’t be able to see it again!', 'success') - return redirect(url_for('user.profile')) - return render_template('user/profile.html', - form=form, - user=current_user, - generated_token=token) + session["generated_token"] = token + flash( + "Make sure to copy your new personal access token now. You won’t be able to see it again!", + "success", + ) + return redirect(url_for("user.profile")) + return render_template( + "user/profile.html", form=form, user=current_user, generated_token=token + ) -@bp.route('/tokens/revoke', methods=['POST']) +@bp.route("/tokens/revoke", methods=["POST"]) @login_required def revoke_token(): - token_id = request.form['token_id'] - jti = request.form['jti'] + token_id = request.form["token_id"] + jti = request.form["jti"] try: tokens.revoke_token(token_id, current_user.id) except utils.CSEntryError as e: current_app.logger.warning(e) - flash(f'Could not revoke the token {jti}. Please contact an administrator.', 'error') + flash( + f"Could not revoke the token {jti}. Please contact an administrator.", + "error", + ) else: - flash(f'Token {jti} has been revoked', 'success') - return redirect(url_for('user.profile')) + flash(f"Token {jti} has been revoked", "success") + return redirect(url_for("user.profile")) diff --git a/app/utils.py b/app/utils.py index bcff931..4493ae5 100644 --- a/app/utils.py +++ b/app/utils.py @@ -47,6 +47,7 @@ class CSEntryError(Exception): Exception used to pass useful information to the client side (API or AJAX) """ + status_code = 400 def __init__(self, message, status_code=None, payload=None): @@ -58,14 +59,14 @@ class CSEntryError(Exception): def to_dict(self): rv = dict(self.payload or ()) - rv['message'] = self.message + rv["message"] = self.message return rv def __str__(self): return str(self.to_dict()) -def image_to_base64(img, format='PNG'): +def image_to_base64(img, format="PNG"): """Convert a Pillow image to a base64 string :param img: Pillow image @@ -74,7 +75,7 @@ def image_to_base64(img, format='PNG'): """ buf = io.BytesIO() img.save(buf, format=format) - return base64.b64encode(buf.getvalue()).decode('ascii') + return base64.b64encode(buf.getvalue()).decode("ascii") def format_field(field): @@ -82,11 +83,11 @@ def format_field(field): if field is None: return None if isinstance(field, datetime.datetime): - return field.strftime('%Y-%m-%d %H:%M') + return field.strftime("%Y-%m-%d %H:%M") return str(field) -def convert_to_model(item, model, filter='name'): +def convert_to_model(item, model, filter="name"): """Convert item to an instance of model Allow to convert a string to an instance of model @@ -100,7 +101,7 @@ def convert_to_model(item, model, filter='name'): kwarg = {filter: item} instance = model.query.filter_by(**kwarg).first() if instance is None: - raise CSEntryError(f'{item} is not a valid {model.__name__.lower()}') + raise CSEntryError(f"{item} is not a valid {model.__name__.lower()}") return instance return item @@ -127,22 +128,24 @@ def get_choices(iterable, allow_blank=False, allow_null=False): """Return a list of (value, label)""" choices = [] if allow_blank: - choices = [('', '')] + choices = [("", "")] if allow_null: - choices.append(('null', 'not set')) + choices.append(("null", "not set")) choices.extend([(val, val) for val in iterable]) return choices -def get_model_choices(model, allow_none=False, attr='name', query=None): +def get_model_choices(model, allow_none=False, attr="name", query=None): """Return a list of (value, label)""" choices = [] if allow_none: - choices = [(None, '')] + choices = [(None, "")] if query is None: query = model.query query = query.order_by(getattr(model, attr)) - choices.extend([(str(instance.id), getattr(instance, attr)) for instance in query.all()]) + choices.extend( + [(str(instance.id), getattr(instance, attr)) for instance in query.all()] + ) return choices @@ -157,7 +160,7 @@ def get_query(query, **kwargs): try: query = query.filter_by(**kwargs) except (sa.exc.InvalidRequestError, AttributeError) as e: - raise CSEntryError('Invalid query arguments', status_code=422) + raise CSEntryError("Invalid query arguments", status_code=422) return query @@ -176,7 +179,7 @@ def lowercase_field(value): # To pass wtforms validation, the value returned must be part of choices def coerce_to_str_or_none(value): """Convert '', None and 'None' to None""" - if value in ('', 'None') or value is None: + if value in ("", "None") or value is None: return None return str(value) @@ -194,11 +197,13 @@ def parse_to_utc(string): def random_mac(): """Return a random MAC address""" - octets = [random.randint(0x00, 0xFF), - random.randint(0x00, 0xFF), - random.randint(0x00, 0xFF)] - octets = [f'{nb:02x}' for nb in octets] - return ':'.join((current_app.config['MAC_OUI'], *octets)) + octets = [ + random.randint(0x00, 0xFF), + random.randint(0x00, 0xFF), + random.randint(0x00, 0xFF), + ] + octets = [f"{nb:02x}" for nb in octets] + return ":".join((current_app.config["MAC_OUI"], *octets)) def pluralize(singular): @@ -206,13 +211,13 @@ def pluralize(singular): Used to pluralize API endpoints (not any given english word) """ - if not singular.endswith('s'): - return singular + 's' + if not singular.endswith("s"): + return singular + "s" else: - return singular + 'es' + return singular + "es" -def format_datetime(value, format='%Y-%m-%d %H:%M'): +def format_datetime(value, format="%Y-%m-%d %H:%M"): """Format a datetime to string Function used as a jinja2 filter diff --git a/app/validators.py b/app/validators.py index 1228f50..d5d78ac 100644 --- a/app/validators.py +++ b/app/validators.py @@ -14,11 +14,11 @@ import re import sqlalchemy as sa from wtforms import ValidationError, SelectField -ICS_ID_RE = re.compile('[A-Z]{3}[0-9]{3}') -HOST_NAME_RE = re.compile('^[a-z0-9\-]{2,20}$') -VLAN_NAME_RE = re.compile('^[A-Za-z0-9\-]{3,25}$') -MAC_ADDRESS_RE = re.compile('^(?:[0-9a-fA-F]{2}[:-]?){5}[0-9a-fA-F]{2}$') -DEVICE_TYPE_RE = re.compile('^[A-Za-z0-9]{3,25}$') +ICS_ID_RE = re.compile("[A-Z]{3}[0-9]{3}") +HOST_NAME_RE = re.compile("^[a-z0-9\-]{2,20}$") +VLAN_NAME_RE = re.compile("^[A-Za-z0-9\-]{3,25}$") +MAC_ADDRESS_RE = re.compile("^(?:[0-9a-fA-F]{2}[:-]?){5}[0-9a-fA-F]{2}$") +DEVICE_TYPE_RE = re.compile("^[A-Za-z0-9]{3,25}$") TAG_RE = DEVICE_TYPE_RE @@ -39,6 +39,7 @@ class IPNetwork: :param message: the error message to raise in case of a validation error """ + def __init__(self, message=None): self.message = message @@ -47,7 +48,7 @@ class IPNetwork: ipaddress.ip_network(field.data, strict=True) except (ipaddress.AddressValueError, ipaddress.NetmaskValueError, ValueError): if self.message is None: - self.message = field.gettext('Invalid IP network.') + self.message = field.gettext("Invalid IP network.") raise ValidationError(self.message) @@ -61,7 +62,7 @@ class Unique(object): :param message: the error message """ - def __init__(self, model, column='name', message=None): + def __init__(self, model, column="name", message=None): self.model = model self.column = column self.message = message @@ -73,9 +74,9 @@ class Unique(object): try: kwargs = {self.column: field.data} obj = self.model.query.filter_by(**kwargs).one() - if not hasattr(form, '_obj') or not form._obj == obj: + if not hasattr(form, "_obj") or not form._obj == obj: if self.message is None: - self.message = field.gettext('Already exists.') + self.message = field.gettext("Already exists.") raise ValidationError(self.message) except sa.orm.exc.NoResultFound: pass @@ -87,10 +88,11 @@ class RegexpList: :param regex: the regular expression to use :param message: the error message """ + def __init__(self, regex, message=None): self.regex = regex if message is None: - message = 'Invalid input.' + message = "Invalid input." self.message = message def __call__(self, form, field): diff --git a/docs/conf.py b/docs/conf.py index 07beee3..aca2f36 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,8 @@ # import os import sys -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) # -- General configuration ------------------------------------------------ @@ -32,30 +33,30 @@ sys.path.insert(0, os.path.abspath('..')) # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.viewcode', - 'sphinxcontrib.httpdomain', - 'sphinxcontrib.autohttp.flask', - 'sphinxcontrib.autohttp.flaskqref', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", + "sphinxcontrib.httpdomain", + "sphinxcontrib.autohttp.flask", + "sphinxcontrib.autohttp.flaskqref", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'CSEntry' -copyright = '2018, Benjamin Bertrand' -author = 'Benjamin Bertrand' +project = "CSEntry" +copyright = "2018, Benjamin Bertrand" +author = "Benjamin Bertrand" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -64,10 +65,10 @@ author = 'Benjamin Bertrand' # The short X.Y version. try: # CI_COMMIT_TAG is defined by GitLab Runner when building tags - version = os.environ['CI_COMMIT_TAG'] + version = os.environ["CI_COMMIT_TAG"] except KeyError: # dev mode - version = os.popen('git describe').read().strip() + version = os.popen("git describe").read().strip() # The full version, including alpha/beta/rc tags. release = version @@ -81,10 +82,10 @@ language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -95,21 +96,18 @@ todo_include_todos = False # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # -html_theme_options = { - 'description': 'Control System Entry', - 'fixed_sidebar': True, -} +html_theme_options = {"description": "Control System Entry", "fixed_sidebar": True} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -117,11 +115,11 @@ html_static_path = ['_static'] # This is required for the alabaster theme # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars html_sidebars = { - '**': [ - 'about.html', - 'navigation.html', - 'relations.html', # needs 'show_related': True theme option to display - 'searchbox.html', + "**": [ + "about.html", + "navigation.html", + "relations.html", # needs 'show_related': True theme option to display + "searchbox.html", ] } @@ -129,7 +127,7 @@ html_sidebars = { # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'CSEntrydoc' +htmlhelp_basename = "CSEntrydoc" # -- Options for LaTeX output --------------------------------------------- @@ -138,15 +136,12 @@ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -156,8 +151,7 @@ latex_elements = { # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'CSEntry.tex', 'CSEntry Documentation', - 'Benjamin Bertrand', 'manual'), + (master_doc, "CSEntry.tex", "CSEntry Documentation", "Benjamin Bertrand", "manual") ] @@ -165,10 +159,7 @@ latex_documents = [ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'csentry', 'CSEntry Documentation', - [author], 1) -] +man_pages = [(master_doc, "csentry", "CSEntry Documentation", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -177,14 +168,20 @@ man_pages = [ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'CSEntry', 'CSEntry Documentation', - author, 'CSEntry', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "CSEntry", + "CSEntry Documentation", + author, + "CSEntry", + "One line description of project.", + "Miscellaneous", + ) ] # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', None), - 'csentry-api': ('http://ics-infrastructure.pages.esss.lu.se/csentry-api/', None), + "python": ("https://docs.python.org/3/", None), + "csentry-api": ("http://ics-infrastructure.pages.esss.lu.se/csentry-api/", None), } diff --git a/migrations/env.py b/migrations/env.py index 23663ff..e956c4b 100755 --- a/migrations/env.py +++ b/migrations/env.py @@ -11,16 +11,18 @@ config = context.config # Interpret the config file for Python logging. # This line sets up loggers basically. fileConfig(config.config_file_name) -logger = logging.getLogger('alembic.env') +logger = logging.getLogger("alembic.env") # add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata from flask import current_app -config.set_main_option('sqlalchemy.url', - current_app.config.get('SQLALCHEMY_DATABASE_URI')) -target_metadata = current_app.extensions['migrate'].db.metadata + +config.set_main_option( + "sqlalchemy.url", current_app.config.get("SQLALCHEMY_DATABASE_URI") +) +target_metadata = current_app.extensions["migrate"].db.metadata # other values from the config, defined by the needs of env.py, # can be acquired: @@ -59,21 +61,25 @@ def run_migrations_online(): # when there are no changes to the schema # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html def process_revision_directives(context, revision, directives): - if getattr(config.cmd_opts, 'autogenerate', False): + if getattr(config.cmd_opts, "autogenerate", False): script = directives[0] if script.upgrade_ops.is_empty(): directives[:] = [] - logger.info('No changes in schema detected.') + logger.info("No changes in schema detected.") - engine = engine_from_config(config.get_section(config.config_ini_section), - prefix='sqlalchemy.', - poolclass=pool.NullPool) + engine = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) connection = engine.connect() - context.configure(connection=connection, - target_metadata=target_metadata, - process_revision_directives=process_revision_directives, - **current_app.extensions['migrate'].configure_args) + context.configure( + connection=connection, + target_metadata=target_metadata, + process_revision_directives=process_revision_directives, + **current_app.extensions["migrate"].configure_args, + ) try: with context.begin_transaction(): @@ -81,6 +87,7 @@ def run_migrations_online(): finally: connection.close() + if context.is_offline_mode(): run_migrations_offline() else: diff --git a/migrations/versions/573560351033_add_stack_member_field_to_item_table.py b/migrations/versions/573560351033_add_stack_member_field_to_item_table.py index e3a1b33..c2dab79 100644 --- a/migrations/versions/573560351033_add_stack_member_field_to_item_table.py +++ b/migrations/versions/573560351033_add_stack_member_field_to_item_table.py @@ -10,25 +10,32 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = '573560351033' -down_revision = '7ffb5fbbd0f0' +revision = "573560351033" +down_revision = "7ffb5fbbd0f0" branch_labels = None depends_on = None def upgrade(): - op.add_column('item', sa.Column('stack_member', sa.SmallInteger(), nullable=True)) - op.create_unique_constraint(op.f('uq_item_host_id_stack_member'), 'item', ['host_id', 'stack_member']) + op.add_column("item", sa.Column("stack_member", sa.SmallInteger(), nullable=True)) + op.create_unique_constraint( + op.f("uq_item_host_id_stack_member"), "item", ["host_id", "stack_member"] + ) op.create_check_constraint( - op.f('ck_item_stack_member_range'), - 'item', - 'stack_member >= 0 AND stack_member <=9' + op.f("ck_item_stack_member_range"), + "item", + "stack_member >= 0 AND stack_member <=9", + ) + op.add_column( + "item_version", + sa.Column( + "stack_member", sa.SmallInteger(), autoincrement=False, nullable=True + ), ) - op.add_column('item_version', sa.Column('stack_member', sa.SmallInteger(), autoincrement=False, nullable=True)) def downgrade(): - op.drop_column('item_version', 'stack_member') - op.drop_constraint(op.f('uq_item_host_id_stack_member'), 'item', type_='unique') - op.drop_constraint(op.f('ck_item_stack_member_range'), 'item', type_='check') - op.drop_column('item', 'stack_member') + op.drop_column("item_version", "stack_member") + op.drop_constraint(op.f("uq_item_host_id_stack_member"), "item", type_="unique") + op.drop_constraint(op.f("ck_item_stack_member_range"), "item", type_="check") + op.drop_column("item", "stack_member") diff --git a/migrations/versions/713ca10255ab_.py b/migrations/versions/713ca10255ab_.py index a9c83ab..fca5834 100644 --- a/migrations/versions/713ca10255ab_.py +++ b/migrations/versions/713ca10255ab_.py @@ -11,7 +11,7 @@ from sqlalchemy.dialects import postgresql import citext # revision identifiers, used by Alembic. -revision = '713ca10255ab' +revision = "713ca10255ab" down_revision = None branch_labels = None depends_on = None @@ -19,254 +19,369 @@ depends_on = None def upgrade(): connection = op.get_bind() - connection.execute('CREATE EXTENSION IF NOT EXISTS citext') - op.create_table('action', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', citext.CIText(), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id', name=op.f('pk_action')), - sa.UniqueConstraint('name', name=op.f('uq_action_name')) - ) - op.create_table('item_version', - sa.Column('updated_at', sa.DateTime(), autoincrement=False, nullable=True), - sa.Column('id', sa.Integer(), autoincrement=False, nullable=False), - sa.Column('quantity', sa.Integer(), autoincrement=False, nullable=True), - sa.Column('location_id', sa.Integer(), autoincrement=False, nullable=True), - sa.Column('status_id', sa.Integer(), autoincrement=False, nullable=True), - sa.Column('parent_id', sa.Integer(), autoincrement=False, nullable=True), - sa.Column('transaction_id', sa.BigInteger(), autoincrement=False, nullable=False), - sa.Column('end_transaction_id', sa.BigInteger(), nullable=True), - sa.Column('operation_type', sa.SmallInteger(), nullable=False), - sa.PrimaryKeyConstraint('id', 'transaction_id', name=op.f('pk_item_version')) - ) - op.create_index(op.f('ix_item_version_end_transaction_id'), 'item_version', ['end_transaction_id'], unique=False) - op.create_index(op.f('ix_item_version_operation_type'), 'item_version', ['operation_type'], unique=False) - op.create_index(op.f('ix_item_version_transaction_id'), 'item_version', ['transaction_id'], unique=False) - op.create_table('location', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', citext.CIText(), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id', name=op.f('pk_location')), - sa.UniqueConstraint('name', name=op.f('uq_location_name')) - ) - op.create_table('manufacturer', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', citext.CIText(), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id', name=op.f('pk_manufacturer')), - sa.UniqueConstraint('name', name=op.f('uq_manufacturer_name')) - ) - op.create_table('model', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', citext.CIText(), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id', name=op.f('pk_model')), - sa.UniqueConstraint('name', name=op.f('uq_model_name')) - ) - op.create_table('status', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', citext.CIText(), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id', name=op.f('pk_status')), - sa.UniqueConstraint('name', name=op.f('uq_status_name')) - ) - op.create_table('tag', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', citext.CIText(), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('admin_only', sa.Boolean(), nullable=False), - sa.PrimaryKeyConstraint('id', name=op.f('pk_tag')), - sa.UniqueConstraint('name', name=op.f('uq_tag_name')) - ) - op.create_table('user_account', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('username', sa.Text(), nullable=False), - sa.Column('display_name', sa.Text(), nullable=False), - sa.Column('email', sa.Text(), nullable=True), - sa.Column('groups', postgresql.ARRAY(sa.Text()), nullable=True), - sa.PrimaryKeyConstraint('id', name=op.f('pk_user_account')), - sa.UniqueConstraint('username', name=op.f('uq_user_account_username')) - ) - op.create_table('item', - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('ics_id', sa.Text(), nullable=False), - sa.Column('serial_number', sa.Text(), nullable=False), - sa.Column('quantity', sa.Integer(), nullable=False), - sa.Column('manufacturer_id', sa.Integer(), nullable=True), - sa.Column('model_id', sa.Integer(), nullable=True), - sa.Column('location_id', sa.Integer(), nullable=True), - sa.Column('status_id', sa.Integer(), nullable=True), - sa.Column('parent_id', sa.Integer(), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['location_id'], ['location.id'], name=op.f('fk_item_location_id_location')), - sa.ForeignKeyConstraint(['manufacturer_id'], ['manufacturer.id'], name=op.f('fk_item_manufacturer_id_manufacturer')), - sa.ForeignKeyConstraint(['model_id'], ['model.id'], name=op.f('fk_item_model_id_model')), - sa.ForeignKeyConstraint(['parent_id'], ['item.id'], name=op.f('fk_item_parent_id_item')), - sa.ForeignKeyConstraint(['status_id'], ['status.id'], name=op.f('fk_item_status_id_status')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_item_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_item')) - ) - op.create_index(op.f('ix_item_ics_id'), 'item', ['ics_id'], unique=True) - op.create_table('network_scope', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('name', citext.CIText(), nullable=False), - sa.Column('first_vlan', sa.Integer(), nullable=False), - sa.Column('last_vlan', sa.Integer(), nullable=False), - sa.Column('supernet', postgresql.CIDR(), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.CheckConstraint('first_vlan < last_vlan', name=op.f('ck_network_scope_first_vlan_less_than_last_vlan')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_network_scope_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_network_scope')), - sa.UniqueConstraint('first_vlan', name=op.f('uq_network_scope_first_vlan')), - sa.UniqueConstraint('last_vlan', name=op.f('uq_network_scope_last_vlan')), - sa.UniqueConstraint('name', name=op.f('uq_network_scope_name')), - sa.UniqueConstraint('supernet', name=op.f('uq_network_scope_supernet')) - ) - op.create_table('token', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('jti', postgresql.UUID(), nullable=False), - sa.Column('token_type', sa.Text(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('issued_at', sa.DateTime(), nullable=False), - sa.Column('expires', sa.DateTime(), nullable=True), - sa.Column('description', sa.Text(), nullable=True), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_token_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_token')), - sa.UniqueConstraint('jti', 'user_id', name=op.f('uq_token_jti')) - ) - op.create_table('transaction', - sa.Column('issued_at', sa.DateTime(), nullable=True), - sa.Column('id', sa.BigInteger(), nullable=False), - sa.Column('remote_addr', sa.String(length=50), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_transaction_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_transaction')) - ) - op.create_index(op.f('ix_transaction_user_id'), 'transaction', ['user_id'], unique=False) - op.create_table('host', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('name', sa.Text(), nullable=False), - sa.Column('type', sa.Text(), nullable=True), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('item_id', sa.Integer(), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['item_id'], ['item.id'], name=op.f('fk_host_item_id_item')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_host_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_host')), - sa.UniqueConstraint('name', name=op.f('uq_host_name')) - ) - op.create_table('item_comment', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('body', sa.Text(), nullable=False), - sa.Column('item_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['item_id'], ['item.id'], name=op.f('fk_item_comment_item_id_item')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_item_comment_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_item_comment')) - ) - op.create_table('mac', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('address', postgresql.MACADDR(), nullable=False), - sa.Column('item_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['item_id'], ['item.id'], name=op.f('fk_mac_item_id_item')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_mac')), - sa.UniqueConstraint('address', name=op.f('uq_mac_address')) - ) - op.create_table('network', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('vlan_name', citext.CIText(), nullable=False), - sa.Column('vlan_id', sa.Integer(), nullable=False), - sa.Column('address', postgresql.CIDR(), nullable=False), - sa.Column('first_ip', postgresql.INET(), nullable=False), - sa.Column('last_ip', postgresql.INET(), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('admin_only', sa.Boolean(), nullable=False), - sa.Column('scope_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.CheckConstraint('first_ip < last_ip', name=op.f('ck_network_first_ip_less_than_last_ip')), - sa.CheckConstraint('first_ip << address', name=op.f('ck_network_first_ip_in_network')), - sa.CheckConstraint('last_ip << address', name=op.f('ck_network_last_ip_in_network')), - sa.ForeignKeyConstraint(['scope_id'], ['network_scope.id'], name=op.f('fk_network_scope_id_network_scope')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_network_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_network')), - sa.UniqueConstraint('address', name=op.f('uq_network_address')), - sa.UniqueConstraint('first_ip', name=op.f('uq_network_first_ip')), - sa.UniqueConstraint('last_ip', name=op.f('uq_network_last_ip')), - sa.UniqueConstraint('vlan_id', name=op.f('uq_network_vlan_id')), - sa.UniqueConstraint('vlan_name', name=op.f('uq_network_vlan_name')) - ) - op.create_table('interface', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('network_id', sa.Integer(), nullable=False), - sa.Column('ip', postgresql.INET(), nullable=False), - sa.Column('name', sa.Text(), nullable=False), - sa.Column('mac_id', sa.Integer(), nullable=True), - sa.Column('host_id', sa.Integer(), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['host_id'], ['host.id'], name=op.f('fk_interface_host_id_host')), - sa.ForeignKeyConstraint(['mac_id'], ['mac.id'], name=op.f('fk_interface_mac_id_mac')), - sa.ForeignKeyConstraint(['network_id'], ['network.id'], name=op.f('fk_interface_network_id_network')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_interface_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_interface')), - sa.UniqueConstraint('ip', name=op.f('uq_interface_ip')), - sa.UniqueConstraint('name', name=op.f('uq_interface_name')) - ) - op.create_table('cname', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('name', sa.Text(), nullable=False), - sa.Column('interface_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['interface_id'], ['interface.id'], name=op.f('fk_cname_interface_id_interface')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_cname_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_cname')), - sa.UniqueConstraint('name', name=op.f('uq_cname_name')) - ) - op.create_table('interfacetags', - sa.Column('tag_id', sa.Integer(), nullable=False), - sa.Column('interface_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['interface_id'], ['interface.id'], name=op.f('fk_interfacetags_interface_id_interface')), - sa.ForeignKeyConstraint(['tag_id'], ['tag.id'], name=op.f('fk_interfacetags_tag_id_tag')), - sa.PrimaryKeyConstraint('tag_id', 'interface_id', name=op.f('pk_interfacetags')) - ) + connection.execute("CREATE EXTENSION IF NOT EXISTS citext") + op.create_table( + "action", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", citext.CIText(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("pk_action")), + sa.UniqueConstraint("name", name=op.f("uq_action_name")), + ) + op.create_table( + "item_version", + sa.Column("updated_at", sa.DateTime(), autoincrement=False, nullable=True), + sa.Column("id", sa.Integer(), autoincrement=False, nullable=False), + sa.Column("quantity", sa.Integer(), autoincrement=False, nullable=True), + sa.Column("location_id", sa.Integer(), autoincrement=False, nullable=True), + sa.Column("status_id", sa.Integer(), autoincrement=False, nullable=True), + sa.Column("parent_id", sa.Integer(), autoincrement=False, nullable=True), + sa.Column( + "transaction_id", sa.BigInteger(), autoincrement=False, nullable=False + ), + sa.Column("end_transaction_id", sa.BigInteger(), nullable=True), + sa.Column("operation_type", sa.SmallInteger(), nullable=False), + sa.PrimaryKeyConstraint("id", "transaction_id", name=op.f("pk_item_version")), + ) + op.create_index( + op.f("ix_item_version_end_transaction_id"), + "item_version", + ["end_transaction_id"], + unique=False, + ) + op.create_index( + op.f("ix_item_version_operation_type"), + "item_version", + ["operation_type"], + unique=False, + ) + op.create_index( + op.f("ix_item_version_transaction_id"), + "item_version", + ["transaction_id"], + unique=False, + ) + op.create_table( + "location", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", citext.CIText(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("pk_location")), + sa.UniqueConstraint("name", name=op.f("uq_location_name")), + ) + op.create_table( + "manufacturer", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", citext.CIText(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("pk_manufacturer")), + sa.UniqueConstraint("name", name=op.f("uq_manufacturer_name")), + ) + op.create_table( + "model", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", citext.CIText(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("pk_model")), + sa.UniqueConstraint("name", name=op.f("uq_model_name")), + ) + op.create_table( + "status", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", citext.CIText(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("pk_status")), + sa.UniqueConstraint("name", name=op.f("uq_status_name")), + ) + op.create_table( + "tag", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", citext.CIText(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("admin_only", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_tag")), + sa.UniqueConstraint("name", name=op.f("uq_tag_name")), + ) + op.create_table( + "user_account", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("username", sa.Text(), nullable=False), + sa.Column("display_name", sa.Text(), nullable=False), + sa.Column("email", sa.Text(), nullable=True), + sa.Column("groups", postgresql.ARRAY(sa.Text()), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("pk_user_account")), + sa.UniqueConstraint("username", name=op.f("uq_user_account_username")), + ) + op.create_table( + "item", + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("ics_id", sa.Text(), nullable=False), + sa.Column("serial_number", sa.Text(), nullable=False), + sa.Column("quantity", sa.Integer(), nullable=False), + sa.Column("manufacturer_id", sa.Integer(), nullable=True), + sa.Column("model_id", sa.Integer(), nullable=True), + sa.Column("location_id", sa.Integer(), nullable=True), + sa.Column("status_id", sa.Integer(), nullable=True), + sa.Column("parent_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["location_id"], ["location.id"], name=op.f("fk_item_location_id_location") + ), + sa.ForeignKeyConstraint( + ["manufacturer_id"], + ["manufacturer.id"], + name=op.f("fk_item_manufacturer_id_manufacturer"), + ), + sa.ForeignKeyConstraint( + ["model_id"], ["model.id"], name=op.f("fk_item_model_id_model") + ), + sa.ForeignKeyConstraint( + ["parent_id"], ["item.id"], name=op.f("fk_item_parent_id_item") + ), + sa.ForeignKeyConstraint( + ["status_id"], ["status.id"], name=op.f("fk_item_status_id_status") + ), + sa.ForeignKeyConstraint( + ["user_id"], ["user_account.id"], name=op.f("fk_item_user_id_user_account") + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_item")), + ) + op.create_index(op.f("ix_item_ics_id"), "item", ["ics_id"], unique=True) + op.create_table( + "network_scope", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("name", citext.CIText(), nullable=False), + sa.Column("first_vlan", sa.Integer(), nullable=False), + sa.Column("last_vlan", sa.Integer(), nullable=False), + sa.Column("supernet", postgresql.CIDR(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.CheckConstraint( + "first_vlan < last_vlan", + name=op.f("ck_network_scope_first_vlan_less_than_last_vlan"), + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_network_scope_user_id_user_account"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_network_scope")), + sa.UniqueConstraint("first_vlan", name=op.f("uq_network_scope_first_vlan")), + sa.UniqueConstraint("last_vlan", name=op.f("uq_network_scope_last_vlan")), + sa.UniqueConstraint("name", name=op.f("uq_network_scope_name")), + sa.UniqueConstraint("supernet", name=op.f("uq_network_scope_supernet")), + ) + op.create_table( + "token", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("jti", postgresql.UUID(), nullable=False), + sa.Column("token_type", sa.Text(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("issued_at", sa.DateTime(), nullable=False), + sa.Column("expires", sa.DateTime(), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], ["user_account.id"], name=op.f("fk_token_user_id_user_account") + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_token")), + sa.UniqueConstraint("jti", "user_id", name=op.f("uq_token_jti")), + ) + op.create_table( + "transaction", + sa.Column("issued_at", sa.DateTime(), nullable=True), + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("remote_addr", sa.String(length=50), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_transaction_user_id_user_account"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_transaction")), + ) + op.create_index( + op.f("ix_transaction_user_id"), "transaction", ["user_id"], unique=False + ) + op.create_table( + "host", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("type", sa.Text(), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("item_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["item_id"], ["item.id"], name=op.f("fk_host_item_id_item") + ), + sa.ForeignKeyConstraint( + ["user_id"], ["user_account.id"], name=op.f("fk_host_user_id_user_account") + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_host")), + sa.UniqueConstraint("name", name=op.f("uq_host_name")), + ) + op.create_table( + "item_comment", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("body", sa.Text(), nullable=False), + sa.Column("item_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["item_id"], ["item.id"], name=op.f("fk_item_comment_item_id_item") + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_item_comment_user_id_user_account"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_item_comment")), + ) + op.create_table( + "mac", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("address", postgresql.MACADDR(), nullable=False), + sa.Column("item_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["item_id"], ["item.id"], name=op.f("fk_mac_item_id_item") + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_mac")), + sa.UniqueConstraint("address", name=op.f("uq_mac_address")), + ) + op.create_table( + "network", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("vlan_name", citext.CIText(), nullable=False), + sa.Column("vlan_id", sa.Integer(), nullable=False), + sa.Column("address", postgresql.CIDR(), nullable=False), + sa.Column("first_ip", postgresql.INET(), nullable=False), + sa.Column("last_ip", postgresql.INET(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("admin_only", sa.Boolean(), nullable=False), + sa.Column("scope_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.CheckConstraint( + "first_ip < last_ip", name=op.f("ck_network_first_ip_less_than_last_ip") + ), + sa.CheckConstraint( + "first_ip << address", name=op.f("ck_network_first_ip_in_network") + ), + sa.CheckConstraint( + "last_ip << address", name=op.f("ck_network_last_ip_in_network") + ), + sa.ForeignKeyConstraint( + ["scope_id"], + ["network_scope.id"], + name=op.f("fk_network_scope_id_network_scope"), + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_network_user_id_user_account"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_network")), + sa.UniqueConstraint("address", name=op.f("uq_network_address")), + sa.UniqueConstraint("first_ip", name=op.f("uq_network_first_ip")), + sa.UniqueConstraint("last_ip", name=op.f("uq_network_last_ip")), + sa.UniqueConstraint("vlan_id", name=op.f("uq_network_vlan_id")), + sa.UniqueConstraint("vlan_name", name=op.f("uq_network_vlan_name")), + ) + op.create_table( + "interface", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("network_id", sa.Integer(), nullable=False), + sa.Column("ip", postgresql.INET(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("mac_id", sa.Integer(), nullable=True), + sa.Column("host_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["host_id"], ["host.id"], name=op.f("fk_interface_host_id_host") + ), + sa.ForeignKeyConstraint( + ["mac_id"], ["mac.id"], name=op.f("fk_interface_mac_id_mac") + ), + sa.ForeignKeyConstraint( + ["network_id"], ["network.id"], name=op.f("fk_interface_network_id_network") + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_interface_user_id_user_account"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_interface")), + sa.UniqueConstraint("ip", name=op.f("uq_interface_ip")), + sa.UniqueConstraint("name", name=op.f("uq_interface_name")), + ) + op.create_table( + "cname", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("interface_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["interface_id"], + ["interface.id"], + name=op.f("fk_cname_interface_id_interface"), + ), + sa.ForeignKeyConstraint( + ["user_id"], ["user_account.id"], name=op.f("fk_cname_user_id_user_account") + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_cname")), + sa.UniqueConstraint("name", name=op.f("uq_cname_name")), + ) + op.create_table( + "interfacetags", + sa.Column("tag_id", sa.Integer(), nullable=False), + sa.Column("interface_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["interface_id"], + ["interface.id"], + name=op.f("fk_interfacetags_interface_id_interface"), + ), + sa.ForeignKeyConstraint( + ["tag_id"], ["tag.id"], name=op.f("fk_interfacetags_tag_id_tag") + ), + sa.PrimaryKeyConstraint( + "tag_id", "interface_id", name=op.f("pk_interfacetags") + ), + ) def downgrade(): - op.drop_table('interfacetags') - op.drop_table('cname') - op.drop_table('interface') - op.drop_table('network') - op.drop_table('mac') - op.drop_table('item_comment') - op.drop_table('host') - op.drop_index(op.f('ix_transaction_user_id'), table_name='transaction') - op.drop_table('transaction') - op.drop_table('token') - op.drop_table('network_scope') - op.drop_index(op.f('ix_item_ics_id'), table_name='item') - op.drop_table('item') - op.drop_table('user_account') - op.drop_table('tag') - op.drop_table('status') - op.drop_table('model') - op.drop_table('manufacturer') - op.drop_table('location') - op.drop_index(op.f('ix_item_version_transaction_id'), table_name='item_version') - op.drop_index(op.f('ix_item_version_operation_type'), table_name='item_version') - op.drop_index(op.f('ix_item_version_end_transaction_id'), table_name='item_version') - op.drop_table('item_version') - op.drop_table('action') + op.drop_table("interfacetags") + op.drop_table("cname") + op.drop_table("interface") + op.drop_table("network") + op.drop_table("mac") + op.drop_table("item_comment") + op.drop_table("host") + op.drop_index(op.f("ix_transaction_user_id"), table_name="transaction") + op.drop_table("transaction") + op.drop_table("token") + op.drop_table("network_scope") + op.drop_index(op.f("ix_item_ics_id"), table_name="item") + op.drop_table("item") + op.drop_table("user_account") + op.drop_table("tag") + op.drop_table("status") + op.drop_table("model") + op.drop_table("manufacturer") + op.drop_table("location") + op.drop_index(op.f("ix_item_version_transaction_id"), table_name="item_version") + op.drop_index(op.f("ix_item_version_operation_type"), table_name="item_version") + op.drop_index(op.f("ix_item_version_end_transaction_id"), table_name="item_version") + op.drop_table("item_version") + op.drop_table("action") diff --git a/migrations/versions/7d0d580cdb1a_add_task_table.py b/migrations/versions/7d0d580cdb1a_add_task_table.py index 33cdd67..2f6278e 100644 --- a/migrations/versions/7d0d580cdb1a_add_task_table.py +++ b/migrations/versions/7d0d580cdb1a_add_task_table.py @@ -10,28 +10,36 @@ import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. -revision = '7d0d580cdb1a' -down_revision = 'f5a605c0c835' +revision = "7d0d580cdb1a" +down_revision = "f5a605c0c835" branch_labels = None depends_on = None def upgrade(): op.create_table( - 'task', - sa.Column('id', postgresql.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('name', sa.Text(), nullable=False), - sa.Column('command', sa.Text(), nullable=True), - sa.Column('status', sa.Enum('QUEUED', 'FINISHED', 'FAILED', 'STARTED', 'DEFERRED', name='job_status'), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_task_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_task')) + "task", + sa.Column("id", postgresql.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("command", sa.Text(), nullable=True), + sa.Column( + "status", + sa.Enum( + "QUEUED", "FINISHED", "FAILED", "STARTED", "DEFERRED", name="job_status" + ), + nullable=True, + ), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], ["user_account.id"], name=op.f("fk_task_user_id_user_account") + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_task")), ) - op.create_index(op.f('ix_task_name'), 'task', ['name'], unique=False) + op.create_index(op.f("ix_task_name"), "task", ["name"], unique=False) def downgrade(): - op.drop_index(op.f('ix_task_name'), table_name='task') - op.drop_table('task') - op.execute('DROP TYPE job_status') + op.drop_index(op.f("ix_task_name"), table_name="task") + op.drop_table("task") + op.execute("DROP TYPE job_status") diff --git a/migrations/versions/7ffb5fbbd0f0_allow_to_associate_several_items_to_one_.py b/migrations/versions/7ffb5fbbd0f0_allow_to_associate_several_items_to_one_.py index a19dea2..da9577d 100644 --- a/migrations/versions/7ffb5fbbd0f0_allow_to_associate_several_items_to_one_.py +++ b/migrations/versions/7ffb5fbbd0f0_allow_to_associate_several_items_to_one_.py @@ -10,39 +10,50 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = '7ffb5fbbd0f0' -down_revision = 'e07c7bc870be' +revision = "7ffb5fbbd0f0" +down_revision = "e07c7bc870be" branch_labels = None depends_on = None def upgrade(): - op.add_column('item', sa.Column('host_id', sa.Integer(), nullable=True)) - op.create_foreign_key(op.f('fk_item_host_id_host'), 'item', 'host', ['host_id'], ['id']) - op.add_column('item_version', sa.Column('host_id', sa.Integer(), autoincrement=False, nullable=True)) + op.add_column("item", sa.Column("host_id", sa.Integer(), nullable=True)) + op.create_foreign_key( + op.f("fk_item_host_id_host"), "item", "host", ["host_id"], ["id"] + ) + op.add_column( + "item_version", + sa.Column("host_id", sa.Integer(), autoincrement=False, nullable=True), + ) # Fill the item host_id based on the old host item_id value conn = op.get_bind() - res = conn.execute('SELECT id, item_id FROM host WHERE item_id IS NOT NULL') + res = conn.execute("SELECT id, item_id FROM host WHERE item_id IS NOT NULL") results = res.fetchall() - item = sa.sql.table('item', sa.sql.column('id'), sa.sql.column('host_id')) + item = sa.sql.table("item", sa.sql.column("id"), sa.sql.column("host_id")) for result in results: - op.execute(item.update().where(item.c.id == result[1]).values(host_id=result[0])) + op.execute( + item.update().where(item.c.id == result[1]).values(host_id=result[0]) + ) # We can drop the item_id column now - op.drop_constraint('fk_host_item_id_item', 'host', type_='foreignkey') - op.drop_column('host', 'item_id') + op.drop_constraint("fk_host_item_id_item", "host", type_="foreignkey") + op.drop_column("host", "item_id") def downgrade(): - op.add_column('host', sa.Column('item_id', sa.INTEGER(), autoincrement=False, nullable=True)) - op.create_foreign_key('fk_host_item_id_item', 'host', 'item', ['item_id'], ['id']) + op.add_column( + "host", sa.Column("item_id", sa.INTEGER(), autoincrement=False, nullable=True) + ) + op.create_foreign_key("fk_host_item_id_item", "host", "item", ["item_id"], ["id"]) # Fill the host item_id based on the item host_id value conn = op.get_bind() - res = conn.execute('SELECT id, host_id FROM item WHERE host_id IS NOT NULL') + res = conn.execute("SELECT id, host_id FROM item WHERE host_id IS NOT NULL") results = res.fetchall() - host = sa.sql.table('host', sa.sql.column('id'), sa.sql.column('item_id')) + host = sa.sql.table("host", sa.sql.column("id"), sa.sql.column("item_id")) for result in results: - op.execute(host.update().where(host.c.id == result[1]).values(item_id=result[0])) + op.execute( + host.update().where(host.c.id == result[1]).values(item_id=result[0]) + ) # Drop the unused columns - op.drop_column('item_version', 'host_id') - op.drop_constraint(op.f('fk_item_host_id_host'), 'item', type_='foreignkey') - op.drop_column('item', 'host_id') + op.drop_column("item_version", "host_id") + op.drop_constraint(op.f("fk_item_host_id_host"), "item", type_="foreignkey") + op.drop_column("item", "host_id") diff --git a/migrations/versions/8f135d5efde2_rename_virtual_device_type.py b/migrations/versions/8f135d5efde2_rename_virtual_device_type.py index 8d18bd5..ef1bcde 100644 --- a/migrations/versions/8f135d5efde2_rename_virtual_device_type.py +++ b/migrations/versions/8f135d5efde2_rename_virtual_device_type.py @@ -10,17 +10,29 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = '8f135d5efde2' -down_revision = '573560351033' +revision = "8f135d5efde2" +down_revision = "573560351033" branch_labels = None depends_on = None def upgrade(): - device_type = sa.sql.table('device_type', sa.sql.column('id'), sa.sql.column('name')) - op.execute(device_type.update().where(device_type.c.name == 'Virtual').values(name='Virtual Machine')) + device_type = sa.sql.table( + "device_type", sa.sql.column("id"), sa.sql.column("name") + ) + op.execute( + device_type.update() + .where(device_type.c.name == "Virtual") + .values(name="Virtual Machine") + ) def downgrade(): - device_type = sa.sql.table('device_type', sa.sql.column('id'), sa.sql.column('name')) - op.execute(device_type.update().where(device_type.c.name == 'Virtual Machine').values(name='Virtual')) + device_type = sa.sql.table( + "device_type", sa.sql.column("id"), sa.sql.column("name") + ) + op.execute( + device_type.update() + .where(device_type.c.name == "Virtual Machine") + .values(name="Virtual") + ) diff --git a/migrations/versions/a73eeb144fa1_add_user_favorite_attributes_tables.py b/migrations/versions/a73eeb144fa1_add_user_favorite_attributes_tables.py index 1594017..7c888d0 100644 --- a/migrations/versions/a73eeb144fa1_add_user_favorite_attributes_tables.py +++ b/migrations/versions/a73eeb144fa1_add_user_favorite_attributes_tables.py @@ -10,58 +10,104 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'a73eeb144fa1' -down_revision = 'ac6b3c416b07' +revision = "a73eeb144fa1" +down_revision = "ac6b3c416b07" branch_labels = None depends_on = None def upgrade(): op.create_table( - 'favorite_actions', - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('action_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['action_id'], ['action.id'], name=op.f('fk_favorite_actions_action_id_action')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_favorite_actions_user_id_user_account')), - sa.PrimaryKeyConstraint('user_id', 'action_id', name=op.f('pk_favorite_actions')) + "favorite_actions", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("action_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["action_id"], + ["action.id"], + name=op.f("fk_favorite_actions_action_id_action"), + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_favorite_actions_user_id_user_account"), + ), + sa.PrimaryKeyConstraint( + "user_id", "action_id", name=op.f("pk_favorite_actions") + ), ) op.create_table( - 'favorite_locations', - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('location_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['location_id'], ['location.id'], name=op.f('fk_favorite_locations_location_id_location')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_favorite_locations_user_id_user_account')), - sa.PrimaryKeyConstraint('user_id', 'location_id', name=op.f('pk_favorite_locations')) + "favorite_locations", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("location_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["location_id"], + ["location.id"], + name=op.f("fk_favorite_locations_location_id_location"), + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_favorite_locations_user_id_user_account"), + ), + sa.PrimaryKeyConstraint( + "user_id", "location_id", name=op.f("pk_favorite_locations") + ), ) op.create_table( - 'favorite_manufacturers', - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('manufacturer_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['manufacturer_id'], ['manufacturer.id'], name=op.f('fk_favorite_manufacturers_manufacturer_id_manufacturer')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_favorite_manufacturers_user_id_user_account')), - sa.PrimaryKeyConstraint('user_id', 'manufacturer_id', name=op.f('pk_favorite_manufacturers')) + "favorite_manufacturers", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("manufacturer_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["manufacturer_id"], + ["manufacturer.id"], + name=op.f("fk_favorite_manufacturers_manufacturer_id_manufacturer"), + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_favorite_manufacturers_user_id_user_account"), + ), + sa.PrimaryKeyConstraint( + "user_id", "manufacturer_id", name=op.f("pk_favorite_manufacturers") + ), ) op.create_table( - 'favorite_models', - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('model_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['model_id'], ['model.id'], name=op.f('fk_favorite_models_model_id_model')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_favorite_models_user_id_user_account')), - sa.PrimaryKeyConstraint('user_id', 'model_id', name=op.f('pk_favorite_models')) + "favorite_models", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("model_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["model_id"], ["model.id"], name=op.f("fk_favorite_models_model_id_model") + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_favorite_models_user_id_user_account"), + ), + sa.PrimaryKeyConstraint("user_id", "model_id", name=op.f("pk_favorite_models")), ) op.create_table( - 'favorite_statuses', - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('status_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['status_id'], ['status.id'], name=op.f('fk_favorite_statuses_status_id_status')), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_favorite_statuses_user_id_user_account')), - sa.PrimaryKeyConstraint('user_id', 'status_id', name=op.f('pk_favorite_statuses')) + "favorite_statuses", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("status_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["status_id"], + ["status.id"], + name=op.f("fk_favorite_statuses_status_id_status"), + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_favorite_statuses_user_id_user_account"), + ), + sa.PrimaryKeyConstraint( + "user_id", "status_id", name=op.f("pk_favorite_statuses") + ), ) def downgrade(): - op.drop_table('favorite_statuses') - op.drop_table('favorite_models') - op.drop_table('favorite_manufacturers') - op.drop_table('favorite_locations') - op.drop_table('favorite_actions') + op.drop_table("favorite_statuses") + op.drop_table("favorite_models") + op.drop_table("favorite_manufacturers") + op.drop_table("favorite_locations") + op.drop_table("favorite_actions") diff --git a/migrations/versions/a9442567c6dc_add_exception_to_task_table.py b/migrations/versions/a9442567c6dc_add_exception_to_task_table.py index c3d8379..6ccca67 100644 --- a/migrations/versions/a9442567c6dc_add_exception_to_task_table.py +++ b/migrations/versions/a9442567c6dc_add_exception_to_task_table.py @@ -10,15 +10,15 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'a9442567c6dc' -down_revision = 'c0b8036078e7' +revision = "a9442567c6dc" +down_revision = "c0b8036078e7" branch_labels = None depends_on = None def upgrade(): - op.add_column('task', sa.Column('exception', sa.Text(), nullable=True)) + op.add_column("task", sa.Column("exception", sa.Text(), nullable=True)) def downgrade(): - op.drop_column('task', 'exception') + op.drop_column("task", "exception") diff --git a/migrations/versions/ac6b3c416b07_add_machine_type_table.py b/migrations/versions/ac6b3c416b07_add_machine_type_table.py index 6adb797..926fcae 100644 --- a/migrations/versions/ac6b3c416b07_add_machine_type_table.py +++ b/migrations/versions/ac6b3c416b07_add_machine_type_table.py @@ -11,42 +11,53 @@ import citext # revision identifiers, used by Alembic. -revision = 'ac6b3c416b07' -down_revision = 'dfd4eae61224' +revision = "ac6b3c416b07" +down_revision = "dfd4eae61224" branch_labels = None depends_on = None def upgrade(): machine_type = op.create_table( - 'machine_type', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', citext.CIText(), nullable=False), - sa.PrimaryKeyConstraint('id', name=op.f('pk_machine_type')), - sa.UniqueConstraint('name', name=op.f('uq_machine_type_name')) + "machine_type", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", citext.CIText(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_machine_type")), + sa.UniqueConstraint("name", name=op.f("uq_machine_type_name")), ) # WARNING! If the database is not empty, we can't set the machine_type_id to nullable=False before adding a value! - op.add_column('host', sa.Column('machine_type_id', sa.Integer(), nullable=True)) - op.create_foreign_key(op.f('fk_host_machine_type_id_machine_type'), 'host', 'machine_type', ['machine_type_id'], ['id']) + op.add_column("host", sa.Column("machine_type_id", sa.Integer(), nullable=True)) + op.create_foreign_key( + op.f("fk_host_machine_type_id_machine_type"), + "host", + "machine_type", + ["machine_type_id"], + ["id"], + ) # Create the Physical and Virtual machine types - op.execute(machine_type.insert().values([ - {'id': 1, 'name': 'Physical'}, - {'id': 2, 'name': 'Virtual'}, - ])) + op.execute( + machine_type.insert().values( + [{"id": 1, "name": "Physical"}, {"id": 2, "name": "Virtual"}] + ) + ) # Fill the host machine_type_id based on the value from the type column - host = sa.sql.table('host', sa.sql.column('machine_type_id'), sa.sql.column('type')) - op.execute(host.update().where(host.c.type == 'Physical').values(machine_type_id=1)) - op.execute(host.update().where(host.c.type == 'Virtual').values(machine_type_id=2)) - op.drop_column('host', 'type') + host = sa.sql.table("host", sa.sql.column("machine_type_id"), sa.sql.column("type")) + op.execute(host.update().where(host.c.type == "Physical").values(machine_type_id=1)) + op.execute(host.update().where(host.c.type == "Virtual").values(machine_type_id=2)) + op.drop_column("host", "type") # Add the nullable=False constraint - op.alter_column('host', 'machine_type_id', nullable=False) + op.alter_column("host", "machine_type_id", nullable=False) def downgrade(): - op.add_column('host', sa.Column('type', sa.TEXT(), autoincrement=False, nullable=True)) - host = sa.sql.table('host', sa.sql.column('machine_type_id'), sa.sql.column('type')) - op.execute(host.update().where(host.c.machine_type_id == 1).values(type='Physical')) - op.execute(host.update().where(host.c.machine_type_id == 2).values(type='Virtual')) - op.drop_constraint(op.f('fk_host_machine_type_id_machine_type'), 'host', type_='foreignkey') - op.drop_column('host', 'machine_type_id') - op.drop_table('machine_type') + op.add_column( + "host", sa.Column("type", sa.TEXT(), autoincrement=False, nullable=True) + ) + host = sa.sql.table("host", sa.sql.column("machine_type_id"), sa.sql.column("type")) + op.execute(host.update().where(host.c.machine_type_id == 1).values(type="Physical")) + op.execute(host.update().where(host.c.machine_type_id == 2).values(type="Virtual")) + op.drop_constraint( + op.f("fk_host_machine_type_id_machine_type"), "host", type_="foreignkey" + ) + op.drop_column("host", "machine_type_id") + op.drop_table("machine_type") diff --git a/migrations/versions/c0b8036078e7_add_fields_to_task_table.py b/migrations/versions/c0b8036078e7_add_fields_to_task_table.py index 1ac53cb..f90a2eb 100644 --- a/migrations/versions/c0b8036078e7_add_fields_to_task_table.py +++ b/migrations/versions/c0b8036078e7_add_fields_to_task_table.py @@ -10,17 +10,17 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'c0b8036078e7' -down_revision = '7d0d580cdb1a' +revision = "c0b8036078e7" +down_revision = "7d0d580cdb1a" branch_labels = None depends_on = None def upgrade(): - op.add_column('task', sa.Column('awx_job_id', sa.Integer(), nullable=True)) - op.add_column('task', sa.Column('ended_at', sa.DateTime(), nullable=True)) + op.add_column("task", sa.Column("awx_job_id", sa.Integer(), nullable=True)) + op.add_column("task", sa.Column("ended_at", sa.DateTime(), nullable=True)) def downgrade(): - op.drop_column('task', 'ended_at') - op.drop_column('task', 'awx_job_id') + op.drop_column("task", "ended_at") + op.drop_column("task", "awx_job_id") diff --git a/migrations/versions/dfd4eae61224_add_domain_table.py b/migrations/versions/dfd4eae61224_add_domain_table.py index d85c102..58c8044 100644 --- a/migrations/versions/dfd4eae61224_add_domain_table.py +++ b/migrations/versions/dfd4eae61224_add_domain_table.py @@ -10,52 +10,75 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'dfd4eae61224' -down_revision = '713ca10255ab' +revision = "dfd4eae61224" +down_revision = "713ca10255ab" branch_labels = None depends_on = None def upgrade(): domain = op.create_table( - 'domain', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('name', sa.Text(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['user_account.id'], name=op.f('fk_domain_user_id_user_account')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_domain')), - sa.UniqueConstraint('name', name=op.f('uq_domain_name')) + "domain", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user_account.id"], + name=op.f("fk_domain_user_id_user_account"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_domain")), + sa.UniqueConstraint("name", name=op.f("uq_domain_name")), ) # WARNING! If the database is not emppty, we can't set the domain_id to nullable=False before adding a value! - op.add_column('network', sa.Column('domain_id', sa.Integer(), nullable=True)) - op.create_foreign_key(op.f('fk_network_domain_id_domain'), 'network', 'domain', ['domain_id'], ['id']) - op.add_column('network_scope', sa.Column('domain_id', sa.Integer(), nullable=True)) - op.create_foreign_key(op.f('fk_network_scope_domain_id_domain'), 'network_scope', 'domain', ['domain_id'], ['id']) + op.add_column("network", sa.Column("domain_id", sa.Integer(), nullable=True)) + op.create_foreign_key( + op.f("fk_network_domain_id_domain"), "network", "domain", ["domain_id"], ["id"] + ) + op.add_column("network_scope", sa.Column("domain_id", sa.Integer(), nullable=True)) + op.create_foreign_key( + op.f("fk_network_scope_domain_id_domain"), + "network_scope", + "domain", + ["domain_id"], + ["id"], + ) # Try to get a user_id (required to create a domain) conn = op.get_bind() - res = conn.execute('SELECT id FROM user_account LIMIT 1') + res = conn.execute("SELECT id FROM user_account LIMIT 1") results = res.fetchall() # If no user was found, then the database is empty - no need to add a default value if results: user_id = results[0][0] # Create a default domain - op.execute(domain.insert().values(id=1, user_id=user_id, name='example.org', - created_at=sa.func.now(), updated_at=sa.func.now())) + op.execute( + domain.insert().values( + id=1, + user_id=user_id, + name="example.org", + created_at=sa.func.now(), + updated_at=sa.func.now(), + ) + ) # Add default domain_id value to network_scope and network - network_scope = sa.sql.table('network_scope', sa.sql.column('domain_id')) + network_scope = sa.sql.table("network_scope", sa.sql.column("domain_id")) op.execute(network_scope.update().values(domain_id=1)) - network = sa.sql.table('network', sa.sql.column('domain_id')) + network = sa.sql.table("network", sa.sql.column("domain_id")) op.execute(network.update().values(domain_id=1)) # Add the nullable=False constraint - op.alter_column('network', 'domain_id', nullable=False) - op.alter_column('network_scope', 'domain_id', nullable=False) + op.alter_column("network", "domain_id", nullable=False) + op.alter_column("network_scope", "domain_id", nullable=False) def downgrade(): - op.drop_constraint(op.f('fk_network_scope_domain_id_domain'), 'network_scope', type_='foreignkey') - op.drop_column('network_scope', 'domain_id') - op.drop_constraint(op.f('fk_network_domain_id_domain'), 'network', type_='foreignkey') - op.drop_column('network', 'domain_id') - op.drop_table('domain') + op.drop_constraint( + op.f("fk_network_scope_domain_id_domain"), "network_scope", type_="foreignkey" + ) + op.drop_column("network_scope", "domain_id") + op.drop_constraint( + op.f("fk_network_domain_id_domain"), "network", type_="foreignkey" + ) + op.drop_column("network", "domain_id") + op.drop_table("domain") diff --git a/migrations/versions/e07c7bc870be_rename_machine_type_to_device_type.py b/migrations/versions/e07c7bc870be_rename_machine_type_to_device_type.py index 59ece83..4933363 100644 --- a/migrations/versions/e07c7bc870be_rename_machine_type_to_device_type.py +++ b/migrations/versions/e07c7bc870be_rename_machine_type_to_device_type.py @@ -10,41 +10,59 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'e07c7bc870be' -down_revision = 'a73eeb144fa1' +revision = "e07c7bc870be" +down_revision = "a73eeb144fa1" branch_labels = None depends_on = None def upgrade(): # Rename the table and constraints - op.rename_table('machine_type', 'device_type') - op.execute('ALTER INDEX pk_machine_type RENAME TO pk_device_type') - op.execute('ALTER TABLE device_type RENAME CONSTRAINT uq_machine_type_name TO uq_device_type_name') + op.rename_table("machine_type", "device_type") + op.execute("ALTER INDEX pk_machine_type RENAME TO pk_device_type") + op.execute( + "ALTER TABLE device_type RENAME CONSTRAINT uq_machine_type_name TO uq_device_type_name" + ) # Rename the foreign key in the host table - op.alter_column('host', 'machine_type_id', new_column_name='device_type_id', existing_type=sa.Integer) - op.drop_constraint('fk_host_machine_type_id_machine_type', 'host', type_='foreignkey') + op.alter_column( + "host", + "machine_type_id", + new_column_name="device_type_id", + existing_type=sa.Integer, + ) + op.drop_constraint( + "fk_host_machine_type_id_machine_type", "host", type_="foreignkey" + ) op.create_foreign_key( - op.f('fk_host_device_type_id_device_type'), - 'host', - 'device_type', - ['device_type_id'], - ['id'] + op.f("fk_host_device_type_id_device_type"), + "host", + "device_type", + ["device_type_id"], + ["id"], ) def downgrade(): # Rename the table and constraints - op.rename_table('device_type', 'machine_type') - op.execute('ALTER INDEX pk_device_type RENAME TO pk_machine_type') - op.execute('ALTER TABLE machine_type RENAME CONSTRAINT uq_device_type_name TO uq_machine_type_name') + op.rename_table("device_type", "machine_type") + op.execute("ALTER INDEX pk_device_type RENAME TO pk_machine_type") + op.execute( + "ALTER TABLE machine_type RENAME CONSTRAINT uq_device_type_name TO uq_machine_type_name" + ) # Rename the foreign key in the host table - op.alter_column('host', 'device_type_id', new_column_name='machine_type_id', existing_type=sa.Integer) - op.drop_constraint('fk_host_device_type_id_machine_type', 'host', type_='foreignkey') + op.alter_column( + "host", + "device_type_id", + new_column_name="machine_type_id", + existing_type=sa.Integer, + ) + op.drop_constraint( + "fk_host_device_type_id_machine_type", "host", type_="foreignkey" + ) op.create_foreign_key( - 'fk_host_machine_type_id_machine_type', - 'host', - 'machine_type', - ['machine_type_id'], - ['id'] + "fk_host_machine_type_id_machine_type", + "host", + "machine_type", + ["machine_type_id"], + ["id"], ) diff --git a/migrations/versions/ea606be23b95_rename_physical_device_type.py b/migrations/versions/ea606be23b95_rename_physical_device_type.py index e5e0866..5ff3a42 100644 --- a/migrations/versions/ea606be23b95_rename_physical_device_type.py +++ b/migrations/versions/ea606be23b95_rename_physical_device_type.py @@ -10,17 +10,29 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'ea606be23b95' -down_revision = '8f135d5efde2' +revision = "ea606be23b95" +down_revision = "8f135d5efde2" branch_labels = None depends_on = None def upgrade(): - device_type = sa.sql.table('device_type', sa.sql.column('id'), sa.sql.column('name')) - op.execute(device_type.update().where(device_type.c.name == 'Physical').values(name='Physical Machine')) + device_type = sa.sql.table( + "device_type", sa.sql.column("id"), sa.sql.column("name") + ) + op.execute( + device_type.update() + .where(device_type.c.name == "Physical") + .values(name="Physical Machine") + ) def downgrade(): - device_type = sa.sql.table('device_type', sa.sql.column('id'), sa.sql.column('name')) - op.execute(device_type.update().where(device_type.c.name == 'Physical Machine').values(name='Physical')) + device_type = sa.sql.table( + "device_type", sa.sql.column("id"), sa.sql.column("name") + ) + op.execute( + device_type.update() + .where(device_type.c.name == "Physical Machine") + .values(name="Physical") + ) diff --git a/migrations/versions/f5a605c0c835_remove_spaces_from_device_type.py b/migrations/versions/f5a605c0c835_remove_spaces_from_device_type.py index 6373968..7093cd4 100644 --- a/migrations/versions/f5a605c0c835_remove_spaces_from_device_type.py +++ b/migrations/versions/f5a605c0c835_remove_spaces_from_device_type.py @@ -10,19 +10,39 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'f5a605c0c835' -down_revision = 'ea606be23b95' +revision = "f5a605c0c835" +down_revision = "ea606be23b95" branch_labels = None depends_on = None def upgrade(): - device_type = sa.sql.table('device_type', sa.sql.column('id'), sa.sql.column('name')) - op.execute(device_type.update().where(device_type.c.name == 'Physical Machine').values(name='PhysicalMachine')) - op.execute(device_type.update().where(device_type.c.name == 'Virtual Machine').values(name='VirtualMachine')) + device_type = sa.sql.table( + "device_type", sa.sql.column("id"), sa.sql.column("name") + ) + op.execute( + device_type.update() + .where(device_type.c.name == "Physical Machine") + .values(name="PhysicalMachine") + ) + op.execute( + device_type.update() + .where(device_type.c.name == "Virtual Machine") + .values(name="VirtualMachine") + ) def downgrade(): - device_type = sa.sql.table('device_type', sa.sql.column('id'), sa.sql.column('name')) - op.execute(device_type.update().where(device_type.c.name == 'PhysicalMachine').values(name='Physical Machine')) - op.execute(device_type.update().where(device_type.c.name == 'VirtualMachine').values(name='Virtual Machine')) + device_type = sa.sql.table( + "device_type", sa.sql.column("id"), sa.sql.column("name") + ) + op.execute( + device_type.update() + .where(device_type.c.name == "PhysicalMachine") + .values(name="Physical Machine") + ) + op.execute( + device_type.update() + .where(device_type.c.name == "VirtualMachine") + .values(name="Virtual Machine") + ) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index c0d6e2c..b9eb657 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -35,17 +35,17 @@ register(factories.CnameFactory) register(factories.TagFactory) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def app(request): """Session-wide test `Flask` application.""" config = { - 'TESTING': True, - 'WTF_CSRF_ENABLED': False, - 'SQLALCHEMY_DATABASE_URI': 'postgresql://ics:icspwd@postgres/csentry_db_test', - 'CSENTRY_LDAP_GROUPS': { - 'admin': ['CSEntry Admin'], - 'create': ['CSEntry User', 'CSEntry Consultant'], - } + "TESTING": True, + "WTF_CSRF_ENABLED": False, + "SQLALCHEMY_DATABASE_URI": "postgresql://ics:icspwd@postgres/csentry_db_test", + "CSENTRY_LDAP_GROUPS": { + "admin": ["CSEntry Admin"], + "create": ["CSEntry User", "CSEntry Consultant"], + }, } app = create_app(config=config) ctx = app.app_context() @@ -63,15 +63,16 @@ def client(request, app): return app.test_client() -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def db(app, request): """Session-wide test database.""" + def teardown(): _db.session.remove() _db.drop_all() _db.app = app - _db.engine.execute('CREATE EXTENSION IF NOT EXISTS citext') + _db.engine.execute("CREATE EXTENSION IF NOT EXISTS citext") _db.create_all() request.addfinalizer(teardown) @@ -95,7 +96,7 @@ def session(db, request): # session is actually a scoped_session # for the `after_transaction_end` event, we need a session instance to # listen for, hence the `session()` call - @sa.event.listens_for(session(), 'after_transaction_end') + @sa.event.listens_for(session(), "after_transaction_end") def resetart_savepoint(sess, trans): if trans.nested and not trans._parent.nested: session.expire_all() @@ -113,34 +114,35 @@ def session(db, request): @pytest.fixture(autouse=True) def no_ldap_connection(monkeypatch): """Make sure we don't make any connection to the LDAP server""" - monkeypatch.delattr('flask_ldap3_login.LDAP3LoginManager._make_connection') + monkeypatch.delattr("flask_ldap3_login.LDAP3LoginManager._make_connection") @pytest.fixture(autouse=True) def patch_ldap_authenticate(monkeypatch): - def authenticate(self, username, password): response = AuthenticationResponse() response.user_id = username - response.user_dn = f'cn={username},dc=esss,dc=lu,dc=se' - if username == 'admin' and password == 'adminpasswd': + response.user_dn = f"cn={username},dc=esss,dc=lu,dc=se" + if username == "admin" and password == "adminpasswd": response.status = AuthenticationResponseStatus.success - response.user_info = {'cn': 'Admin User', 'mail': 'admin@example.com'} - response.user_groups = [{'cn': 'CSEntry Admin'}] - elif username == 'user_rw' and password == 'userrw': + response.user_info = {"cn": "Admin User", "mail": "admin@example.com"} + response.user_groups = [{"cn": "CSEntry Admin"}] + elif username == "user_rw" and password == "userrw": response.status = AuthenticationResponseStatus.success - response.user_info = {'cn': 'User RW', 'mail': 'user_rw@example.com'} - response.user_groups = [{'cn': 'CSEntry User'}] - elif username == 'consultant' and password == 'consultantpwd': + response.user_info = {"cn": "User RW", "mail": "user_rw@example.com"} + response.user_groups = [{"cn": "CSEntry User"}] + elif username == "consultant" and password == "consultantpwd": response.status = AuthenticationResponseStatus.success - response.user_info = {'cn': 'Consultant', 'mail': 'consultant@example.com'} - response.user_groups = [{'cn': 'CSEntry Consultant'}] - elif username == 'user_ro' and password == 'userro': + response.user_info = {"cn": "Consultant", "mail": "consultant@example.com"} + response.user_groups = [{"cn": "CSEntry Consultant"}] + elif username == "user_ro" and password == "userro": response.status = AuthenticationResponseStatus.success - response.user_info = {'cn': 'User RO', 'mail': 'user_ro@example.com'} - response.user_groups = [{'cn': 'ESS Employees'}] + response.user_info = {"cn": "User RO", "mail": "user_ro@example.com"} + response.user_groups = [{"cn": "ESS Employees"}] else: response.status = AuthenticationResponseStatus.fail return response - monkeypatch.setattr('flask_ldap3_login.LDAP3LoginManager.authenticate', authenticate) + monkeypatch.setattr( + "flask_ldap3_login.LDAP3LoginManager.authenticate", authenticate + ) diff --git a/tests/functional/factories.py b/tests/functional/factories.py index 50de310..571d071 100644 --- a/tests/functional/factories.py +++ b/tests/functional/factories.py @@ -23,65 +23,65 @@ class UserFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.User sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - username = factory.Sequence(lambda n: f'username{n}') - display_name = factory.LazyAttribute(lambda o: f'long {o.username}') + username = factory.Sequence(lambda n: f"username{n}") + display_name = factory.LazyAttribute(lambda o: f"long {o.username}") class ActionFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Action sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'action{n}') + name = factory.Sequence(lambda n: f"action{n}") class ManufacturerFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Manufacturer sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'manufacturer{n}') + name = factory.Sequence(lambda n: f"manufacturer{n}") class ModelFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Model sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'model{n}') + name = factory.Sequence(lambda n: f"model{n}") class LocationFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Location sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'location{n}') + name = factory.Sequence(lambda n: f"location{n}") class StatusFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Status sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'status{n}') + name = factory.Sequence(lambda n: f"status{n}") class ItemFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Item sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - ics_id = factory.Sequence(lambda n: f'AAA{n:03}') - serial_number = factory.Faker('isbn10') + ics_id = factory.Sequence(lambda n: f"AAA{n:03}") + serial_number = factory.Faker("isbn10") manufacturer = factory.SubFactory(ManufacturerFactory) model = factory.SubFactory(ModelFactory) location = factory.SubFactory(LocationFactory) @@ -93,22 +93,22 @@ class DomainFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Domain sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" user = factory.SubFactory(UserFactory) - name = factory.Sequence(lambda n: f'domain{n}.example.org') + name = factory.Sequence(lambda n: f"domain{n}.example.org") class NetworkScopeFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.NetworkScope sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'scope{n}') + name = factory.Sequence(lambda n: f"scope{n}") first_vlan = factory.Sequence(lambda n: 1600 + 10 * n) last_vlan = factory.Sequence(lambda n: 1609 + 10 * n) - supernet = factory.Faker('ipv4', network=True) + supernet = factory.Faker("ipv4", network=True) user = factory.SubFactory(UserFactory) domain = factory.SubFactory(DomainFactory) @@ -117,11 +117,11 @@ class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Network sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - vlan_name = factory.Sequence(lambda n: f'vlan{n}') + vlan_name = factory.Sequence(lambda n: f"vlan{n}") vlan_id = factory.Sequence(lambda n: 1600 + n) - address = factory.Sequence(lambda n: f'192.168.{n}.0/24') + address = factory.Sequence(lambda n: f"192.168.{n}.0/24") scope = factory.SubFactory(NetworkScopeFactory) user = factory.SubFactory(UserFactory) domain = factory.SubFactory(DomainFactory) @@ -143,11 +143,13 @@ class InterfaceFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Interface sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'host{n}') + name = factory.Sequence(lambda n: f"host{n}") network = factory.SubFactory(NetworkFactory) - ip = factory.LazyAttributeSequence(lambda o, n: str(ipaddress.ip_address(o.network.first_ip) + n)) + ip = factory.LazyAttributeSequence( + lambda o, n: str(ipaddress.ip_address(o.network.first_ip) + n) + ) user = factory.SubFactory(UserFactory) @@ -155,18 +157,18 @@ class DeviceTypeFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.DeviceType sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'Type{n}') + name = factory.Sequence(lambda n: f"Type{n}") class HostFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Host sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'host{n}') + name = factory.Sequence(lambda n: f"host{n}") user = factory.SubFactory(UserFactory) device_type = factory.SubFactory(DeviceTypeFactory) @@ -175,18 +177,18 @@ class MacFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Mac sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - address = factory.Faker('mac_address') + address = factory.Faker("mac_address") class CnameFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Cname sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'host{n}') + name = factory.Sequence(lambda n: f"host{n}") interface = factory.SubFactory(InterfaceFactory) user = factory.SubFactory(UserFactory) @@ -195,6 +197,6 @@ class TagFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Tag sqlalchemy_session = common.Session - sqlalchemy_session_persistence = 'commit' + sqlalchemy_session_persistence = "commit" - name = factory.Sequence(lambda n: f'Tag{n}') + name = factory.Sequence(lambda n: f"Tag{n}") diff --git a/tests/functional/test_api.py b/tests/functional/test_api.py index 3fe122d..af30acc 100644 --- a/tests/functional/test_api.py +++ b/tests/functional/test_api.py @@ -15,90 +15,98 @@ import pytest from app import models -API_URL = '/api/v1' +API_URL = "/api/v1" ENDPOINT_MODEL = { - 'inventory/actions': models.Action, - 'inventory/manufacturers': models.Manufacturer, - 'inventory/models': models.Model, - 'inventory/locations': models.Location, - 'inventory/statuses': models.Status, - 'inventory/items': models.Item, - 'network/networks': models.Network, - 'network/interfaces': models.Interface, - 'network/hosts': models.Host, - 'network/macs': models.Mac, - 'network/domains': models.Domain, - 'network/cnames': models.Cname, + "inventory/actions": models.Action, + "inventory/manufacturers": models.Manufacturer, + "inventory/models": models.Model, + "inventory/locations": models.Location, + "inventory/statuses": models.Status, + "inventory/items": models.Item, + "network/networks": models.Network, + "network/interfaces": models.Interface, + "network/hosts": models.Host, + "network/macs": models.Mac, + "network/domains": models.Domain, + "network/cnames": models.Cname, } -GENERIC_GET_ENDPOINTS = [key for key in ENDPOINT_MODEL.keys() - if key.startswith('inventory') and key != 'inventory/items'] -GENERIC_CREATE_ENDPOINTS = [key for key in ENDPOINT_MODEL.keys() - if key.startswith('inventory') and key not in ('inventory/items', 'inventory/actions')] -CREATE_AUTH_ENDPOINTS = [key for key in ENDPOINT_MODEL.keys() if key != 'inventory/actions'] +GENERIC_GET_ENDPOINTS = [ + key + for key in ENDPOINT_MODEL.keys() + if key.startswith("inventory") and key != "inventory/items" +] +GENERIC_CREATE_ENDPOINTS = [ + key + for key in ENDPOINT_MODEL.keys() + if key.startswith("inventory") + and key not in ("inventory/items", "inventory/actions") +] +CREATE_AUTH_ENDPOINTS = [ + key for key in ENDPOINT_MODEL.keys() if key != "inventory/actions" +] def get(client, url, token=None): response = client.get( url, - headers={'Content-Type': 'application/json', - 'Authorization': f'Bearer {token}'}, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {token}", + }, ) - if response.headers['Content-Type'] == 'application/json': + if response.headers["Content-Type"] == "application/json": response.json = json.loads(response.data) return response def post(client, url, data, token=None): - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} if token is not None: - headers['Authorization'] = f'Bearer {token}' + headers["Authorization"] = f"Bearer {token}" response = client.post(url, data=json.dumps(data), headers=headers) - if response.headers['Content-Type'] == 'application/json': + if response.headers["Content-Type"] == "application/json": response.json = json.loads(response.data) return response def patch(client, url, data, token=None): - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} if token is not None: - headers['Authorization'] = f'Bearer {token}' + headers["Authorization"] = f"Bearer {token}" response = client.patch(url, data=json.dumps(data), headers=headers) - if response.headers['Content-Type'] == 'application/json': + if response.headers["Content-Type"] == "application/json": response.json = json.loads(response.data) return response def login(client, username, password): - data = { - 'username': username, - 'password': password - } - return post(client, f'{API_URL}/user/login', data) + data = {"username": username, "password": password} + return post(client, f"{API_URL}/user/login", data) def get_token(client, username, password): response = login(client, username, password) - return response.json['access_token'] + return response.json["access_token"] @pytest.fixture() def readonly_token(client): - return get_token(client, 'user_ro', 'userro') + return get_token(client, "user_ro", "userro") @pytest.fixture() def user_token(client): - return get_token(client, 'user_rw', 'userrw') + return get_token(client, "user_rw", "userrw") @pytest.fixture() def consultant_token(client): - return get_token(client, 'consultant', 'consultantpwd') + return get_token(client, "consultant", "consultantpwd") @pytest.fixture() def admin_token(client): - return get_token(client, 'admin', 'adminpasswd') + return get_token(client, "admin", "adminpasswd") def check_response_message(response, msg, status_code=400): @@ -108,344 +116,479 @@ def check_response_message(response, msg, status_code=400): except AttributeError: data = json.loads(response.data) try: - message = data['message'] + message = data["message"] except KeyError: # flask-jwt-extended is using "msg" instead of "message" # in its default callbacks - message = data['msg'] + message = data["msg"] assert message.startswith(msg) def check_names(response, names): - response_names = set(item['name'] for item in response.json) + response_names = set(item["name"] for item in response.json) assert set(names) == response_names def check_input_is_subset_of_response(response, inputs): # Sort the response by id to match the inputs order - response_elts = sorted(response.json, key=lambda d: d['id']) + response_elts = sorted(response.json, key=lambda d: d["id"]) for d1, d2 in zip(inputs, response_elts): for key, value in d1.items(): if isinstance(value, datetime.datetime): - value = value.strftime('%Y-%m-%d %H:%M') + value = value.strftime("%Y-%m-%d %H:%M") assert d2[key] == value def test_login(client): - response = client.post(f'{API_URL}/user/login') - check_response_message(response, 'Body should be a JSON object') - response = post(client, f'{API_URL}/user/login', data={'username': 'foo', 'passwd': ''}) - check_response_message(response, 'Missing mandatory field (username or password)', 422) - response = login(client, 'foo', 'invalid') - check_response_message(response, 'Invalid credentials', 401) - response = login(client, 'user_ro', 'userro') + response = client.post(f"{API_URL}/user/login") + check_response_message(response, "Body should be a JSON object") + response = post( + client, f"{API_URL}/user/login", data={"username": "foo", "passwd": ""} + ) + check_response_message( + response, "Missing mandatory field (username or password)", 422 + ) + response = login(client, "foo", "invalid") + check_response_message(response, "Invalid credentials", 401) + response = login(client, "user_ro", "userro") assert response.status_code == 200 - assert 'access_token' in response.json + assert "access_token" in response.json -@pytest.mark.parametrize('endpoint', GENERIC_GET_ENDPOINTS) +@pytest.mark.parametrize("endpoint", GENERIC_GET_ENDPOINTS) def test_get_generic_model(endpoint, session, client, readonly_token): model = ENDPOINT_MODEL[endpoint] - names = ('Foo', 'Bar', 'Alice') + names = ("Foo", "Bar", "Alice") for name in names: session.add(model(name=name)) session.commit() - response = client.get(f'{API_URL}/{endpoint}') - check_response_message(response, 'Missing Authorization Header', 401) - response = get(client, f'{API_URL}/{endpoint}', 'xxxxxxxxx') - check_response_message(response, 'Not enough segments', 422) - response = get(client, f'{API_URL}/{endpoint}', readonly_token) + response = client.get(f"{API_URL}/{endpoint}") + check_response_message(response, "Missing Authorization Header", 401) + response = get(client, f"{API_URL}/{endpoint}", "xxxxxxxxx") + check_response_message(response, "Not enough segments", 422) + response = get(client, f"{API_URL}/{endpoint}", readonly_token) check_names(response, names) - response = get(client, f'{API_URL}/{endpoint}', readonly_token) + response = get(client, f"{API_URL}/{endpoint}", readonly_token) check_names(response, names) for item in response.json: - assert 'qrcode' in item + assert "qrcode" in item -@pytest.mark.parametrize('endpoint', CREATE_AUTH_ENDPOINTS) +@pytest.mark.parametrize("endpoint", CREATE_AUTH_ENDPOINTS) def test_create_model_auth_fail(endpoint, client, readonly_token): - response = client.post(f'{API_URL}/{endpoint}') - check_response_message(response, 'Missing Authorization Header', 401) - response = post(client, f'{API_URL}/{endpoint}', data={}, token='xxxxxxxxx') - check_response_message(response, 'Not enough segments', 422) - response = post(client, f'{API_URL}/{endpoint}', data={}, token=readonly_token) + response = client.post(f"{API_URL}/{endpoint}") + check_response_message(response, "Missing Authorization Header", 401) + response = post(client, f"{API_URL}/{endpoint}", data={}, token="xxxxxxxxx") + check_response_message(response, "Not enough segments", 422) + response = post(client, f"{API_URL}/{endpoint}", data={}, token=readonly_token) check_response_message(response, "User doesn't have the required group", 403) model = ENDPOINT_MODEL[endpoint] assert model.query.count() == 0 -@pytest.mark.parametrize('endpoint', GENERIC_CREATE_ENDPOINTS) +@pytest.mark.parametrize("endpoint", GENERIC_CREATE_ENDPOINTS) def test_create_generic_model(endpoint, client, user_token): - response = post(client, f'{API_URL}/{endpoint}', data={}, token=user_token) + response = post(client, f"{API_URL}/{endpoint}", data={}, token=user_token) check_response_message(response, "Missing mandatory field 'name'", 422) - data = {'name': 'Foo'} - response = post(client, f'{API_URL}/{endpoint}', data=data, token=user_token) + data = {"name": "Foo"} + response = post(client, f"{API_URL}/{endpoint}", data=data, token=user_token) assert response.status_code == 201 - assert {'id', 'name'} <= set(response.json.keys()) - assert response.json['name'] == 'Foo' - response = post(client, f'{API_URL}/{endpoint}', data=data, token=user_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) - response = post(client, f'{API_URL}/{endpoint}', data={'name': 'foo'}, token=user_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) - response = post(client, f'{API_URL}/{endpoint}', data={'name': 'FOO'}, token=user_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) - data = {'name': 'Bar', 'description': 'Bar description'} - response = post(client, f'{API_URL}/{endpoint}', data=data, token=user_token) + assert {"id", "name"} <= set(response.json.keys()) + assert response.json["name"] == "Foo" + response = post(client, f"{API_URL}/{endpoint}", data=data, token=user_token) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) + response = post( + client, f"{API_URL}/{endpoint}", data={"name": "foo"}, token=user_token + ) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) + response = post( + client, f"{API_URL}/{endpoint}", data={"name": "FOO"}, token=user_token + ) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) + data = {"name": "Bar", "description": "Bar description"} + response = post(client, f"{API_URL}/{endpoint}", data=data, token=user_token) assert response.status_code == 201 - assert response.json['description'] == 'Bar description' + assert response.json["description"] == "Bar description" model = ENDPOINT_MODEL[endpoint] assert model.query.count() == 2 - response = get(client, f'{API_URL}/{endpoint}', user_token) - check_names(response, ('Foo', 'Bar')) + response = get(client, f"{API_URL}/{endpoint}", user_token) + check_names(response, ("Foo", "Bar")) -@pytest.mark.parametrize('endpoint', GENERIC_CREATE_ENDPOINTS) +@pytest.mark.parametrize("endpoint", GENERIC_CREATE_ENDPOINTS) def test_create_generic_model_invalid_param(endpoint, client, user_token): model = ENDPOINT_MODEL[endpoint] - response = post(client, f'{API_URL}/{endpoint}', data={'name': 'foo', 'hello': 'world'}, token=user_token) - check_response_message(response, f"'hello' is an invalid keyword argument for {model.__name__}", 422) + response = post( + client, + f"{API_URL}/{endpoint}", + data={"name": "foo", "hello": "world"}, + token=user_token, + ) + check_response_message( + response, f"'hello' is an invalid keyword argument for {model.__name__}", 422 + ) def test_create_item(client, user_token): # check that serial_number is mandatory - response = post(client, f'{API_URL}/inventory/items', data={}, token=user_token) + response = post(client, f"{API_URL}/inventory/items", data={}, token=user_token) check_response_message(response, "Missing mandatory field 'serial_number'", 422) # check create with only serial_number - data = {'serial_number': '123456'} - response = post(client, f'{API_URL}/inventory/items', data=data, token=user_token) + data = {"serial_number": "123456"} + response = post(client, f"{API_URL}/inventory/items", data=data, token=user_token) assert response.status_code == 201 - assert {'id', 'ics_id', 'serial_number', 'manufacturer', 'model', 'quantity', - 'location', 'status', 'parent', 'children', 'macs', 'history', 'host', - 'stack_member', 'updated_at', 'created_at', 'user', 'comments'} == set(response.json.keys()) - assert response.json['serial_number'] == '123456' + assert { + "id", + "ics_id", + "serial_number", + "manufacturer", + "model", + "quantity", + "location", + "status", + "parent", + "children", + "macs", + "history", + "host", + "stack_member", + "updated_at", + "created_at", + "user", + "comments", + } == set(response.json.keys()) + assert response.json["serial_number"] == "123456" # Check that serial_number doesn't have to be unique - response = post(client, f'{API_URL}/inventory/items', data=data, token=user_token) + response = post(client, f"{API_URL}/inventory/items", data=data, token=user_token) assert response.status_code == 201 # check that ics_id shall be unique - data2 = {'serial_number': '456789', 'ics_id': 'AAA001'} - response = post(client, f'{API_URL}/inventory/items', data=data2, token=user_token) + data2 = {"serial_number": "456789", "ics_id": "AAA001"} + response = post(client, f"{API_URL}/inventory/items", data=data2, token=user_token) assert response.status_code == 201 - response = post(client, f'{API_URL}/inventory/items', data=data2, token=user_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) + response = post(client, f"{API_URL}/inventory/items", data=data2, token=user_token) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) # check all items that were created assert models.Item.query.count() == 3 - response = get(client, f'{API_URL}/inventory/items', user_token) + response = get(client, f"{API_URL}/inventory/items", user_token) check_input_is_subset_of_response(response, (data, data, data2)) def test_create_item_with_host_id(client, host_factory, user_token): host = host_factory() # Check that we can pass an host_id - data = {'serial_number': '123456', - 'host_id': host.id} - response = post(client, f'{API_URL}/inventory/items', data=data, token=user_token) + data = {"serial_number": "123456", "host_id": host.id} + response = post(client, f"{API_URL}/inventory/items", data=data, token=user_token) assert response.status_code == 201 - item = models.Item.query.filter_by(serial_number=data['serial_number']).first() + item = models.Item.query.filter_by(serial_number=data["serial_number"]).first() assert item.host_id == host.id def test_create_item_invalid_ics_id(client, user_token): - for ics_id in ('foo', 'AAB1234', 'AZ02', 'WS007', 'AAA01'): - data = {'serial_number': '123456', 'ics_id': ics_id} - response = post(client, f'{API_URL}/inventory/items', data=data, token=user_token) - check_response_message(response, 'ICS id shall match [A-Z]{3}[0-9]{3}', 422) + for ics_id in ("foo", "AAB1234", "AZ02", "WS007", "AAA01"): + data = {"serial_number": "123456", "ics_id": ics_id} + response = post( + client, f"{API_URL}/inventory/items", data=data, token=user_token + ) + check_response_message(response, "ICS id shall match [A-Z]{3}[0-9]{3}", 422) def test_get_item_fail(client, session, readonly_token): - response = get(client, f'{API_URL}/inventory/items/50', token=readonly_token) + response = get(client, f"{API_URL}/inventory/items/50", token=readonly_token) check_response_message(response, "Item id '50' not found", 404) - response = get(client, f'{API_URL}/inventory/items/bar', token=readonly_token) + response = get(client, f"{API_URL}/inventory/items/bar", token=readonly_token) check_response_message(response, "Item id 'bar' not found", 404) def test_get_item(client, status_factory, item_factory, readonly_token): # Create some items - status_factory(name='Stock') - item1 = item_factory(serial_number='123456') - item2 = item_factory(serial_number='234567', ics_id='AAA001', status='Stock') + status_factory(name="Stock") + item1 = item_factory(serial_number="123456") + item2 = item_factory(serial_number="234567", ics_id="AAA001", status="Stock") # we can get items by id... - response = get(client, f'{API_URL}/inventory/items/{item1.id}', token=readonly_token) + response = get( + client, f"{API_URL}/inventory/items/{item1.id}", token=readonly_token + ) assert response.status_code == 200 - assert response.json['id'] == item1.id - assert response.json['serial_number'] == item1.serial_number + assert response.json["id"] == item1.id + assert response.json["serial_number"] == item1.serial_number # ...or ics_id - response = get(client, f'{API_URL}/inventory/items/{item2.ics_id}', token=readonly_token) + response = get( + client, f"{API_URL}/inventory/items/{item2.ics_id}", token=readonly_token + ) assert response.status_code == 200 - assert response.json['id'] == item2.id - assert response.json['ics_id'] == item2.ics_id - assert response.json['serial_number'] == item2.serial_number - assert response.json['status'] == str(item2.status) + assert response.json["id"] == item2.id + assert response.json["ics_id"] == item2.ics_id + assert response.json["serial_number"] == item2.serial_number + assert response.json["status"] == str(item2.status) def test_patch_item_auth_fail(client, session, readonly_token): - response = client.patch(f'{API_URL}/inventory/items/50') - check_response_message(response, 'Missing Authorization Header', 401) - response = patch(client, f'{API_URL}/inventory/items/50', data={}, token='xxxxxxxxx') - check_response_message(response, 'Not enough segments', 422) - response = patch(client, f'{API_URL}/inventory/items/50', data={}, token=readonly_token) + response = client.patch(f"{API_URL}/inventory/items/50") + check_response_message(response, "Missing Authorization Header", 401) + response = patch( + client, f"{API_URL}/inventory/items/50", data={}, token="xxxxxxxxx" + ) + check_response_message(response, "Not enough segments", 422) + response = patch( + client, f"{API_URL}/inventory/items/50", data={}, token=readonly_token + ) check_response_message(response, "User doesn't have the required group", 403) def test_patch_item_fail(client, item_factory, user_token): - response = patch(client, f'{API_URL}/inventory/items/50', data={}, token=user_token) - check_response_message(response, 'At least one field is required', 422) - data = {'location': 'ESS', 'foo': 'bar'} - response = patch(client, f'{API_URL}/inventory/items/50', data=data, token=user_token) + response = patch(client, f"{API_URL}/inventory/items/50", data={}, token=user_token) + check_response_message(response, "At least one field is required", 422) + data = {"location": "ESS", "foo": "bar"} + response = patch( + client, f"{API_URL}/inventory/items/50", data=data, token=user_token + ) check_response_message(response, "Invalid field 'foo'", 422) - data = {'location': 'ESS'} - response = patch(client, f'{API_URL}/inventory/items/50', data=data, token=user_token) + data = {"location": "ESS"} + response = patch( + client, f"{API_URL}/inventory/items/50", data=data, token=user_token + ) check_response_message(response, "Item id '50' not found", 404) - response = patch(client, f'{API_URL}/inventory/items/bar', data=data, token=user_token) + response = patch( + client, f"{API_URL}/inventory/items/bar", data=data, token=user_token + ) check_response_message(response, "Item id 'bar' not found", 404) # Create an item - item1 = item_factory(serial_number='234567', ics_id='AAA001') + item1 = item_factory(serial_number="234567", ics_id="AAA001") # check that we can't change the serial_number or ics_id - response = patch(client, f'{API_URL}/inventory/items/{item1.id}', data={'serial_number': '12345'}, token=user_token) + response = patch( + client, + f"{API_URL}/inventory/items/{item1.id}", + data={"serial_number": "12345"}, + token=user_token, + ) check_response_message(response, "Invalid field 'serial_number'", 422) - response = patch(client, f'{API_URL}/inventory/items/{item1.id}', data={'ics_id': 'AAA002'}, token=user_token) + response = patch( + client, + f"{API_URL}/inventory/items/{item1.id}", + data={"ics_id": "AAA002"}, + token=user_token, + ) check_response_message(response, "'ics_id' can't be changed", 422) def test_patch_item(client, status, item_factory, user_token): # Create some items - item1 = item_factory(ics_id='ZZZ001') + item1 = item_factory(ics_id="ZZZ001") item2 = item_factory() # we can patch items by id... - data = {'ics_id': 'AAB004'} - response = patch(client, f'{API_URL}/inventory/items/{item1.id}', data=data, token=user_token) + data = {"ics_id": "AAB004"} + response = patch( + client, f"{API_URL}/inventory/items/{item1.id}", data=data, token=user_token + ) assert response.status_code == 200 - assert response.json['id'] == item1.id - assert response.json['serial_number'] == item1.serial_number - assert response.json['ics_id'] == data['ics_id'] + assert response.json["id"] == item1.id + assert response.json["serial_number"] == item1.serial_number + assert response.json["ics_id"] == data["ics_id"] # ...or ics_id - data = {'status': status.name} - response = patch(client, f'{API_URL}/inventory/items/{item2.ics_id}', data=data, token=user_token) + data = {"status": status.name} + response = patch( + client, f"{API_URL}/inventory/items/{item2.ics_id}", data=data, token=user_token + ) assert response.status_code == 200 - assert response.json['id'] == item2.id - assert response.json['ics_id'] == item2.ics_id - assert response.json['serial_number'] == item2.serial_number - assert response.json['status'] == data['status'] + assert response.json["id"] == item2.id + assert response.json["ics_id"] == item2.ics_id + assert response.json["serial_number"] == item2.serial_number + assert response.json["status"] == data["status"] def test_patch_item_integrity_error(client, user_token, item_factory): # Create some items item1 = item_factory() - item2 = item_factory(ics_id='ZZZ001') - data = {'ics_id': item1.ics_id} - response = patch(client, f'{API_URL}/inventory/items/{item2.id}', data=data, token=user_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) + item2 = item_factory(ics_id="ZZZ001") + data = {"ics_id": item1.ics_id} + response = patch( + client, f"{API_URL}/inventory/items/{item2.id}", data=data, token=user_token + ) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) -def test_patch_item_parent(client, location_factory, manufacturer_factory, - status_factory, item_factory, user_token): +def test_patch_item_parent( + client, + location_factory, + manufacturer_factory, + status_factory, + item_factory, + user_token, +): # Create some items - location_factory(name='ESS') - manufacturer_factory(name='HP') - status_factory(name='In service') - status_factory(name='Stock') - item1 = item_factory(ics_id='AAA001', status='In service') - item2 = item_factory(ics_id='AAA002') - item3 = item_factory(ics_id='AAA003') + location_factory(name="ESS") + manufacturer_factory(name="HP") + status_factory(name="In service") + status_factory(name="Stock") + item1 = item_factory(ics_id="AAA001", status="In service") + item2 = item_factory(ics_id="AAA002") + item3 = item_factory(ics_id="AAA003") item3.parent_id = item1.id # set parent changes the status and location - data1 = {'parent': item1.ics_id} - response = patch(client, f'{API_URL}/inventory/items/{item2.ics_id}', data=data1, token=user_token) + data1 = {"parent": item1.ics_id} + response = patch( + client, + f"{API_URL}/inventory/items/{item2.ics_id}", + data=data1, + token=user_token, + ) assert response.status_code == 200 - assert response.json['id'] == item2.id - assert response.json['ics_id'] == item2.ics_id - assert response.json['serial_number'] == item2.serial_number - assert response.json['parent'] == item1.ics_id - assert response.json['status'] == str(item1.status) - assert response.json['location'] == str(item1.location) + assert response.json["id"] == item2.id + assert response.json["ics_id"] == item2.ics_id + assert response.json["serial_number"] == item2.serial_number + assert response.json["parent"] == item1.ics_id + assert response.json["status"] == str(item1.status) + assert response.json["location"] == str(item1.location) # updating a parent, modifies the status and location of all children # check location - data2 = {'location': 'ESS'} - response = patch(client, f'{API_URL}/inventory/items/{item1.ics_id}', data=data2, token=user_token) + data2 = {"location": "ESS"} + response = patch( + client, + f"{API_URL}/inventory/items/{item1.ics_id}", + data=data2, + token=user_token, + ) assert response.status_code == 200 - assert response.json['id'] == item1.id - assert response.json['ics_id'] == item1.ics_id - assert response.json['serial_number'] == item1.serial_number - assert response.json['status'] == str(item1.status) - assert response.json['location'] == data2['location'] - for ics_id in ('AAA002', 'AAA003'): - response = get(client, f'{API_URL}/inventory/items/{ics_id}', token=user_token) - assert response.json['location'] == data2['location'] - assert response.json['status'] == 'In service' + assert response.json["id"] == item1.id + assert response.json["ics_id"] == item1.ics_id + assert response.json["serial_number"] == item1.serial_number + assert response.json["status"] == str(item1.status) + assert response.json["location"] == data2["location"] + for ics_id in ("AAA002", "AAA003"): + response = get(client, f"{API_URL}/inventory/items/{ics_id}", token=user_token) + assert response.json["location"] == data2["location"] + assert response.json["status"] == "In service" # check status - data3 = {'status': 'Stock'} - response = patch(client, f'{API_URL}/inventory/items/{item1.ics_id}', data=data3, token=user_token) + data3 = {"status": "Stock"} + response = patch( + client, + f"{API_URL}/inventory/items/{item1.ics_id}", + data=data3, + token=user_token, + ) assert response.status_code == 200 - assert response.json['status'] == data3['status'] - for ics_id in ('AAA002', 'AAA003'): - response = get(client, f'{API_URL}/inventory/items/{ics_id}', token=user_token) - assert response.json['location'] == data2['location'] - assert response.json['status'] == data3['status'] + assert response.json["status"] == data3["status"] + for ics_id in ("AAA002", "AAA003"): + response = get(client, f"{API_URL}/inventory/items/{ics_id}", token=user_token) + assert response.json["location"] == data2["location"] + assert response.json["status"] == data3["status"] # manufacturer has no impact on children - data4 = {'manufacturer': 'HP'} - response = patch(client, f'{API_URL}/inventory/items/{item1.ics_id}', data=data4, token=user_token) + data4 = {"manufacturer": "HP"} + response = patch( + client, + f"{API_URL}/inventory/items/{item1.ics_id}", + data=data4, + token=user_token, + ) assert response.status_code == 200 - assert response.json['manufacturer'] == 'HP' + assert response.json["manufacturer"] == "HP" # Manufacturer didn't change on children - response = get(client, f'{API_URL}/inventory/items/{item2.ics_id}', token=user_token) - assert response.json['manufacturer'] == str(item2.manufacturer) - assert str(item2.manufacturer) != 'HP' - response = get(client, f'{API_URL}/inventory/items/{item3.ics_id}', token=user_token) - assert response.json['manufacturer'] == str(item3.manufacturer) - assert str(item3.manufacturer) != 'HP' + response = get( + client, f"{API_URL}/inventory/items/{item2.ics_id}", token=user_token + ) + assert response.json["manufacturer"] == str(item2.manufacturer) + assert str(item2.manufacturer) != "HP" + response = get( + client, f"{API_URL}/inventory/items/{item3.ics_id}", token=user_token + ) + assert response.json["manufacturer"] == str(item3.manufacturer) + assert str(item3.manufacturer) != "HP" def test_get_items(client, location_factory, item_factory, readonly_token): # Create some items - location_factory(name='ESS') - item1 = item_factory(location='ESS') - item2 = item_factory(serial_number='234567') + location_factory(name="ESS") + item1 = item_factory(location="ESS") + item2 = item_factory(serial_number="234567") item3 = item_factory() - response = get(client, f'{API_URL}/inventory/items', token=readonly_token) + response = get(client, f"{API_URL}/inventory/items", token=readonly_token) assert response.status_code == 200 assert len(response.json) == 3 - check_input_is_subset_of_response(response, (item1.to_dict(), item2.to_dict(), item3.to_dict())) + check_input_is_subset_of_response( + response, (item1.to_dict(), item2.to_dict(), item3.to_dict()) + ) # test filtering - response = get(client, f'{API_URL}/inventory/items?serial_number=234567', token=readonly_token) + response = get( + client, f"{API_URL}/inventory/items?serial_number=234567", token=readonly_token + ) assert response.status_code == 200 assert len(response.json) == 1 check_input_is_subset_of_response(response, (item2.to_dict(),)) # filtering on location_id works but not location (might want to change that) - response = get(client, f'{API_URL}/inventory/items?location_id={item1.location_id}', token=readonly_token) + response = get( + client, + f"{API_URL}/inventory/items?location_id={item1.location_id}", + token=readonly_token, + ) assert response.status_code == 200 assert len(response.json) == 1 check_input_is_subset_of_response(response, (item1.to_dict(),)) - response = get(client, f'{API_URL}/inventory/items?location=ESS', token=readonly_token) - check_response_message(response, 'Invalid query arguments', 422) + response = get( + client, f"{API_URL}/inventory/items?location=ESS", token=readonly_token + ) + check_response_message(response, "Invalid query arguments", 422) # using an unknown key raises a 422 - response = get(client, f'{API_URL}/inventory/items?foo=bar', token=readonly_token) - check_response_message(response, 'Invalid query arguments', 422) + response = get(client, f"{API_URL}/inventory/items?foo=bar", token=readonly_token) + check_response_message(response, "Invalid query arguments", 422) def test_get_networks(client, network_factory, readonly_token): # Create some networks - network1 = network_factory(address='172.16.1.0/24', first_ip='172.16.1.1', last_ip='172.16.1.254') - network2 = network_factory(address='172.16.20.0/22', first_ip='172.16.20.11', last_ip='172.16.20.250') - network3 = network_factory(address='172.16.5.0/24', first_ip='172.16.5.10', last_ip='172.16.5.254') + network1 = network_factory( + address="172.16.1.0/24", first_ip="172.16.1.1", last_ip="172.16.1.254" + ) + network2 = network_factory( + address="172.16.20.0/22", first_ip="172.16.20.11", last_ip="172.16.20.250" + ) + network3 = network_factory( + address="172.16.5.0/24", first_ip="172.16.5.10", last_ip="172.16.5.254" + ) - response = get(client, f'{API_URL}/network/networks', token=readonly_token) + response = get(client, f"{API_URL}/network/networks", token=readonly_token) assert response.status_code == 200 assert len(response.json) == 3 - check_input_is_subset_of_response(response, (network1.to_dict(), network2.to_dict(), network3.to_dict())) + check_input_is_subset_of_response( + response, (network1.to_dict(), network2.to_dict(), network3.to_dict()) + ) # test filtering by address - response = get(client, f'{API_URL}/network/networks?address=172.16.20.0/22', token=readonly_token) + response = get( + client, + f"{API_URL}/network/networks?address=172.16.20.0/22", + token=readonly_token, + ) assert response.status_code == 200 assert len(response.json) == 1 check_input_is_subset_of_response(response, (network2.to_dict(),)) @@ -453,72 +596,139 @@ def test_get_networks(client, network_factory, readonly_token): def test_create_network_auth_fail(client, session, user_token): # admin is required to create networks - response = post(client, f'{API_URL}/network/networks', data={}, token=user_token) + response = post(client, f"{API_URL}/network/networks", data={}, token=user_token) check_response_message(response, "User doesn't have the required group", 403) def test_create_network(client, admin_token, network_scope_factory): - scope = network_scope_factory(supernet='172.16.0.0/16') + scope = network_scope_factory(supernet="172.16.0.0/16") # check that vlan_name, vlan_id, address, first_ip, last_ip and scope are mandatory - response = post(client, f'{API_URL}/network/networks', data={}, token=admin_token) + response = post(client, f"{API_URL}/network/networks", data={}, token=admin_token) check_response_message(response, "Missing mandatory field 'vlan_name'", 422) - response = post(client, f'{API_URL}/network/networks', data={'first_ip': '172.16.1.10', 'last_ip': '172.16.1.250'}, token=admin_token) + response = post( + client, + f"{API_URL}/network/networks", + data={"first_ip": "172.16.1.10", "last_ip": "172.16.1.250"}, + token=admin_token, + ) check_response_message(response, "Missing mandatory field 'vlan_name'", 422) - response = post(client, f'{API_URL}/network/networks', data={'address': '172.16.1.0/24'}, token=admin_token) + response = post( + client, + f"{API_URL}/network/networks", + data={"address": "172.16.1.0/24"}, + token=admin_token, + ) check_response_message(response, "Missing mandatory field 'vlan_name'", 422) - response = post(client, f'{API_URL}/network/networks', data={'vlan_name': 'network1'}, token=admin_token) + response = post( + client, + f"{API_URL}/network/networks", + data={"vlan_name": "network1"}, + token=admin_token, + ) check_response_message(response, "Missing mandatory field 'vlan_id'", 422) - response = post(client, f'{API_URL}/network/networks', data={'vlan_name': 'network1', 'vlan_id': 1600}, token=admin_token) + response = post( + client, + f"{API_URL}/network/networks", + data={"vlan_name": "network1", "vlan_id": 1600}, + token=admin_token, + ) check_response_message(response, "Missing mandatory field 'address'", 422) - response = post(client, f'{API_URL}/network/networks', data={'vlan_name': 'network1', 'vlan_id': 1600, 'address': '172.16.1.0/24', 'first_ip': '172.16.1.10'}, token=admin_token) + response = post( + client, + f"{API_URL}/network/networks", + data={ + "vlan_name": "network1", + "vlan_id": 1600, + "address": "172.16.1.0/24", + "first_ip": "172.16.1.10", + }, + token=admin_token, + ) check_response_message(response, "Missing mandatory field 'last_ip'", 422) - data = {'vlan_name': 'network1', - 'vlan_id': 1600, - 'address': '172.16.1.0/24', - 'first_ip': '172.16.1.10', - 'last_ip': '172.16.1.250', - 'scope': scope.name} - response = post(client, f'{API_URL}/network/networks', data=data, token=admin_token) + data = { + "vlan_name": "network1", + "vlan_id": 1600, + "address": "172.16.1.0/24", + "first_ip": "172.16.1.10", + "last_ip": "172.16.1.250", + "scope": scope.name, + } + response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token) assert response.status_code == 201 - assert {'id', 'vlan_name', 'vlan_id', 'address', 'netmask', - 'first_ip', 'last_ip', 'description', 'admin_only', - 'scope', 'domain', 'interfaces', 'created_at', - 'updated_at', 'user'} == set(response.json.keys()) - assert response.json['vlan_name'] == 'network1' - assert response.json['vlan_id'] == 1600 - assert response.json['address'] == '172.16.1.0/24' - assert response.json['first_ip'] == '172.16.1.10' - assert response.json['last_ip'] == '172.16.1.250' - assert response.json['netmask'] == '255.255.255.0' + assert { + "id", + "vlan_name", + "vlan_id", + "address", + "netmask", + "first_ip", + "last_ip", + "description", + "admin_only", + "scope", + "domain", + "interfaces", + "created_at", + "updated_at", + "user", + } == set(response.json.keys()) + assert response.json["vlan_name"] == "network1" + assert response.json["vlan_id"] == 1600 + assert response.json["address"] == "172.16.1.0/24" + assert response.json["first_ip"] == "172.16.1.10" + assert response.json["last_ip"] == "172.16.1.250" + assert response.json["netmask"] == "255.255.255.0" # Check that address and name shall be unique - response = post(client, f'{API_URL}/network/networks', data=data, token=admin_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) + response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) data_same_address = data.copy() - data_same_address['vlan_name'] = 'networkX' - response = post(client, f'{API_URL}/network/networks', data=data_same_address, token=admin_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) - data_same_name = {'vlan_name': 'network1', - 'vlan_id': '1600', - 'address': '172.16.2.0/24', - 'first_ip': '172.16.2.10', - 'last_ip': '172.16.2.250', - 'scope': scope.name} - response = post(client, f'{API_URL}/network/networks', data=data_same_name, token=admin_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) + data_same_address["vlan_name"] = "networkX" + response = post( + client, f"{API_URL}/network/networks", data=data_same_address, token=admin_token + ) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) + data_same_name = { + "vlan_name": "network1", + "vlan_id": "1600", + "address": "172.16.2.0/24", + "first_ip": "172.16.2.10", + "last_ip": "172.16.2.250", + "scope": scope.name, + } + response = post( + client, f"{API_URL}/network/networks", data=data_same_name, token=admin_token + ) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) # Check that all parameters can be passed - data2 = {'vlan_name': 'network2', - 'vlan_id': '1601', - 'address': '172.16.5.0/24', - 'first_ip': '172.16.5.11', - 'last_ip': '172.16.5.250', - 'description': 'long description', - 'scope': scope.name} - response = post(client, f'{API_URL}/network/networks', data=data2, token=admin_token) + data2 = { + "vlan_name": "network2", + "vlan_id": "1601", + "address": "172.16.5.0/24", + "first_ip": "172.16.5.11", + "last_ip": "172.16.5.250", + "description": "long description", + "scope": scope.name, + } + response = post( + client, f"{API_URL}/network/networks", data=data2, token=admin_token + ) assert response.status_code == 201 - assert response.json['description'] == 'long description' + assert response.json["description"] == "long description" # check all items that were created assert models.Network.query.count() == 2 @@ -526,192 +736,306 @@ def test_create_network(client, admin_token, network_scope_factory): def test_create_network_invalid_address(client, admin_token, network_scope): # invalid network address - data = {'vlan_name': 'network1', - 'vlan_id': '1600', - 'address': 'foo', - 'first_ip': '172.16.1.10', - 'last_ip': '172.16.1.250', - 'scope': network_scope.name} - response = post(client, f'{API_URL}/network/networks', data=data, token=admin_token) - check_response_message(response, "'foo' does not appear to be an IPv4 or IPv6 network", 422) - data['address'] = '172.16.1' - response = post(client, f'{API_URL}/network/networks', data=data, token=admin_token) - check_response_message(response, "'172.16.1' does not appear to be an IPv4 or IPv6 network", 422) + data = { + "vlan_name": "network1", + "vlan_id": "1600", + "address": "foo", + "first_ip": "172.16.1.10", + "last_ip": "172.16.1.250", + "scope": network_scope.name, + } + response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token) + check_response_message( + response, "'foo' does not appear to be an IPv4 or IPv6 network", 422 + ) + data["address"] = "172.16.1" + response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token) + check_response_message( + response, "'172.16.1' does not appear to be an IPv4 or IPv6 network", 422 + ) # address has host bits set - data['address'] = '172.16.1.1/24' - response = post(client, f'{API_URL}/network/networks', data=data, token=admin_token) - check_response_message(response, '172.16.1.1/24 has host bits set', 422) + data["address"] = "172.16.1.1/24" + response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token) + check_response_message(response, "172.16.1.1/24 has host bits set", 422) -@pytest.mark.parametrize('address', ('', 'foo', '192.168')) -def test_create_network_invalid_ip(address, client, session, admin_token, network_scope): +@pytest.mark.parametrize("address", ("", "foo", "192.168")) +def test_create_network_invalid_ip( + address, client, session, admin_token, network_scope +): # invalid first IP address - data = {'vlan_name': 'network1', - 'vlan_id': '1600', - 'address': '192.168.0.0/24', - 'first_ip': address, - 'last_ip': '192.168.0.250', - 'scope': network_scope.name} - response = post(client, f'{API_URL}/network/networks', data=data, token=admin_token) - check_response_message(response, f"'{address}' does not appear to be an IPv4 or IPv6 address", 422) + data = { + "vlan_name": "network1", + "vlan_id": "1600", + "address": "192.168.0.0/24", + "first_ip": address, + "last_ip": "192.168.0.250", + "scope": network_scope.name, + } + response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token) + check_response_message( + response, f"'{address}' does not appear to be an IPv4 or IPv6 address", 422 + ) # invalid last IP address - data = {'vlan_name': 'network1', - 'vlan_id': '1600', - 'address': '192.168.0.0/24', - 'first_ip': '192.168.0.250', - 'last_ip': address, - 'scope': network_scope.name} - response = post(client, f'{API_URL}/network/networks', data=data, token=admin_token) - check_response_message(response, f"'{address}' does not appear to be an IPv4 or IPv6 address", 422) + data = { + "vlan_name": "network1", + "vlan_id": "1600", + "address": "192.168.0.0/24", + "first_ip": "192.168.0.250", + "last_ip": address, + "scope": network_scope.name, + } + response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token) + check_response_message( + response, f"'{address}' does not appear to be an IPv4 or IPv6 address", 422 + ) def test_create_network_invalid_range(client, session, admin_token, network_scope): # first_ip not in network address - data = {'vlan_name': 'network1', - 'vlan_id': '1600', - 'address': '172.16.1.0/24', - 'first_ip': '172.16.2.10', - 'last_ip': '172.16.1.250', - 'scope': network_scope.name} - response = post(client, f'{API_URL}/network/networks', data=data, token=admin_token) - check_response_message(response, 'IP address 172.16.2.10 is not in network 172.16.1.0/24', 422) + data = { + "vlan_name": "network1", + "vlan_id": "1600", + "address": "172.16.1.0/24", + "first_ip": "172.16.2.10", + "last_ip": "172.16.1.250", + "scope": network_scope.name, + } + response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token) + check_response_message( + response, "IP address 172.16.2.10 is not in network 172.16.1.0/24", 422 + ) # last_ip not in network address - data = {'vlan_name': 'network1', - 'vlan_id': '1600', - 'address': '172.16.1.0/24', - 'first_ip': '172.16.1.10', - 'last_ip': '172.16.5.250', - 'scope': network_scope.name} - response = post(client, f'{API_URL}/network/networks', data=data, token=admin_token) - check_response_message(response, 'IP address 172.16.5.250 is not in network 172.16.1.0/24', 422) + data = { + "vlan_name": "network1", + "vlan_id": "1600", + "address": "172.16.1.0/24", + "first_ip": "172.16.1.10", + "last_ip": "172.16.5.250", + "scope": network_scope.name, + } + response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token) + check_response_message( + response, "IP address 172.16.5.250 is not in network 172.16.1.0/24", 422 + ) # first_ip > last_ip - data = {'vlan_name': 'network1', - 'vlan_id': '1600', - 'address': '172.16.1.0/24', - 'first_ip': '172.16.1.10', - 'last_ip': '172.16.1.9', - 'scope': network_scope.name} - response = post(client, f'{API_URL}/network/networks', data=data, token=admin_token) - check_response_message(response, 'Last IP address 172.16.1.9 is less than the first address 172.16.1.10', 422) + data = { + "vlan_name": "network1", + "vlan_id": "1600", + "address": "172.16.1.0/24", + "first_ip": "172.16.1.10", + "last_ip": "172.16.1.9", + "scope": network_scope.name, + } + response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token) + check_response_message( + response, + "Last IP address 172.16.1.9 is less than the first address 172.16.1.10", + 422, + ) def test_get_interfaces(client, network_factory, interface_factory, readonly_token): # Create some interfaces - network1 = network_factory(address='192.168.1.0/24', first_ip='192.168.1.10', last_ip='192.168.1.250') - network2 = network_factory(address='192.168.2.0/24', first_ip='192.168.2.10', last_ip='192.168.2.250') - interface1 = interface_factory(network=network1, ip='192.168.1.10') - interface2 = interface_factory(network=network1, ip='192.168.1.11', name='interface2') - interface3 = interface_factory(network=network2, ip='192.168.2.10') + network1 = network_factory( + address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250" + ) + network2 = network_factory( + address="192.168.2.0/24", first_ip="192.168.2.10", last_ip="192.168.2.250" + ) + interface1 = interface_factory(network=network1, ip="192.168.1.10") + interface2 = interface_factory( + network=network1, ip="192.168.1.11", name="interface2" + ) + interface3 = interface_factory(network=network2, ip="192.168.2.10") - response = get(client, f'{API_URL}/network/interfaces', token=readonly_token) + response = get(client, f"{API_URL}/network/interfaces", token=readonly_token) assert response.status_code == 200 assert len(response.json) == 3 - check_input_is_subset_of_response(response, (interface1.to_dict(), interface2.to_dict(), interface3.to_dict())) + check_input_is_subset_of_response( + response, (interface1.to_dict(), interface2.to_dict(), interface3.to_dict()) + ) # test filtering by network_id - response = get(client, f'{API_URL}/network/interfaces?network_id={network2.id}', token=readonly_token) + response = get( + client, + f"{API_URL}/network/interfaces?network_id={network2.id}", + token=readonly_token, + ) assert response.status_code == 200 assert len(response.json) == 1 check_input_is_subset_of_response(response, (interface3.to_dict(),)) -def test_get_interfaces_by_domain(client, domain_factory, network_factory, interface_factory, readonly_token): +def test_get_interfaces_by_domain( + client, domain_factory, network_factory, interface_factory, readonly_token +): # Create some interfaces - domain1 = domain_factory(name='tn.esss.lu.se') - domain2 = domain_factory(name='ics.esss.lu.se') - network1 = network_factory(address='192.168.1.0/24', first_ip='192.168.1.10', last_ip='192.168.1.250', domain=domain1) - network2 = network_factory(address='192.168.2.0/24', first_ip='192.168.2.10', last_ip='192.168.2.250', domain=domain2) - interface1 = interface_factory(network=network1, ip='192.168.1.10') - interface2 = interface_factory(network=network1, ip='192.168.1.11', name='interface2') - interface3 = interface_factory(network=network2, ip='192.168.2.10') + domain1 = domain_factory(name="tn.esss.lu.se") + domain2 = domain_factory(name="ics.esss.lu.se") + network1 = network_factory( + address="192.168.1.0/24", + first_ip="192.168.1.10", + last_ip="192.168.1.250", + domain=domain1, + ) + network2 = network_factory( + address="192.168.2.0/24", + first_ip="192.168.2.10", + last_ip="192.168.2.250", + domain=domain2, + ) + interface1 = interface_factory(network=network1, ip="192.168.1.10") + interface2 = interface_factory( + network=network1, ip="192.168.1.11", name="interface2" + ) + interface3 = interface_factory(network=network2, ip="192.168.2.10") # test filtering by domain - response = get(client, f'{API_URL}/network/interfaces?domain=tn.esss.lu.se', token=readonly_token) + response = get( + client, + f"{API_URL}/network/interfaces?domain=tn.esss.lu.se", + token=readonly_token, + ) assert response.status_code == 200 assert len(response.json) == 2 - check_input_is_subset_of_response(response, (interface1.to_dict(), interface2.to_dict())) + check_input_is_subset_of_response( + response, (interface1.to_dict(), interface2.to_dict()) + ) - response = get(client, f'{API_URL}/network/interfaces?domain=ics.esss.lu.se', token=readonly_token) + response = get( + client, + f"{API_URL}/network/interfaces?domain=ics.esss.lu.se", + token=readonly_token, + ) assert response.status_code == 200 assert len(response.json) == 1 check_input_is_subset_of_response(response, (interface3.to_dict(),)) def test_create_interface(client, network_factory, user_token): - network = network_factory(address='192.168.1.0/24', first_ip='192.168.1.10', last_ip='192.168.1.250') + network = network_factory( + address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250" + ) # check that network_id and ip are mandatory - response = post(client, f'{API_URL}/network/interfaces', data={}, token=user_token) + response = post(client, f"{API_URL}/network/interfaces", data={}, token=user_token) check_response_message(response, "Missing mandatory field 'network'", 422) - response = post(client, f'{API_URL}/network/interfaces', data={'ip': '192.168.1.20'}, token=user_token) + response = post( + client, + f"{API_URL}/network/interfaces", + data={"ip": "192.168.1.20"}, + token=user_token, + ) check_response_message(response, "Missing mandatory field 'network'", 422) - response = post(client, f'{API_URL}/network/interfaces', data={'network': network.address}, token=user_token) + response = post( + client, + f"{API_URL}/network/interfaces", + data={"network": network.address}, + token=user_token, + ) check_response_message(response, "Missing mandatory field 'ip'", 422) - data = {'network': network.vlan_name, - 'ip': '192.168.1.20', - 'name': 'interface1'} - response = post(client, f'{API_URL}/network/interfaces', data=data, token=user_token) + data = {"network": network.vlan_name, "ip": "192.168.1.20", "name": "interface1"} + response = post( + client, f"{API_URL}/network/interfaces", data=data, token=user_token + ) assert response.status_code == 201 - assert {'id', 'network', 'ip', 'name', 'mac', 'domain', - 'host', 'device_type', 'cnames', 'tags', 'created_at', - 'updated_at', 'user'} == set(response.json.keys()) - assert response.json['network'] == network.vlan_name - assert response.json['ip'] == '192.168.1.20' - assert response.json['name'] == 'interface1' + assert { + "id", + "network", + "ip", + "name", + "mac", + "domain", + "host", + "device_type", + "cnames", + "tags", + "created_at", + "updated_at", + "user", + } == set(response.json.keys()) + assert response.json["network"] == network.vlan_name + assert response.json["ip"] == "192.168.1.20" + assert response.json["name"] == "interface1" # Check that IP and name shall be unique - response = post(client, f'{API_URL}/network/interfaces', data=data, token=user_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) + response = post( + client, f"{API_URL}/network/interfaces", data=data, token=user_token + ) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) # Check that all parameters can be passed - data2 = {'network': network.vlan_name, - 'ip': '192.168.1.21', - 'name': 'myhostname'} - response = post(client, f'{API_URL}/network/interfaces', data=data2, token=user_token) + data2 = {"network": network.vlan_name, "ip": "192.168.1.21", "name": "myhostname"} + response = post( + client, f"{API_URL}/network/interfaces", data=data2, token=user_token + ) assert response.status_code == 201 # check all items that were created assert models.Interface.query.count() == 2 -@pytest.mark.parametrize('ip', ('', 'foo', '192.168')) +@pytest.mark.parametrize("ip", ("", "foo", "192.168")) def test_create_interface_invalid_ip(ip, client, network_factory, user_token): - network = network_factory(address='192.168.1.0/24', first_ip='192.168.1.10', last_ip='192.168.1.250') + network = network_factory( + address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250" + ) # invalid IP address - data = {'network': network.vlan_name, - 'ip': ip, - 'name': 'hostname'} - response = post(client, f'{API_URL}/network/interfaces', data=data, token=user_token) - check_response_message(response, f"'{ip}' does not appear to be an IPv4 or IPv6 address", 422) + data = {"network": network.vlan_name, "ip": ip, "name": "hostname"} + response = post( + client, f"{API_URL}/network/interfaces", data=data, token=user_token + ) + check_response_message( + response, f"'{ip}' does not appear to be an IPv4 or IPv6 address", 422 + ) def test_create_interface_ip_not_in_network(client, network_factory, user_token): - network = network_factory(address='192.168.1.0/24', first_ip='192.168.1.10', last_ip='192.168.1.250') + network = network_factory( + address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250" + ) # IP address not in range - data = {'network': network.vlan_name, - 'ip': '192.168.2.4', - 'name': 'hostname'} - response = post(client, f'{API_URL}/network/interfaces', data=data, token=user_token) - check_response_message(response, 'IP address 192.168.2.4 is not in network 192.168.1.0/24', 422) + data = {"network": network.vlan_name, "ip": "192.168.2.4", "name": "hostname"} + response = post( + client, f"{API_URL}/network/interfaces", data=data, token=user_token + ) + check_response_message( + response, "IP address 192.168.2.4 is not in network 192.168.1.0/24", 422 + ) def test_create_interface_ip_not_in_range(client, network_factory, user_token): - network = network_factory(address='192.168.1.0/24', first_ip='192.168.1.10', last_ip='192.168.1.250') + network = network_factory( + address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250" + ) # IP address not in range - data = {'network': network.vlan_name, - 'ip': '192.168.1.4', - 'name': 'hostname'} - response = post(client, f'{API_URL}/network/interfaces', data=data, token=user_token) - check_response_message(response, 'IP address 192.168.1.4 is not in range 192.168.1.10 - 192.168.1.250', 422) + data = {"network": network.vlan_name, "ip": "192.168.1.4", "name": "hostname"} + response = post( + client, f"{API_URL}/network/interfaces", data=data, token=user_token + ) + check_response_message( + response, + "IP address 192.168.1.4 is not in range 192.168.1.10 - 192.168.1.250", + 422, + ) -def test_create_interface_ip_not_in_range_as_admin(client, network_factory, admin_token): - network = network_factory(address='192.168.1.0/24', first_ip='192.168.1.10', last_ip='192.168.1.250') +def test_create_interface_ip_not_in_range_as_admin( + client, network_factory, admin_token +): + network = network_factory( + address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250" + ) # IP address not in range - data = {'network': network.vlan_name, - 'ip': '192.168.1.4', - 'name': 'hostname'} - response = post(client, f'{API_URL}/network/interfaces', data=data, token=admin_token) + data = {"network": network.vlan_name, "ip": "192.168.1.4", "name": "hostname"} + response = post( + client, f"{API_URL}/network/interfaces", data=data, token=admin_token + ) assert response.status_code == 201 @@ -720,7 +1044,7 @@ def test_get_macs(client, mac_factory, readonly_token): mac1 = mac_factory() mac2 = mac_factory() - response = get(client, f'{API_URL}/network/macs', token=readonly_token) + response = get(client, f"{API_URL}/network/macs", token=readonly_token) assert response.status_code == 200 assert len(response.json) == 2 check_input_is_subset_of_response(response, (mac1.to_dict(), mac2.to_dict())) @@ -729,111 +1053,139 @@ def test_get_macs(client, mac_factory, readonly_token): def test_create_mac(client, item_factory, user_token): item = item_factory() # check that address is mandatory - response = post(client, f'{API_URL}/network/macs', data={}, token=user_token) + response = post(client, f"{API_URL}/network/macs", data={}, token=user_token) check_response_message(response, "Missing mandatory field 'address'", 422) - data = {'address': 'b5:4b:7d:a4:23:43'} - response = post(client, f'{API_URL}/network/macs', data=data, token=user_token) + data = {"address": "b5:4b:7d:a4:23:43"} + response = post(client, f"{API_URL}/network/macs", data=data, token=user_token) assert response.status_code == 201 - assert {'id', 'address', 'item', 'interfaces'} == set(response.json.keys()) - assert response.json['address'] == data['address'] + assert {"id", "address", "item", "interfaces"} == set(response.json.keys()) + assert response.json["address"] == data["address"] # Check that address shall be unique - response = post(client, f'{API_URL}/network/macs', data=data, token=user_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) + response = post(client, f"{API_URL}/network/macs", data=data, token=user_token) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) # Check that all parameters can be passed - data2 = {'address': 'b5:4b:7d:a4:23:44', - 'item_id': item.id} - response = post(client, f'{API_URL}/network/macs', data=data2, token=user_token) + data2 = {"address": "b5:4b:7d:a4:23:44", "item_id": item.id} + response = post(client, f"{API_URL}/network/macs", data=data2, token=user_token) assert response.status_code == 201 # check that all items were created assert models.Mac.query.count() == 2 -@pytest.mark.parametrize('address', ('', 'foo', 'b5:4b:7d:a4:23')) +@pytest.mark.parametrize("address", ("", "foo", "b5:4b:7d:a4:23")) def test_create_mac_invalid_address(address, client, user_token): - data = {'address': address} - response = post(client, f'{API_URL}/network/macs', data=data, token=user_token) - check_response_message(response, f"'{address}' does not appear to be a MAC address", 422) + data = {"address": address} + response = post(client, f"{API_URL}/network/macs", data=data, token=user_token) + check_response_message( + response, f"'{address}' does not appear to be a MAC address", 422 + ) def test_get_hosts(client, host_factory, readonly_token): # Create some hosts host1 = host_factory() host2 = host_factory() - response = get(client, f'{API_URL}/network/hosts', token=readonly_token) + response = get(client, f"{API_URL}/network/hosts", token=readonly_token) assert response.status_code == 200 assert len(response.json) == 2 check_input_is_subset_of_response(response, (host1.to_dict(), host2.to_dict())) def test_create_host(client, device_type_factory, user_token): - device_type = device_type_factory(name='Virtual') + device_type = device_type_factory(name="Virtual") # check that name and device_type are mandatory - response = post(client, f'{API_URL}/network/hosts', data={}, token=user_token) + response = post(client, f"{API_URL}/network/hosts", data={}, token=user_token) check_response_message(response, "Missing mandatory field 'name'", 422) - response = post(client, f'{API_URL}/network/hosts', data={'name': 'myhost'}, token=user_token) + response = post( + client, f"{API_URL}/network/hosts", data={"name": "myhost"}, token=user_token + ) check_response_message(response, "Missing mandatory field 'device_type'", 422) - response = post(client, f'{API_URL}/network/hosts', data={'device_type': 'Physical'}, token=user_token) + response = post( + client, + f"{API_URL}/network/hosts", + data={"device_type": "Physical"}, + token=user_token, + ) check_response_message(response, "Missing mandatory field 'name'", 422) - data = {'name': 'my-hostname', - 'device_type': device_type.name} - response = post(client, f'{API_URL}/network/hosts', data=data, token=user_token) + data = {"name": "my-hostname", "device_type": device_type.name} + response = post(client, f"{API_URL}/network/hosts", data=data, token=user_token) assert response.status_code == 201 - assert {'id', 'name', 'device_type', 'description', - 'items', 'interfaces', 'created_at', - 'updated_at', 'user'} == set(response.json.keys()) - assert response.json['name'] == data['name'] + assert { + "id", + "name", + "device_type", + "description", + "items", + "interfaces", + "created_at", + "updated_at", + "user", + } == set(response.json.keys()) + assert response.json["name"] == data["name"] # Check that name shall be unique - response = post(client, f'{API_URL}/network/hosts', data=data, token=user_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) + response = post(client, f"{API_URL}/network/hosts", data=data, token=user_token) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) # check that the number of items created assert models.Host.query.count() == 1 def test_create_host_with_items(client, item_factory, device_type_factory, user_token): - device_type = device_type_factory(name='Network') - item1 = item_factory(ics_id='AAA001') - item2 = item_factory(ics_id='AAA002') + device_type = device_type_factory(name="Network") + item1 = item_factory(ics_id="AAA001") + item2 = item_factory(ics_id="AAA002") # Check that we can pass a list of items ics_id - data = {'name': 'my-switch', - 'device_type': device_type.name, - 'items': [item1.ics_id, item2.ics_id]} - response = post(client, f'{API_URL}/network/hosts', data=data, token=user_token) + data = { + "name": "my-switch", + "device_type": device_type.name, + "items": [item1.ics_id, item2.ics_id], + } + response = post(client, f"{API_URL}/network/hosts", data=data, token=user_token) assert response.status_code == 201 - host = models.Host.query.filter_by(name='my-switch').first() + host = models.Host.query.filter_by(name="my-switch").first() assert models.Item.query.get(item1.id).host_id == host.id assert models.Item.query.get(item2.id).host_id == host.id -def test_create_host_as_consultant(client, item_factory, device_type_factory, consultant_token): +def test_create_host_as_consultant( + client, item_factory, device_type_factory, consultant_token +): device_type = device_type_factory() - data = {'name': 'my-hostname', - 'device_type': device_type.name} - response = post(client, f'{API_URL}/network/hosts', data=data, token=consultant_token) + data = {"name": "my-hostname", "device_type": device_type.name} + response = post( + client, f"{API_URL}/network/hosts", data=data, token=consultant_token + ) assert response.status_code == 201 def test_get_user_profile(client, readonly_token): - response = get(client, f'{API_URL}/user/profile', token=readonly_token) + response = get(client, f"{API_URL}/user/profile", token=readonly_token) assert response.status_code == 200 user = response.json - assert {'id', 'username', 'groups', 'email', 'display_name'} == set(user.keys()) - assert user['username'] == 'user_ro' - assert user['display_name'] == 'User RO' - assert user['email'] == 'user_ro@example.com' + assert {"id", "username", "groups", "email", "display_name"} == set(user.keys()) + assert user["username"] == "user_ro" + assert user["display_name"] == "User RO" + assert user["email"] == "user_ro@example.com" def test_get_domains(client, domain_factory, readonly_token): # Create some domains domain1 = domain_factory() domain2 = domain_factory() - response = get(client, f'{API_URL}/network/domains', token=readonly_token) + response = get(client, f"{API_URL}/network/domains", token=readonly_token) assert response.status_code == 200 assert len(response.json) == 2 check_input_is_subset_of_response(response, (domain1.to_dict(), domain2.to_dict())) @@ -841,36 +1193,53 @@ def test_get_domains(client, domain_factory, readonly_token): def test_create_domain(client, admin_token): # check that name is mandatory - response = post(client, f'{API_URL}/network/domains', data={}, token=admin_token) + response = post(client, f"{API_URL}/network/domains", data={}, token=admin_token) check_response_message(response, "Missing mandatory field 'name'", 422) - data = {'name': 'tn.esss.lu.se'} - response = post(client, f'{API_URL}/network/domains', data=data, token=admin_token) + data = {"name": "tn.esss.lu.se"} + response = post(client, f"{API_URL}/network/domains", data=data, token=admin_token) assert response.status_code == 201 - assert {'id', 'name', 'scopes', - 'networks', 'created_at', - 'updated_at', 'user'} == set(response.json.keys()) - assert response.json['name'] == data['name'] + assert { + "id", + "name", + "scopes", + "networks", + "created_at", + "updated_at", + "user", + } == set(response.json.keys()) + assert response.json["name"] == data["name"] # Check that name shall be unique - response = post(client, f'{API_URL}/network/domains', data=data, token=admin_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) + response = post(client, f"{API_URL}/network/domains", data=data, token=admin_token) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) def test_get_cnames(client, cname_factory, readonly_token): # Create some cnames cname1 = cname_factory() cname2 = cname_factory() - response = get(client, f'{API_URL}/network/cnames', token=readonly_token) + response = get(client, f"{API_URL}/network/cnames", token=readonly_token) assert response.status_code == 200 assert len(response.json) == 2 check_input_is_subset_of_response(response, (cname1.to_dict(), cname2.to_dict())) -def test_get_cnames_by_domain(client, domain_factory, network_factory, interface_factory, cname_factory, readonly_token): +def test_get_cnames_by_domain( + client, + domain_factory, + network_factory, + interface_factory, + cname_factory, + readonly_token, +): # Create some cnames - domain_a = domain_factory(name='a.esss.lu.se') - domain_b = domain_factory(name='b.esss.lu.se') + domain_a = domain_factory(name="a.esss.lu.se") + domain_b = domain_factory(name="b.esss.lu.se") network_a = network_factory(domain=domain_a) network_b = network_factory(domain=domain_b) interface_a1 = interface_factory(network=network_a) @@ -881,36 +1250,55 @@ def test_get_cnames_by_domain(client, domain_factory, network_factory, interface cname_a3 = cname_factory(interface=interface_a2) cname_b1 = cname_factory(interface=interface_b1) cname_b2 = cname_factory(interface=interface_b1) - response = get(client, f'{API_URL}/network/cnames', token=readonly_token) + response = get(client, f"{API_URL}/network/cnames", token=readonly_token) assert response.status_code == 200 assert len(response.json) == 5 - response = get(client, f'{API_URL}/network/cnames?domain=a.esss.lu.se', token=readonly_token) + response = get( + client, f"{API_URL}/network/cnames?domain=a.esss.lu.se", token=readonly_token + ) assert response.status_code == 200 assert len(response.json) == 3 - check_input_is_subset_of_response(response, (cname_a1.to_dict(), cname_a2.to_dict(), cname_a3.to_dict())) - response = get(client, f'{API_URL}/network/cnames?domain=b.esss.lu.se', token=readonly_token) + check_input_is_subset_of_response( + response, (cname_a1.to_dict(), cname_a2.to_dict(), cname_a3.to_dict()) + ) + response = get( + client, f"{API_URL}/network/cnames?domain=b.esss.lu.se", token=readonly_token + ) assert response.status_code == 200 assert len(response.json) == 2 - check_input_is_subset_of_response(response, (cname_b1.to_dict(), cname_b2.to_dict())) + check_input_is_subset_of_response( + response, (cname_b1.to_dict(), cname_b2.to_dict()) + ) def test_create_cname(client, interface, admin_token): # check that name and interface_id are mandatory - response = post(client, f'{API_URL}/network/cnames', data={}, token=admin_token) + response = post(client, f"{API_URL}/network/cnames", data={}, token=admin_token) check_response_message(response, "Missing mandatory field 'name'", 422) - response = post(client, f'{API_URL}/network/cnames', data={'name': 'myhost'}, token=admin_token) + response = post( + client, f"{API_URL}/network/cnames", data={"name": "myhost"}, token=admin_token + ) check_response_message(response, "Missing mandatory field 'interface_id'", 422) - response = post(client, f'{API_URL}/network/cnames', data={'interface_id': interface.id}, token=admin_token) + response = post( + client, + f"{API_URL}/network/cnames", + data={"interface_id": interface.id}, + token=admin_token, + ) check_response_message(response, "Missing mandatory field 'name'", 422) - data = {'name': 'myhost.tn.esss.lu.se', - 'interface_id': interface.id} - response = post(client, f'{API_URL}/network/cnames', data=data, token=admin_token) + data = {"name": "myhost.tn.esss.lu.se", "interface_id": interface.id} + response = post(client, f"{API_URL}/network/cnames", data=data, token=admin_token) assert response.status_code == 201 - assert {'id', 'name', 'interface', - 'created_at', 'updated_at', 'user'} == set(response.json.keys()) - assert response.json['name'] == data['name'] + assert {"id", "name", "interface", "created_at", "updated_at", "user"} == set( + response.json.keys() + ) + assert response.json["name"] == data["name"] # Check that name shall be unique - response = post(client, f'{API_URL}/network/cnames', data=data, token=admin_token) - check_response_message(response, '(psycopg2.IntegrityError) duplicate key value violates unique constraint', 422) + response = post(client, f"{API_URL}/network/cnames", data=data, token=admin_token) + check_response_message( + response, + "(psycopg2.IntegrityError) duplicate key value violates unique constraint", + 422, + ) diff --git a/tests/functional/test_models.py b/tests/functional/test_models.py index bf2185f..e000d18 100644 --- a/tests/functional/test_models.py +++ b/tests/functional/test_models.py @@ -17,99 +17,121 @@ from wtforms import ValidationError def test_user_groups(user_factory): user = user_factory() assert user.groups == [] - groups = ['foo', 'Another group'] + groups = ["foo", "Another group"] user = user_factory(groups=groups) assert user.groups == groups def test_user_is_admin(user_factory): - user = user_factory(groups=['foo', 'CSEntry User']) + user = user_factory(groups=["foo", "CSEntry User"]) assert not user.is_admin - user = user_factory(groups=['foo', 'CSEntry Admin']) + user = user_factory(groups=["foo", "CSEntry Admin"]) assert user.is_admin def test_user_is_member_of_one_group(user_factory): - user = user_factory(groups=['one', 'two']) - assert not user.is_member_of_one_group(['create', 'admin']) - user = user_factory(groups=['one', 'CSEntry Consultant']) - assert user.is_member_of_one_group(['create']) - assert user.is_member_of_one_group(['create', 'admin']) - assert not user.is_member_of_one_group(['admin']) - user = user_factory(groups=['one', 'CSEntry Admin']) - assert not user.is_member_of_one_group(['create']) - assert user.is_member_of_one_group(['create', 'admin']) - assert user.is_member_of_one_group(['admin']) + user = user_factory(groups=["one", "two"]) + assert not user.is_member_of_one_group(["create", "admin"]) + user = user_factory(groups=["one", "CSEntry Consultant"]) + assert user.is_member_of_one_group(["create"]) + assert user.is_member_of_one_group(["create", "admin"]) + assert not user.is_member_of_one_group(["admin"]) + user = user_factory(groups=["one", "CSEntry Admin"]) + assert not user.is_member_of_one_group(["create"]) + assert user.is_member_of_one_group(["create", "admin"]) + assert user.is_member_of_one_group(["admin"]) def test_network_ip_properties(network_factory): # Create some networks - network1 = network_factory(address='172.16.1.0/24', first_ip='172.16.1.10', last_ip='172.16.1.250') - network2 = network_factory(address='172.16.20.0/26', first_ip='172.16.20.11', last_ip='172.16.20.14') - - assert network1.network_ip == ipaddress.ip_network('172.16.1.0/24') - assert network1.first == ipaddress.ip_address('172.16.1.10') - assert network1.last == ipaddress.ip_address('172.16.1.250') + network1 = network_factory( + address="172.16.1.0/24", first_ip="172.16.1.10", last_ip="172.16.1.250" + ) + network2 = network_factory( + address="172.16.20.0/26", first_ip="172.16.20.11", last_ip="172.16.20.14" + ) + + assert network1.network_ip == ipaddress.ip_network("172.16.1.0/24") + assert network1.first == ipaddress.ip_address("172.16.1.10") + assert network1.last == ipaddress.ip_address("172.16.1.250") assert len(network1.ip_range()) == 241 - assert network1.ip_range() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(10, 251)] + assert network1.ip_range() == [ + ipaddress.ip_address(f"172.16.1.{i}") for i in range(10, 251) + ] assert network1.ip_range() == network1.available_ips() assert network1.used_ips() == [] - assert network2.network_ip == ipaddress.ip_network('172.16.20.0/26') - assert network2.first == ipaddress.ip_address('172.16.20.11') - assert network2.last == ipaddress.ip_address('172.16.20.14') + assert network2.network_ip == ipaddress.ip_network("172.16.20.0/26") + assert network2.first == ipaddress.ip_address("172.16.20.11") + assert network2.last == ipaddress.ip_address("172.16.20.14") assert len(network2.ip_range()) == 4 - assert network2.ip_range() == [ipaddress.ip_address(f'172.16.20.{i}') for i in range(11, 15)] + assert network2.ip_range() == [ + ipaddress.ip_address(f"172.16.20.{i}") for i in range(11, 15) + ] assert network2.ip_range() == network2.available_ips() assert network2.used_ips() == [] def test_network_available_and_used_ips(network_factory, interface_factory): # Create some networks and interfaces - network1 = network_factory(address='172.16.1.0/24', first_ip='172.16.1.10', last_ip='172.16.1.250') - network2 = network_factory(address='172.16.20.0/26', first_ip='172.16.20.11', last_ip='172.16.20.14') + network1 = network_factory( + address="172.16.1.0/24", first_ip="172.16.1.10", last_ip="172.16.1.250" + ) + network2 = network_factory( + address="172.16.20.0/26", first_ip="172.16.20.11", last_ip="172.16.20.14" + ) for i in range(10, 20): - interface_factory(network=network1, ip=f'172.16.1.{i}') - interface_factory(network=network2, ip='172.16.20.13') + interface_factory(network=network1, ip=f"172.16.1.{i}") + interface_factory(network=network2, ip="172.16.20.13") # Check available and used IPs - assert network1.used_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(10, 20)] - assert network1.available_ips() == [ipaddress.ip_address(f'172.16.1.{i}') for i in range(20, 251)] - assert network2.used_ips() == [ipaddress.ip_address('172.16.20.13')] - assert network2.available_ips() == [ipaddress.ip_address('172.16.20.11'), - ipaddress.ip_address('172.16.20.12'), - ipaddress.ip_address('172.16.20.14')] + assert network1.used_ips() == [ + ipaddress.ip_address(f"172.16.1.{i}") for i in range(10, 20) + ] + assert network1.available_ips() == [ + ipaddress.ip_address(f"172.16.1.{i}") for i in range(20, 251) + ] + assert network2.used_ips() == [ipaddress.ip_address("172.16.20.13")] + assert network2.available_ips() == [ + ipaddress.ip_address("172.16.20.11"), + ipaddress.ip_address("172.16.20.12"), + ipaddress.ip_address("172.16.20.14"), + ] # Add more interfaces - interface_factory(network=network2, ip='172.16.20.11') - interface_factory(network=network2, ip='172.16.20.14') + interface_factory(network=network2, ip="172.16.20.11") + interface_factory(network=network2, ip="172.16.20.14") assert len(network2.used_ips()) == 3 - assert network2.used_ips() == [ipaddress.ip_address('172.16.20.11'), - ipaddress.ip_address('172.16.20.13'), - ipaddress.ip_address('172.16.20.14')] - assert network2.available_ips() == [ipaddress.ip_address('172.16.20.12')] + assert network2.used_ips() == [ + ipaddress.ip_address("172.16.20.11"), + ipaddress.ip_address("172.16.20.13"), + ipaddress.ip_address("172.16.20.14"), + ] + assert network2.available_ips() == [ipaddress.ip_address("172.16.20.12")] # Add last available IP - interface_factory(network=network2, ip='172.16.20.12') - assert network2.used_ips() == [ipaddress.ip_address(f'172.16.20.{i}') for i in range(11, 15)] + interface_factory(network=network2, ip="172.16.20.12") + assert network2.used_ips() == [ + ipaddress.ip_address(f"172.16.20.{i}") for i in range(11, 15) + ] assert list(network2.available_ips()) == [] def test_network_gateway(network_factory): - network = network_factory(address='192.168.0.0/24') - assert str(network.gateway) == '192.168.0.254' - network = network_factory(address='172.16.110.0/23') - assert str(network.gateway) == '172.16.111.254' + network = network_factory(address="192.168.0.0/24") + assert str(network.gateway) == "192.168.0.254" + network = network_factory(address="172.16.110.0/23") + assert str(network.gateway) == "172.16.111.254" def test_mac_address_validation(mac_factory): - mac = mac_factory(address='F4:A7:39:15:DA:01') - assert mac.address == 'f4:a7:39:15:da:01' - mac = mac_factory(address='F4-A7-39-15-DA-02') - assert mac.address == 'f4:a7:39:15:da:02' - mac = mac_factory(address='F4A73915DA06') - assert mac.address == 'f4:a7:39:15:da:06' + mac = mac_factory(address="F4:A7:39:15:DA:01") + assert mac.address == "f4:a7:39:15:da:01" + mac = mac_factory(address="F4-A7-39-15-DA-02") + assert mac.address == "f4:a7:39:15:da:02" + mac = mac_factory(address="F4A73915DA06") + assert mac.address == "f4:a7:39:15:da:06" with pytest.raises(ValidationError) as excinfo: - mac = mac_factory(address='F4A73915DA') + mac = mac_factory(address="F4A73915DA") assert "'F4A73915DA' does not appear to be a MAC address" in str(excinfo.value) @@ -131,16 +153,16 @@ def test_manufacturer_favorite_users(user_factory, manufacturer_factory): def test_device_type_validation(device_type_factory): - device_type = device_type_factory(name='PhysicalMachine') - assert device_type.name == 'PhysicalMachine' + device_type = device_type_factory(name="PhysicalMachine") + assert device_type.name == "PhysicalMachine" with pytest.raises(ValidationError) as excinfo: - device_type = device_type_factory(name='Physical Machine') + device_type = device_type_factory(name="Physical Machine") assert "'Physical Machine' is an invalid device type name" in str(excinfo.value) def test_tag_validation(tag_factory): - tag = tag_factory(name='IOC') - assert tag.name == 'IOC' + tag = tag_factory(name="IOC") + assert tag.name == "IOC" with pytest.raises(ValidationError) as excinfo: - tag = tag_factory(name='My tag') + tag = tag_factory(name="My tag") assert "'My tag' is an invalid tag name" in str(excinfo.value) diff --git a/tests/functional/test_web.py b/tests/functional/test_web.py index 6da1cb8..f4bb959 100644 --- a/tests/functional/test_web.py +++ b/tests/functional/test_web.py @@ -16,76 +16,70 @@ import re def get(client, url): response = client.get(url) - if response.headers['Content-Type'] == 'application/json': + if response.headers["Content-Type"] == "application/json": response.json = json.loads(response.data) return response def login(client, username, password): - data = { - 'username': username, - 'password': password - } - return client.post('/user/login', data=data, follow_redirects=True) + data = {"username": username, "password": password} + return client.post("/user/login", data=data, follow_redirects=True) def logout(client): - return client.get('/user/logout', follow_redirects=True) + return client.get("/user/logout", follow_redirects=True) @pytest.fixture def logged_client(client): - login(client, 'user_ro', 'userro') + login(client, "user_ro", "userro") return client def test_login_logout(client): - response = login(client, 'unknown', 'invalid') - assert b'<title>Login - CSEntry</title>' in response.data - response = login(client, 'user_rw', 'invalid') - assert b'<title>Login - CSEntry</title>' in response.data - response = login(client, 'user_rw', 'userrw') - assert b'Welcome to CSEntry!' in response.data - assert b'User RW' in response.data + response = login(client, "unknown", "invalid") + assert b"<title>Login - CSEntry</title>" in response.data + response = login(client, "user_rw", "invalid") + assert b"<title>Login - CSEntry</title>" in response.data + response = login(client, "user_rw", "userrw") + assert b"Welcome to CSEntry!" in response.data + assert b"User RW" in response.data response = logout(client) - assert b'<title>Login - CSEntry</title>' in response.data + assert b"<title>Login - CSEntry</title>" in response.data def test_index(logged_client): - response = logged_client.get('/') - assert b'Welcome to CSEntry!' in response.data - assert b'User RO' in response.data + response = logged_client.get("/") + assert b"Welcome to CSEntry!" in response.data + assert b"User RO" in response.data -@pytest.mark.parametrize('url', [ - '/', - '/inventory/items', - '/inventory/_retrieve_items', - '/network/networks', -]) +@pytest.mark.parametrize( + "url", ["/", "/inventory/items", "/inventory/_retrieve_items", "/network/networks"] +) def test_protected_url(url, client): response = client.get(url) assert response.status_code == 302 - assert '/user/login' in response.headers['Location'] - login(client, 'user_ro', 'userro') + assert "/user/login" in response.headers["Location"] + login(client, "user_ro", "userro") response = client.get(url) assert response.status_code == 200 def test_retrieve_items(logged_client, item_factory): - response = get(logged_client, '/inventory/_retrieve_items') - assert response.json['data'] == [] - serial_numbers = ('12345', '45678') + response = get(logged_client, "/inventory/_retrieve_items") + assert response.json["data"] == [] + serial_numbers = ("12345", "45678") for sn in serial_numbers: item_factory(serial_number=sn) - response = get(logged_client, '/inventory/_retrieve_items') - items = response.json['data'] + response = get(logged_client, "/inventory/_retrieve_items") + items = response.json["data"] assert set(serial_numbers) == set(item[4] for item in items) assert len(items[0]) == 11 def test_generate_random_mac(logged_client): - response = get(logged_client, '/network/_generate_random_mac') - mac = response.json['data']['mac'] - assert re.match('^(?:[0-9a-fA-F]{2}:){5}[0-9a-fA-F]{2}$', mac) is not None - assert mac.startswith('02:42:42') + response = get(logged_client, "/network/_generate_random_mac") + mac = response.json["data"]["mac"] + assert re.match("^(?:[0-9a-fA-F]{2}:){5}[0-9a-fA-F]{2}$", mac) is not None + assert mac.startswith("02:42:42") -- GitLab