# -*- coding: utf-8 -*- """ app.models ~~~~~~~~~~ This module implements the models used in the app. :copyright: (c) 2017 European Spallation Source ERIC :license: BSD 2-Clause, see LICENSE for more details. """ import ipaddress import re import qrcode import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy.orm import validates from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy_continuum import make_versioned, version_class from citext import CIText from flask import current_app from flask_login import UserMixin from .extensions import db, login_manager, ldap_manager, jwt from .plugins import FlaskUserPlugin from . import utils ICS_ID_RE = re.compile('[A-Z]{3}[0-9]{3}') make_versioned(plugins=[FlaskUserPlugin()]) @login_manager.user_loader def load_user(user_id): """User loader callback for flask-login :param str user_id: unicode ID of a user :returns: corresponding user object or None """ return User.query.get(int(user_id)) @jwt.user_loader_callback_loader def user_loader_callback(identity): """User loader callback for flask-jwt-extended :param str identity: identity from the token (user_id) :returns: corresponding user object or None """ return User.query.get(int(identity)) @ldap_manager.save_user def save_user(dn, username, data, memberships): """User saver for flask-ldap3-login This method is called whenever a LDAPLoginForm() successfully validates. """ user = User.query.filter_by(username=username).first() if user is None: user = User(username, name=utils.attribute_to_string(data['cn']), email=utils.attribute_to_string(data['mail'])) # Always update the user groups to keep them up-to-date user.groups = [utils.attribute_to_string(group['cn']) for group in memberships] db.session.add(user) db.session.commit() return user # Table required for Many-to-Many relationships between users and groups usergroups_table = db.Table( 'usergroups', db.Column('user_id', db.Integer, db.ForeignKey('user_account.id')), db.Column('group_id', db.Integer, db.ForeignKey('group.id')) ) class Group(db.Model): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.Text, nullable=False, unique=True) def __init__(self, name): self.name = name def __str__(self): return self.name def find_or_create_group(name): """Return the existing group or a newly created one""" group = Group.query.filter_by(name=name).first() return group or Group(name=name) class User(db.Model, UserMixin): # "user" is a reserved word in postgresql # so let's use another name __tablename__ = 'user_account' id = db.Column(db.Integer, primary_key=True) username = db.Column(db.Text, unique=True) name = db.Column(db.Text) email = db.Column(db.Text) grp = db.relationship('Group', secondary=usergroups_table, backref=db.backref('members', lazy='dynamic')) # Proxy the 'name' attribute from the 'grp' relationship # See http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html groups = association_proxy('grp', 'name', creator=find_or_create_group) def __init__(self, username, name, email): self.username = username self.name = name self.email = email def get_id(self): """Return the user id as unicode Required by flask-login """ return str(self.id) @property def is_admin(self): return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups def is_member_of_one_group(self, groups): """Return True if the user is at least member of one of the given groups""" names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups] return bool(set(self.groups) & set(names)) def is_member_of_all_groups(self, groups): """Return True if the user is member of all the given groups""" names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups] return set(names).issubset(self.groups) def __str__(self): return self.name class QRCodeMixin: id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) def __init__(self, name=None): self.name = name def image(self): """Return a QRCode image to identify a record The QRCode includes: - CSE prefix - the table name - the name of the record """ data = ':'.join(['CSE', self.__tablename__, self.name]) return qrcode.make(data, version=1, box_size=5) def __str__(self): return self.name def to_dict(self, qrcode=False): d = {'id': self.id, 'name': self.name} if qrcode: d['qrcode'] = utils.image_to_base64(self.image()) return d class Action(QRCodeMixin, db.Model): pass class Manufacturer(QRCodeMixin, db.Model): items = db.relationship('Item', back_populates='manufacturer') class Model(QRCodeMixin, db.Model): description = db.Column(db.Text) items = db.relationship('Item', back_populates='model') def to_dict(self, qrcode=False): d = super().to_dict(qrcode) d['description'] = self.description return d class Location(QRCodeMixin, db.Model): items = db.relationship('Item', back_populates='location') networks = db.relationship('Network', backref='location') class Status(QRCodeMixin, db.Model): items = db.relationship('Item', back_populates='status') class Item(db.Model): __versioned__ = { 'exclude': ['_created', 'ics_id', 'serial_number', 'manufacturer_id', 'model_id'] } id = db.Column(db.Integer, primary_key=True) _created = db.Column(db.DateTime, default=db.func.now()) _updated = db.Column(db.DateTime, default=db.func.now(), onupdate=db.func.now()) ics_id = db.Column(db.Text, unique=True, index=True) serial_number = db.Column(db.Text, nullable=False) manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id')) model_id = db.Column(db.Integer, db.ForeignKey('model.id')) location_id = db.Column(db.Integer, db.ForeignKey('location.id')) status_id = db.Column(db.Integer, db.ForeignKey('status.id')) parent_id = db.Column(db.Integer, db.ForeignKey('item.id')) manufacturer = db.relationship('Manufacturer', back_populates='items') model = db.relationship('Model', back_populates='items') location = db.relationship('Location', back_populates='items') status = db.relationship('Status', back_populates='items') children = db.relationship('Item', backref=db.backref('parent', remote_side=[id])) macs = db.relationship('Mac', backref='item') def __init__(self, ics_id=None, serial_number=None, manufacturer=None, model=None, location=None, status=None): # All arguments must be optional for this class to work with flask-admin! self.ics_id = ics_id self.serial_number = serial_number self.manufacturer = utils.convert_to_model(manufacturer, Manufacturer) self.model = utils.convert_to_model(model, Model) self.location = utils.convert_to_model(location, Location) self.status = utils.convert_to_model(status, Status) def __str__(self): return str(self.ics_id) @validates('ics_id') def validate_ics_id(self, key, string): """Ensure the ICS id field matches the required format""" if string is not None: if ICS_ID_RE.fullmatch(string) is None: raise utils.CSEntryError('ICS id shall match [A-Z]{3}[0-9]{3}', status_code=422) return string def to_dict(self, long=False): d = { 'id': self.id, 'ics_id': self.ics_id, 'serial_number': self.serial_number, 'manufacturer': utils.format_field(self.manufacturer), 'model': utils.format_field(self.model), 'location': utils.format_field(self.location), 'status': utils.format_field(self.status), 'updated': utils.format_field(self._updated), 'created': utils.format_field(self._created), 'parent': utils.format_field(self.parent), } if long: d['children'] = [utils.format_field(child) for child in self.children] d['macs'] = [mac.to_dict(long=True) for mac in self.macs] d['history'] = self.history() return d def history(self): versions = [] for version in self.versions: # parent is an attribute used by SQLAlchemy-Continuum # version.parent refers to an ItemVersion instance (and has no link with # the item parent_id) # We need to retrieve the parent "manually" if version.parent_id is None: parent = None else: parent = Item.query.get(version.parent_id) versions.append({ 'updated': utils.format_field(version._updated), 'location': utils.format_field(version.location), 'status': utils.format_field(version.status), 'parent': utils.format_field(parent), }) return versions class Network(db.Model): id = db.Column(db.Integer, primary_key=True) label = db.Column(db.Text) prefix = db.Column(postgresql.CIDR, nullable=False, unique=True) first = db.Column(postgresql.INET, nullable=False, unique=True) last = db.Column(postgresql.INET, nullable=False, unique=True) gateway = db.Column(postgresql.INET) vlanid = db.Column(db.Integer, unique=True) location_id = db.Column(db.Integer, db.ForeignKey('location.id')) hosts = db.relationship('Host', backref='network') __table_args__ = ( sa.CheckConstraint('first < last', name='first_less_than_last'), sa.CheckConstraint('first << prefix', name='first_in_prefix'), sa.CheckConstraint('last << prefix', name='last_in_prefix'), ) def __str__(self): return str(self.prefix) @staticmethod def ip_in_network(ip, prefix): """Ensure the IP is in the network :param str user_id: unicode ID of a user :returns: a tuple with the IP and network as (IPv4Address, IPv4Network) :raises: CSEntryError if the IP is not in the network """ addr = ipaddress.ip_address(ip) net = ipaddress.ip_network(prefix) if addr not in net: raise utils.CSEntryError(f'IP address {ip} is not in network {prefix}', status_code=422) return (addr, net) @validates('first') def validate_first(self, key, ip): """Ensure the first IP is in the network""" self.ip_in_network(ip, self.prefix) return ip @validates('last') def validate_last(self, key, ip): """Ensure the last IP is in the network""" addr, net = self.ip_in_network(ip, self.prefix) if addr < ipaddress.ip_address(self.first): raise utils.CSEntryError(f'Last IP address {ip} is less than the first address {self.first}', status_code=422) return ip @validates('hosts') def validate_hosts(self, key, host): """Ensure the host IP is in the network range""" addr, net = self.ip_in_network(host.ip, self.prefix) if addr < ipaddress.ip_address(self.first) or addr > ipaddress.ip_address(self.last): raise utils.CSEntryError(f'IP address {host.ip} is not in range {self.first} - {self.last}', status_code=422) return host def to_dict(self): return { 'id': self.id, 'label': self.label, 'prefix': self.prefix, 'first': self.first, 'last': self.last, 'gateway': self.gateway, 'vlanid': self.vlanid, 'location': utils.format_field(self.location), } class Host(db.Model): id = db.Column(db.Integer, primary_key=True) network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False) ip = db.Column(postgresql.INET, nullable=False, unique=True) name = db.Column(db.Text, unique=True) mac = db.relationship('Mac', backref='host') def __str__(self): return str(self.ip) def to_dict(self, long=False): d = { 'id': self.id, 'ip': self.ip, 'name': self.name, 'network_id': self.network_id, } if long: d['mac'] = getattr(self.mac, 'address', None) return d class Mac(db.Model): id = db.Column(db.Integer, primary_key=True) address = db.Column(postgresql.MACADDR, nullable=False, unique=True) host_id = db.Column(db.Integer, db.ForeignKey('host.id'), unique=True) item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False) def __str__(self): return str(self.address) def to_dict(self, long=False): d = { 'id': self.id, 'address': self.address, 'host_id': self.host_id, 'item_id': self.item_id, } if long: d['item_ics_id'] = self.item.ics_id try: d['host'] = self.host.to_dict() except AttributeError: d['host'] = None return d # call configure_mappers after defining all the models # required by sqlalchemy_continuum sa.orm.configure_mappers() ItemVersion = version_class(Item)