# -*- coding: utf-8 -*- """ app.models ~~~~~~~~~~ This module implements the models used in the app. :copyright: (c) 2017 European Spallation Source ERIC :license: BSD 2-Clause, see LICENSE for more details. """ import ipaddress import re import qrcode import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy.orm import validates from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy_continuum import make_versioned, version_class from citext import CIText from flask import current_app from flask_login import UserMixin from wtforms import ValidationError from .extensions import db, login_manager, ldap_manager, jwt from .plugins import FlaskUserPlugin from . import utils ICS_ID_RE = re.compile('[A-Z]{3}[0-9]{3}') HOST_NAME_RE = re.compile('^[a-z0-9\-]{2,20}$') make_versioned(plugins=[FlaskUserPlugin()]) @login_manager.user_loader def load_user(user_id): """User loader callback for flask-login :param str user_id: unicode ID of a user :returns: corresponding user object or None """ return User.query.get(int(user_id)) @jwt.user_loader_callback_loader def user_loader_callback(identity): """User loader callback for flask-jwt-extended :param str identity: identity from the token (user_id) :returns: corresponding user object or None """ return User.query.get(int(identity)) @ldap_manager.save_user def save_user(dn, username, data, memberships): """User saver for flask-ldap3-login This method is called whenever a LDAPLoginForm() successfully validates. """ user = User.query.filter_by(username=username).first() if user is None: user = User(username=username, name=utils.attribute_to_string(data['cn']), email=utils.attribute_to_string(data['mail'])) # Always update the user groups to keep them up-to-date user.groups = [utils.attribute_to_string(group['cn']) for group in memberships] db.session.add(user) db.session.commit() return user # Table required for Many-to-Many relationships between users and groups usergroups_table = db.Table( 'usergroups', db.Column('user_id', db.Integer, db.ForeignKey('user_account.id')), db.Column('group_id', db.Integer, db.ForeignKey('group.id')) ) class Group(db.Model): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.Text, nullable=False, unique=True) def __str__(self): return self.name def find_or_create_group(name): """Return the existing group or a newly created one""" group = Group.query.filter_by(name=name).first() return group or Group(name=name) class User(db.Model, UserMixin): # "user" is a reserved word in postgresql # so let's use another name __tablename__ = 'user_account' id = db.Column(db.Integer, primary_key=True) username = db.Column(db.Text, unique=True) name = db.Column(db.Text) email = db.Column(db.Text) grp = db.relationship('Group', secondary=usergroups_table, backref=db.backref('members', lazy='dynamic')) # Proxy the 'name' attribute from the 'grp' relationship # See http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html groups = association_proxy('grp', 'name', creator=find_or_create_group) def get_id(self): """Return the user id as unicode Required by flask-login """ return str(self.id) @property def is_admin(self): return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups def is_member_of_one_group(self, groups): """Return True if the user is at least member of one of the given groups""" names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups] return bool(set(self.groups) & set(names)) def is_member_of_all_groups(self, groups): """Return True if the user is member of all the given groups""" names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups] return set(names).issubset(self.groups) def __str__(self): return self.name class QRCodeMixin: id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) def image(self): """Return a QRCode image to identify a record The QRCode includes: - CSE prefix - the table name - the name of the record """ data = ':'.join(['CSE', self.__tablename__, self.name]) return qrcode.make(data, version=1, box_size=5) def __str__(self): return self.name def to_dict(self, qrcode=False): d = {'id': self.id, 'name': self.name} if qrcode: d['qrcode'] = utils.image_to_base64(self.image()) return d class Action(QRCodeMixin, db.Model): pass class Manufacturer(QRCodeMixin, db.Model): items = db.relationship('Item', back_populates='manufacturer') class Model(QRCodeMixin, db.Model): description = db.Column(db.Text) items = db.relationship('Item', back_populates='model') def to_dict(self, qrcode=False): d = super().to_dict(qrcode) d['description'] = self.description return d class Location(QRCodeMixin, db.Model): items = db.relationship('Item', back_populates='location') class Status(QRCodeMixin, db.Model): items = db.relationship('Item', back_populates='status') class Item(db.Model): __versioned__ = { 'exclude': ['_created', 'ics_id', 'serial_number', 'manufacturer_id', 'model_id'] } id = db.Column(db.Integer, primary_key=True) _created = db.Column(db.DateTime, default=db.func.now()) _updated = db.Column(db.DateTime, default=db.func.now(), onupdate=db.func.now()) ics_id = db.Column(db.Text, unique=True, index=True) serial_number = db.Column(db.Text, nullable=False) manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id')) model_id = db.Column(db.Integer, db.ForeignKey('model.id')) location_id = db.Column(db.Integer, db.ForeignKey('location.id')) status_id = db.Column(db.Integer, db.ForeignKey('status.id')) parent_id = db.Column(db.Integer, db.ForeignKey('item.id')) manufacturer = db.relationship('Manufacturer', back_populates='items') model = db.relationship('Model', back_populates='items') location = db.relationship('Location', back_populates='items') status = db.relationship('Status', back_populates='items') children = db.relationship('Item', backref=db.backref('parent', remote_side=[id])) macs = db.relationship('Mac', backref='item') def __init__(self, **kwargs): # Automatically convert manufacturer/model/location/status to an # instance of their class if passed as a string for key, cls in [('manufacturer', Manufacturer), ('model', Model), ('location', Location), ('status', Status)]: if key in kwargs: kwargs[key] = utils.convert_to_model(kwargs[key], cls) super().__init__(**kwargs) def __str__(self): return str(self.ics_id) @validates('ics_id') def validate_ics_id(self, key, string): """Ensure the ICS id field matches the required format""" if string is not None: if ICS_ID_RE.fullmatch(string) is None: raise ValidationError('ICS id shall match [A-Z]{3}[0-9]{3}') return string def to_dict(self, long=False): d = { 'id': self.id, 'ics_id': self.ics_id, 'serial_number': self.serial_number, 'manufacturer': utils.format_field(self.manufacturer), 'model': utils.format_field(self.model), 'location': utils.format_field(self.location), 'status': utils.format_field(self.status), 'updated': utils.format_field(self._updated), 'created': utils.format_field(self._created), 'parent': utils.format_field(self.parent), } if long: d['children'] = [utils.format_field(child) for child in self.children] d['macs'] = [mac.to_dict(long=True) for mac in self.macs] d['history'] = self.history() return d def history(self): versions = [] for version in self.versions: # parent is an attribute used by SQLAlchemy-Continuum # version.parent refers to an ItemVersion instance (and has no link with # the item parent_id) # We need to retrieve the parent "manually" if version.parent_id is None: parent = None else: parent = Item.query.get(version.parent_id) versions.append({ 'updated': utils.format_field(version._updated), 'location': utils.format_field(version.location), 'status': utils.format_field(version.status), 'parent': utils.format_field(parent), }) return versions class Network(db.Model): id = db.Column(db.Integer, primary_key=True) vlan_name = db.Column(CIText, nullable=False, unique=True) vlan_id = db.Column(db.Integer, nullable=False, unique=True) address = db.Column(postgresql.CIDR, nullable=False, unique=True) first_ip = db.Column(postgresql.INET, nullable=False, unique=True) last_ip = db.Column(postgresql.INET, nullable=False, unique=True) gateway = db.Column(postgresql.INET) description = db.Column(db.Text) admin_only = db.Column(db.Boolean, nullable=False, default=False) scope_id = db.Column(db.Integer, db.ForeignKey('network_scope.id'), nullable=False) hosts = db.relationship('Host', backref='network') __table_args__ = ( sa.CheckConstraint('first_ip < last_ip', name='first_ip_less_than_last_ip'), sa.CheckConstraint('first_ip << address', name='first_ip_in_network'), sa.CheckConstraint('last_ip << address', name='last_ip_in_network'), ) def __init__(self, **kwargs): # Automatically convert scope to an instance of NetworkScope if it was passed # as a string if 'scope' in kwargs: kwargs['scope'] = utils.convert_to_model(kwargs['scope'], NetworkScope, 'name') super().__init__(**kwargs) def __str__(self): return str(self.vlan_name) @property def network_ip(self): return ipaddress.ip_network(self.address) @property def first(self): return ipaddress.ip_address(self.first_ip) @property def last(self): return ipaddress.ip_address(self.last_ip) def ip_range(self): """Return the list of IP addresses that can be assigned for this network The range is defined by the first and last IP """ return [addr for addr in self.network_ip.hosts() if self.first <= addr <= self.last] def used_ips(self): """Return the list of IP addresses in use The list is sorted """ return sorted(host.address for host in self.hosts) def available_ips(self): """Return the list of IP addresses available""" return [addr for addr in self.ip_range() if addr not in self.used_ips()] @staticmethod def ip_in_network(ip, address): """Ensure the IP is in the network :param str user_id: unicode ID of a user :returns: a tuple with the IP and network as (IPv4Address, IPv4Network) :raises: ValidationError if the IP is not in the network """ addr = ipaddress.ip_address(ip) net = ipaddress.ip_network(address) if addr not in net: raise ValidationError(f'IP address {ip} is not in network {address}') return (addr, net) @validates('first_ip') def validate_first_ip(self, key, ip): """Ensure the first IP is in the network""" self.ip_in_network(ip, self.address) return ip @validates('last_ip') def validate_last_ip(self, key, ip): """Ensure the last IP is in the network and greater than first_ip""" addr, net = self.ip_in_network(ip, self.address) if addr < self.first: raise ValidationError(f'Last IP address {ip} is less than the first address {self.first}') return ip @validates('hosts') def validate_hosts(self, key, host): """Ensure the host IP is in the network range""" addr, net = self.ip_in_network(host.ip, self.address) if addr < self.first or addr > self.last: raise ValidationError(f'IP address {host.ip} is not in range {self.first} - {self.last}') return host def to_dict(self): return { 'id': self.id, 'vlan_name': self.vlan_name, 'address': self.address, 'first_ip': self.first_ip, 'last_ip': self.last_ip, 'gateway': self.gateway, 'vlan_id': self.vlan_id, 'description': self.description, 'admin_only': self.admin_only, 'scope': utils.format_field(self.scope), } # Table required for Many-to-Many relationships between hosts and tags hosttags_table = db.Table( 'hosttags', db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True), db.Column('host_id', db.Integer, db.ForeignKey('host.id'), primary_key=True) ) class Tag(QRCodeMixin, db.Model): pass class Host(db.Model): id = db.Column(db.Integer, primary_key=True) network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False) ip = db.Column(postgresql.INET, nullable=False, unique=True) name = db.Column(db.Text, nullable=False, unique=True) description = db.Column(db.Text) mac_id = db.Column(db.Integer, db.ForeignKey('mac.id')) cnames = db.relationship('Cname', backref='host') tags = db.relationship('Tag', secondary=hosttags_table, lazy='subquery', backref=db.backref('hosts', lazy=True)) def __init__(self, **kwargs): # Automatically convert network to an instance of Network if it was passed # as an address string if 'network' in kwargs: kwargs['network'] = utils.convert_to_model(kwargs['network'], Network, 'address') # WARNING! Setting self.network will call validates_hosts in the Network class # For the validation to work, self.ip must be set before! # Ensure that ip is passed before network try: ip = kwargs.pop('ip') except KeyError: super().__init__(**kwargs) else: super().__init__(ip=ip, **kwargs) @validates('name') def validate_name(self, key, string): """Ensure the hostname matches the required format""" if string is None: return None # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: raise ValidationError('Host name shall match [a-z0-9\-]{2,20}') return lower_string @property def address(self): return ipaddress.ip_address(self.ip) def __str__(self): return str(self.name) def __repr__(self): return f'Host(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, description={self.description}, mac={self.mac})' def to_dict(self, long=False): d = { 'id': self.id, 'ip': self.ip, 'name': self.name, 'description': self.description, 'network_id': self.network_id, } if long: d['mac'] = getattr(self.mac, 'address', None) return d class Mac(db.Model): id = db.Column(db.Integer, primary_key=True) address = db.Column(postgresql.MACADDR, nullable=False, unique=True) item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False) hosts = db.relationship('Host', backref='mac') def __str__(self): return str(self.address) def to_dict(self, long=False): d = { 'id': self.id, 'address': self.address, 'item_id': self.item_id, } if long: d['item_ics_id'] = self.item.ics_id d['hosts'] = [host.to_dict() for host in self.hosts] return d class Cname(db.Model): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.Text, nullable=False, unique=True) host_id = db.Column(db.Integer, db.ForeignKey('host.id'), nullable=False, unique=True) def __str__(self): return str(self.name) def to_dict(self): return { 'id': self.id, 'name': self.name, 'host_id': self.host_id, } class NetworkScope(db.Model): __tablename__ = 'network_scope' id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) first_vlan = db.Column(db.Integer, nullable=False, unique=True) last_vlan = db.Column(db.Integer, nullable=False, unique=True) subnet = db.Column(postgresql.CIDR, nullable=False, unique=True) networks = db.relationship('Network', backref='scope') def __str__(self): return str(self.name) def to_dict(self): return { 'id': self.id, 'name': self.name, 'first_vlan': self.first_vlan, 'last_vlan': self.last_vlan, 'subnet': self.subnet, } # call configure_mappers after defining all the models # required by sqlalchemy_continuum sa.orm.configure_mappers() ItemVersion = version_class(Item)