# -*- coding: utf-8 -*- """ app.models ~~~~~~~~~~ This module implements the models used in the app. :copyright: (c) 2017 European Spallation Source ERIC :license: BSD 2-Clause, see LICENSE for more details. """ import ipaddress import string import qrcode import sqlalchemy as sa from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.dialects import postgresql from sqlalchemy.orm import validates from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy_continuum import make_versioned, version_class from citext import CIText from flask import current_app from flask_login import UserMixin from wtforms import ValidationError from .extensions import db, login_manager, ldap_manager from .plugins import FlaskUserPlugin from .validators import ICS_ID_RE, HOST_NAME_RE, VLAN_NAME_RE from . import utils make_versioned(plugins=[FlaskUserPlugin()]) # See http://docs.sqlalchemy.org/en/latest/core/compiler.html#utc-timestamp-function class utcnow(sa.sql.expression.FunctionElement): type = sa.types.DateTime() @sa.ext.compiler.compiles(utcnow, 'postgresql') def pg_utcnow(element, compiler, **kw): return "TIMEZONE('utc', CURRENT_TIMESTAMP)" def temporary_ics_ids(): """Generator that returns the full list of temporary ICS ids""" return (f'{current_app.config["TEMPORARY_ICS_ID"]}{letter}{number:0=3d}' for letter in string.ascii_uppercase for number in range(0, 1000)) def used_temporary_ics_ids(): """Generator that returns the list of temporary ICS ids used""" temporary_items = Item.query.filter( Item.ics_id.startswith( current_app.config['TEMPORARY_ICS_ID'])).all() return (item.ics_id for item in temporary_items) def get_temporary_ics_id(): """Return a temporary ICS id that is available""" for ics_id in temporary_ics_ids(): if ics_id not in used_temporary_ics_ids(): return ics_id else: raise ValueError('No temporary ICS id available') @login_manager.user_loader def load_user(user_id): """User loader callback for flask-login :param str user_id: unicode ID of a user :returns: corresponding user object or None """ return User.query.get(int(user_id)) @ldap_manager.save_user def save_user(dn, username, data, memberships): """User saver for flask-ldap3-login This method is called whenever a LDAPLoginForm() successfully validates. """ user = User.query.filter_by(username=username).first() if user is None: user = User(username=username, display_name=utils.attribute_to_string(data['cn']), email=utils.attribute_to_string(data['mail'])) # Always update the user groups to keep them up-to-date user.groups = [utils.attribute_to_string(group['cn']) for group in memberships] db.session.add(user) db.session.commit() return user # Table required for Many-to-Many relationships between users and groups usergroups_table = db.Table( 'usergroups', db.Column('user_id', db.Integer, db.ForeignKey('user_account.id')), db.Column('group_id', db.Integer, db.ForeignKey('group.id')) ) class Group(db.Model): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.Text, nullable=False, unique=True) def __str__(self): return self.name def find_or_create_group(name): """Return the existing group or a newly created one""" group = Group.query.filter_by(name=name).first() return group or Group(name=name) class User(db.Model, UserMixin): # "user" is a reserved word in postgresql # so let's use another name __tablename__ = 'user_account' id = db.Column(db.Integer, primary_key=True) username = db.Column(db.Text, nullable=False, unique=True) display_name = db.Column(db.Text, nullable=False) email = db.Column(db.Text) grp = db.relationship('Group', secondary=usergroups_table, backref=db.backref('members', lazy='dynamic')) # Proxy the 'name' attribute from the 'grp' relationship # See http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html groups = association_proxy('grp', 'name', creator=find_or_create_group) tokens = db.relationship("Token", backref="user") def get_id(self): """Return the user id as unicode Required by flask-login """ return str(self.id) @property def csentry_groups(self): groups = [] for key, value in current_app.config['CSENTRY_LDAP_GROUPS'].items(): if value in self.groups: groups.append(key) return groups @property def is_admin(self): return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups def is_member_of_one_group(self, groups): """Return True if the user is at least member of one of the given groups""" names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups] return bool(set(self.groups) & set(names)) def is_member_of_all_groups(self, groups): """Return True if the user is member of all the given groups""" names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups] return set(names).issubset(self.groups) def __str__(self): return self.display_name def to_dict(self): return { 'id': self.id, 'username': self.username, 'display_name': self.display_name, 'email': self.email, 'groups': self.csentry_groups, } class Token(db.Model): """Table to store valid tokens""" id = db.Column(db.Integer, primary_key=True) jti = db.Column(postgresql.UUID, nullable=False) token_type = db.Column(db.Text, nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user_account.id'), nullable=False) issued_at = db.Column(db.DateTime, nullable=False) # expires can be set to None for tokens that never expire expires = db.Column(db.DateTime) description = db.Column(db.Text) __table_args__ = ( sa.UniqueConstraint(jti, user_id), ) def __str__(self): return self.jti class QRCodeMixin: id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) def image(self): """Return a QRCode image to identify a record The QRCode includes: - CSE prefix - the table name - the name of the record """ data = ':'.join(['CSE', self.__tablename__, self.name]) return qrcode.make(data, version=1, box_size=5) def __str__(self): return self.name def to_dict(self): return { 'id': self.id, 'name': self.name, 'qrcode': utils.image_to_base64(self.image()), } class Action(QRCodeMixin, db.Model): pass class Manufacturer(QRCodeMixin, db.Model): items = db.relationship('Item', back_populates='manufacturer') class Model(QRCodeMixin, db.Model): description = db.Column(db.Text) items = db.relationship('Item', back_populates='model') def to_dict(self): d = super().to_dict() d['description'] = self.description return d class Location(QRCodeMixin, db.Model): items = db.relationship('Item', back_populates='location') class Status(QRCodeMixin, db.Model): items = db.relationship('Item', back_populates='status') class CreatedMixin: id = db.Column(db.Integer, primary_key=True) created_at = db.Column(db.DateTime, default=utcnow()) updated_at = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow()) # Using ForeignKey and relationship in mixin requires the @declared_attr decorator # See http://docs.sqlalchemy.org/en/latest/orm/extensions/declarative/mixins.html @declared_attr def user_id(cls): return db.Column(db.Integer, db.ForeignKey('user_account.id'), nullable=False, default=utils.fetch_current_user_id) @declared_attr def user(cls): return db.relationship('User') def to_dict(self): return { 'id': self.id, 'created_at': utils.format_field(self.created_at), 'updated_at': utils.format_field(self.updated_at), 'user': str(self.user), } class Item(CreatedMixin, db.Model): __versioned__ = { 'exclude': ['created_at', 'ics_id', 'serial_number', 'manufacturer_id', 'model_id'] } # WARNING! Inheriting id from CreatedMixin doesn't play well with # SQLAlchemy-Continuum. It has to be defined here. id = db.Column(db.Integer, primary_key=True) ics_id = db.Column(db.Text, unique=True, nullable=False, index=True, default=get_temporary_ics_id) serial_number = db.Column(db.Text, nullable=False) manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id')) model_id = db.Column(db.Integer, db.ForeignKey('model.id')) location_id = db.Column(db.Integer, db.ForeignKey('location.id')) status_id = db.Column(db.Integer, db.ForeignKey('status.id')) parent_id = db.Column(db.Integer, db.ForeignKey('item.id')) manufacturer = db.relationship('Manufacturer', back_populates='items') model = db.relationship('Model', back_populates='items') location = db.relationship('Location', back_populates='items') status = db.relationship('Status', back_populates='items') children = db.relationship('Item', backref=db.backref('parent', remote_side=[id])) macs = db.relationship('Mac', backref='item') host = db.relationship('Host', uselist=False, backref='item') comments = db.relationship('ItemComment', backref='item') def __init__(self, **kwargs): # Automatically convert manufacturer/model/location/status to an # instance of their class if passed as a string for key, cls in [('manufacturer', Manufacturer), ('model', Model), ('location', Location), ('status', Status)]: if key in kwargs: kwargs[key] = utils.convert_to_model(kwargs[key], cls) super().__init__(**kwargs) def __str__(self): return str(self.ics_id) @validates('ics_id') def validate_ics_id(self, key, string): """Ensure the ICS id field matches the required format""" if string is not None: if ICS_ID_RE.fullmatch(string) is None: raise ValidationError('ICS id shall match [A-Z]{3}[0-9]{3}') return string def to_dict(self): d = super().to_dict() d.update({ 'ics_id': self.ics_id, 'serial_number': self.serial_number, 'manufacturer': utils.format_field(self.manufacturer), 'model': utils.format_field(self.model), 'location': utils.format_field(self.location), 'status': utils.format_field(self.status), 'parent': utils.format_field(self.parent), 'children': [str(child) for child in self.children], 'macs': [str(mac) for mac in self.macs], 'history': self.history(), 'comments': [str(comment) for comment in self.comments], }) return d def history(self): versions = [] for version in self.versions: # parent is an attribute used by SQLAlchemy-Continuum # version.parent refers to an ItemVersion instance (and has no link with # the item parent_id) # We need to retrieve the parent "manually" if version.parent_id is None: parent = None else: parent = Item.query.get(version.parent_id) versions.append({ 'updated_at': utils.format_field(version.updated_at), 'location': utils.format_field(version.location), 'status': utils.format_field(version.status), 'parent': utils.format_field(parent), }) return versions class ItemComment(CreatedMixin, db.Model): body = db.Column(db.Text, nullable=False) item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False) def __str__(self): return self.body def to_dict(self): d = super().to_dict() d.update({ 'body': self.body, 'item': str(self.item), }) return d class Network(CreatedMixin, db.Model): vlan_name = db.Column(CIText, nullable=False, unique=True) vlan_id = db.Column(db.Integer, nullable=False, unique=True) address = db.Column(postgresql.CIDR, nullable=False, unique=True) first_ip = db.Column(postgresql.INET, nullable=False, unique=True) last_ip = db.Column(postgresql.INET, nullable=False, unique=True) description = db.Column(db.Text) admin_only = db.Column(db.Boolean, nullable=False, default=False) scope_id = db.Column(db.Integer, db.ForeignKey('network_scope.id'), nullable=False) interfaces = db.relationship('Interface', backref='network') __table_args__ = ( sa.CheckConstraint('first_ip < last_ip', name='first_ip_less_than_last_ip'), sa.CheckConstraint('first_ip << address', name='first_ip_in_network'), sa.CheckConstraint('last_ip << address', name='last_ip_in_network'), ) def __init__(self, **kwargs): # Automatically convert scope to an instance of NetworkScope if it was passed # as a string if 'scope' in kwargs: kwargs['scope'] = utils.convert_to_model(kwargs['scope'], NetworkScope, 'name') super().__init__(**kwargs) def __str__(self): return str(self.vlan_name) @property def network_ip(self): return ipaddress.ip_network(self.address) @property def first(self): return ipaddress.ip_address(self.first_ip) @property def last(self): return ipaddress.ip_address(self.last_ip) def ip_range(self): """Return the list of IP addresses that can be assigned for this network The range is defined by the first and last IP """ return [addr for addr in self.network_ip.hosts() if self.first <= addr <= self.last] def used_ips(self): """Return the list of IP addresses in use The list is sorted """ return sorted(interface.address for interface in self.interfaces) def available_ips(self): """Return the list of IP addresses available""" return [addr for addr in self.ip_range() if addr not in self.used_ips()] def gateway(self): """Return the network gateway""" for interface in self.interfaces: if 'gateway' in [tag.name for tag in interface.tags]: return interface return None @staticmethod def ip_in_network(ip, address): """Ensure the IP is in the network :param str user_id: unicode ID of a user :returns: a tuple with the IP and network as (IPv4Address, IPv4Network) :raises: ValidationError if the IP is not in the network """ addr = ipaddress.ip_address(ip) net = ipaddress.ip_network(address) if addr not in net: raise ValidationError(f'IP address {ip} is not in network {address}') return (addr, net) @validates('first_ip') def validate_first_ip(self, key, ip): """Ensure the first IP is in the network""" self.ip_in_network(ip, self.address) return ip @validates('last_ip') def validate_last_ip(self, key, ip): """Ensure the last IP is in the network and greater than first_ip""" addr, net = self.ip_in_network(ip, self.address) if addr < self.first: raise ValidationError(f'Last IP address {ip} is less than the first address {self.first}') return ip @validates('interfaces') def validate_interfaces(self, key, interface): """Ensure the interface IP is in the network range""" addr, net = self.ip_in_network(interface.ip, self.address) if addr < self.first or addr > self.last: raise ValidationError(f'IP address {interface.ip} is not in range {self.first} - {self.last}') return interface @validates('vlan_name') def validate_vlan_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None if VLAN_NAME_RE.fullmatch(string) is None: raise ValidationError('Vlan name shall match [A-Za-z0-9\-]{3,25}') return string def to_dict(self): d = super().to_dict() d.update({ 'vlan_name': self.vlan_name, 'vlan_id': self.vlan_id, 'address': self.address, 'first_ip': self.first_ip, 'last_ip': self.last_ip, 'description': self.description, 'admin_only': self.admin_only, 'scope': utils.format_field(self.scope), 'interfaces': [str(interface) for interface in self.interfaces], }) return d # Table required for Many-to-Many relationships between interfaces and tags interfacetags_table = db.Table( 'interfacetags', db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True), db.Column('interface_id', db.Integer, db.ForeignKey('interface.id'), primary_key=True) ) class Tag(QRCodeMixin, db.Model): pass class Host(CreatedMixin, db.Model): name = db.Column(db.Text, nullable=False, unique=True) type = db.Column(db.Text) description = db.Column(db.Text) item_id = db.Column(db.Integer, db.ForeignKey('item.id')) interfaces = db.relationship('Interface', backref='host') def __str__(self): return str(self.name) @validates('name') def validate_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}') return lower_string def to_dict(self): d = super().to_dict() d.update({ 'name': self.name, 'type': self.type, 'description': self.description, 'item': utils.format_field(self.item), 'interfaces': [str(interface) for interface in self.interfaces], }) return d class Interface(CreatedMixin, db.Model): network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False) ip = db.Column(postgresql.INET, nullable=False, unique=True) name = db.Column(db.Text, nullable=False, unique=True) mac_id = db.Column(db.Integer, db.ForeignKey('mac.id')) host_id = db.Column(db.Integer, db.ForeignKey('host.id')) # Add delete and delete-orphan options to automatically delete cnames when: # - deleting an interface # - de-associating a cname (removing it from the interface.cnames list) cnames = db.relationship('Cname', backref='interface', cascade='all, delete, delete-orphan') tags = db.relationship('Tag', secondary=interfacetags_table, lazy='subquery', backref=db.backref('interfaces', lazy=True)) def __init__(self, **kwargs): # Automatically convert network to an instance of Network if it was passed # as an address string if 'network' in kwargs: kwargs['network'] = utils.convert_to_model(kwargs['network'], Network, 'address') # WARNING! Setting self.network will call validates_interfaces in the Network class # For the validation to work, self.ip must be set before! # Ensure that ip is passed before network try: ip = kwargs.pop('ip') except KeyError: super().__init__(**kwargs) else: super().__init__(ip=ip, **kwargs) @validates('name') def validate_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}') return lower_string @property def address(self): return ipaddress.ip_address(self.ip) def __str__(self): return str(self.name) def __repr__(self): return f'Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})' def to_dict(self): d = super().to_dict() d.update({ 'network': str(self.network), 'ip': self.ip, 'name': self.name, 'mac': utils.format_field(self.mac), 'host': utils.format_field(self.host), 'cnames': [str(cname) for cname in self.cnames], 'tags': [str(tag) for tag in self.tags], }) return d class Mac(db.Model): id = db.Column(db.Integer, primary_key=True) address = db.Column(postgresql.MACADDR, nullable=False, unique=True) item_id = db.Column(db.Integer, db.ForeignKey('item.id')) interfaces = db.relationship('Interface', backref='mac') def __str__(self): return str(self.address) def to_dict(self): return { 'id': self.id, 'address': self.address, 'item': utils.format_field(self.item), 'interfaces': [str(interface) for interface in self.interfaces], } class Cname(CreatedMixin, db.Model): name = db.Column(db.Text, nullable=False, unique=True) interface_id = db.Column(db.Integer, db.ForeignKey('interface.id'), nullable=False) def __str__(self): return str(self.name) def to_dict(self): d = super().to_dict() d.update({ 'name': self.name, 'interface': str(self.interface), }) return d class NetworkScope(CreatedMixin, db.Model): __tablename__ = 'network_scope' name = db.Column(CIText, nullable=False, unique=True) first_vlan = db.Column(db.Integer, nullable=False, unique=True) last_vlan = db.Column(db.Integer, nullable=False, unique=True) supernet = db.Column(postgresql.CIDR, nullable=False, unique=True) description = db.Column(db.Text) networks = db.relationship('Network', backref='scope') __table_args__ = ( sa.CheckConstraint('first_vlan < last_vlan', name='first_vlan_less_than_last_vlan'), ) def __str__(self): return str(self.name) @property def supernet_ip(self): return ipaddress.ip_network(self.supernet) def prefix_range(self): """Return the list of subnet prefix that can be used for this network scope""" return list(range(self.supernet_ip.prefixlen + 1, 31)) def vlan_range(self): """Return the list of vlan ids that can be assigned for this network scope The range is defined by the first and last vlan """ return range(self.first_vlan, self.last_vlan + 1) def used_vlans(self): """Return the list of vlan ids in use The list is sorted """ return sorted(network.vlan_id for network in self.networks) def available_vlans(self): """Return the list of vlan ids available""" return [vlan for vlan in self.vlan_range() if vlan not in self.used_vlans()] def used_subnets(self): """Return the list of subnets in use The list is sorted """ return sorted(network.network_ip for network in self.networks) def available_subnets(self, prefix): """Return the list of available subnets with the given prefix""" return [str(subnet) for subnet in self.supernet_ip.subnets(new_prefix=prefix) if subnet not in self.used_subnets()] def to_dict(self): d = super().to_dict() d.update({ 'name': self.name, 'first_vlan': self.first_vlan, 'last_vlan': self.last_vlan, 'supernet': self.supernet, 'description': self.description, 'networks': [str(network) for network in self.networks], }) return d # call configure_mappers after defining all the models # required by sqlalchemy_continuum sa.orm.configure_mappers() ItemVersion = version_class(Item)