Skip to content
Snippets Groups Projects
models.py 13.4 KiB
Newer Older
Benjamin Bertrand's avatar
Benjamin Bertrand committed
# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~

This module implements the models used in the app.

:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.

"""
import ipaddress
Benjamin Bertrand's avatar
Benjamin Bertrand committed
import qrcode
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import validates
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy_continuum import make_versioned, version_class
from citext import CIText
from flask import current_app
Benjamin Bertrand's avatar
Benjamin Bertrand committed
from flask_login import UserMixin
from wtforms import ValidationError
from .extensions import db, login_manager, ldap_manager, jwt
from .plugins import FlaskUserPlugin
Benjamin Bertrand's avatar
Benjamin Bertrand committed
from . import utils


ICS_ID_RE = re.compile('[A-Z]{3}[0-9]{3}')
make_versioned(plugins=[FlaskUserPlugin()])
Benjamin Bertrand's avatar
Benjamin Bertrand committed


@login_manager.user_loader
def load_user(user_id):
    """User loader callback for flask-login

    :param str user_id: unicode ID of a user
    :returns: corresponding user object or None
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    """
    return User.query.get(int(user_id))


@jwt.user_loader_callback_loader
def user_loader_callback(identity):
    """User loader callback for flask-jwt-extended

    :param str identity: identity from the token (user_id)
    :returns: corresponding user object or None
    """
    return User.query.get(int(identity))


Benjamin Bertrand's avatar
Benjamin Bertrand committed
@ldap_manager.save_user
def save_user(dn, username, data, memberships):
    """User saver for flask-ldap3-login

    This method is called whenever a LDAPLoginForm()
    successfully validates.
    """
    user = User.query.filter_by(username=username).first()
    if user is None:
        user = User(username,
                    name=utils.attribute_to_string(data['cn']),
                    email=utils.attribute_to_string(data['mail']))
    # Always update the user groups to keep them up-to-date
    user.groups = [utils.attribute_to_string(group['cn']) for group in memberships]
    db.session.add(user)
    db.session.commit()
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    return user


# Table required for Many-to-Many relationships between users and groups
usergroups_table = db.Table(
    'usergroups',
    db.Column('user_id', db.Integer, db.ForeignKey('user_account.id')),
    db.Column('group_id', db.Integer, db.ForeignKey('group.id'))
)


class Group(db.Model):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.Text, nullable=False, unique=True)

    def __init__(self, name):
        self.name = name
Benjamin Bertrand's avatar
Benjamin Bertrand committed

    def __str__(self):
        return self.name


def find_or_create_group(name):
    """Return the existing group or a newly created one"""
    group = Group.query.filter_by(name=name).first()
    return group or Group(name=name)


Benjamin Bertrand's avatar
Benjamin Bertrand committed
class User(db.Model, UserMixin):
    # "user" is a reserved word in postgresql
    # so let's use another name
    __tablename__ = 'user_account'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.Text, unique=True)
    name = db.Column(db.Text)
    email = db.Column(db.Text)
    grp = db.relationship('Group', secondary=usergroups_table,
                          backref=db.backref('members', lazy='dynamic'))
    # Proxy the 'name' attribute from the 'grp' relationship
    # See http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html
    groups = association_proxy('grp', 'name',
                               creator=find_or_create_group)

    def __init__(self, username, name, email):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        self.username = username
        self.name = name
        self.email = email

    def get_id(self):
        """Return the user id as unicode

        Required by flask-login
        """
        return str(self.id)

    @property
    def is_admin(self):
        return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups

    def is_member_of_one_group(self, groups):
        """Return True if the user is at least member of one of the given groups"""
        names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
        return bool(set(self.groups) & set(names))

    def is_member_of_all_groups(self, groups):
        """Return True if the user is member of all the given groups"""
        names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
        return set(names).issubset(self.groups)
Benjamin Bertrand's avatar
Benjamin Bertrand committed

    def __str__(self):
        return self.name


class QRCodeMixin:
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(CIText, nullable=False, unique=True)
Benjamin Bertrand's avatar
Benjamin Bertrand committed

    def __init__(self, name=None):
        self.name = name

Benjamin Bertrand's avatar
Benjamin Bertrand committed
    def image(self):
        """Return a QRCode image to identify a record

        The QRCode includes:
             - CSE prefix
             - the table name
             - the name of the record
        """
        data = ':'.join(['CSE', self.__tablename__, self.name])
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        return qrcode.make(data, version=1, box_size=5)

Benjamin Bertrand's avatar
Benjamin Bertrand committed
    def __str__(self):
        return self.name

    def to_dict(self, qrcode=False):
        d = {'id': self.id, 'name': self.name}
        if qrcode:
            d['qrcode'] = utils.image_to_base64(self.image())
        return d
Benjamin Bertrand's avatar
Benjamin Bertrand committed

class Action(QRCodeMixin, db.Model):
class Manufacturer(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='manufacturer')
Benjamin Bertrand's avatar
Benjamin Bertrand committed


class Model(QRCodeMixin, db.Model):
    description = db.Column(db.Text)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    items = db.relationship('Item', back_populates='model')

    def to_dict(self, qrcode=False):
        d = super().to_dict(qrcode)
        d['description'] = self.description
        return d
Benjamin Bertrand's avatar
Benjamin Bertrand committed

class Location(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='location')
    networks = db.relationship('Network', backref='location')
Benjamin Bertrand's avatar
Benjamin Bertrand committed


class Status(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='status')


class Item(db.Model):
        'exclude': ['_created', 'ics_id', 'serial_number',
                    'manufacturer_id', 'model_id']
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    id = db.Column(db.Integer, primary_key=True)
    _created = db.Column(db.DateTime, default=db.func.now())
    _updated = db.Column(db.DateTime, default=db.func.now(), onupdate=db.func.now())
    ics_id = db.Column(db.Text, unique=True, index=True)
    serial_number = db.Column(db.Text, nullable=False)
    manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id'))
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    model_id = db.Column(db.Integer, db.ForeignKey('model.id'))
    location_id = db.Column(db.Integer, db.ForeignKey('location.id'))
    status_id = db.Column(db.Integer, db.ForeignKey('status.id'))
    parent_id = db.Column(db.Integer, db.ForeignKey('item.id'))

    manufacturer = db.relationship('Manufacturer', back_populates='items')
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    model = db.relationship('Model', back_populates='items')
    location = db.relationship('Location', back_populates='items')
    status = db.relationship('Status', back_populates='items')
    children = db.relationship('Item', backref=db.backref('parent', remote_side=[id]))
    macs = db.relationship('Mac', backref='item')
Benjamin Bertrand's avatar
Benjamin Bertrand committed

    def __init__(self, ics_id=None, serial_number=None, manufacturer=None, model=None, location=None, status=None):
        # All arguments must be optional for this class to work with flask-admin!
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        self.serial_number = serial_number
        self.manufacturer = utils.convert_to_model(manufacturer, Manufacturer)
        self.model = utils.convert_to_model(model, Model)
        self.location = utils.convert_to_model(location, Location)
        self.status = utils.convert_to_model(status, Status)
Benjamin Bertrand's avatar
Benjamin Bertrand committed

    def __str__(self):
        return str(self.ics_id)

    @validates('ics_id')
    def validate_ics_id(self, key, string):
        """Ensure the ICS id field matches the required format"""
        if string is not None:
            if ICS_ID_RE.fullmatch(string) is None:
                raise ValidationError('ICS id shall match [A-Z]{3}[0-9]{3}')
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    def to_dict(self, long=False):
            'serial_number': self.serial_number,
            'manufacturer': utils.format_field(self.manufacturer),
            'model': utils.format_field(self.model),
            'location': utils.format_field(self.location),
            'status': utils.format_field(self.status),
            'updated': utils.format_field(self._updated),
            'created': utils.format_field(self._created),
            'parent': utils.format_field(self.parent),
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        if long:
            d['children'] = [utils.format_field(child) for child in self.children]
            d['macs'] = [mac.to_dict(long=True) for mac in self.macs]
            d['history'] = self.history()
        return d

    def history(self):
        versions = []
        for version in self.versions:
            # parent is an attribute used by SQLAlchemy-Continuum
            # version.parent refers to an ItemVersion instance (and has no link with
            # the item parent_id)
            # We need to retrieve the parent "manually"
            if version.parent_id is None:
                parent = None
            else:
                parent = Item.query.get(version.parent_id)
            versions.append({
                'updated': utils.format_field(version._updated),
                'location': utils.format_field(version.location),
                'status': utils.format_field(version.status),
                'parent': utils.format_field(parent),
            })
        return versions
class Network(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    label = db.Column(db.Text)
    prefix = db.Column(postgresql.CIDR, nullable=False, unique=True)
    first = db.Column(postgresql.INET, nullable=False, unique=True)
    last = db.Column(postgresql.INET, nullable=False, unique=True)
    gateway = db.Column(postgresql.INET)
    vlanid = db.Column(db.Integer, unique=True)
    location_id = db.Column(db.Integer, db.ForeignKey('location.id'))

    hosts = db.relationship('Host', backref='network')

    __table_args__ = (
        sa.CheckConstraint('first < last', name='first_less_than_last'),
        sa.CheckConstraint('first << prefix', name='first_in_prefix'),
        sa.CheckConstraint('last << prefix', name='last_in_prefix'),
    )

    def __str__(self):
        return str(self.prefix)

    @staticmethod
    def ip_in_network(ip, prefix):
        """Ensure the IP is in the network

        :param str user_id: unicode ID of a user
        :returns: a tuple with the IP and network as (IPv4Address, IPv4Network)
        :raises: ValidationError if the IP is not in the network
        """
        addr = ipaddress.ip_address(ip)
        net = ipaddress.ip_network(prefix)
        if addr not in net:
            raise ValidationError(f'IP address {ip} is not in network {prefix}')
        return (addr, net)

    @validates('first')
    def validate_first(self, key, ip):
        """Ensure the first IP is in the network"""
        self.ip_in_network(ip, self.prefix)
        return ip

    @validates('last')
    def validate_last(self, key, ip):
        """Ensure the last IP is in the network"""
        addr, net = self.ip_in_network(ip, self.prefix)
        if addr < ipaddress.ip_address(self.first):
            raise ValidationError(f'Last IP address {ip} is less than the first address {self.first}')
    @validates('hosts')
    def validate_hosts(self, key, host):
        """Ensure the host IP is in the network range"""
        addr, net = self.ip_in_network(host.ip, self.prefix)
        if addr < ipaddress.ip_address(self.first) or addr > ipaddress.ip_address(self.last):
            raise ValidationError(f'IP address {host.ip} is not in range {self.first} - {self.last}')
        return host

    def to_dict(self):
        return {
            'id': self.id,
            'label': self.label,
            'prefix': self.prefix,
            'first': self.first,
            'last': self.last,
            'gateway': self.gateway,
            'vlanid': self.vlanid,
            'location': utils.format_field(self.location),
        }


class Host(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False)
    ip = db.Column(postgresql.INET, nullable=False, unique=True)
    name = db.Column(db.Text, unique=True)

    mac = db.relationship('Mac', backref='host')

    def __str__(self):
        return str(self.ip)

    def to_dict(self, long=False):
        d = {
            'id': self.id,
            'ip': self.ip,
            'name': self.name,
            'network_id': self.network_id,
        }
        if long:
            d['mac'] = getattr(self.mac, 'address', None)
        return d


class Mac(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    address = db.Column(postgresql.MACADDR, nullable=False, unique=True)
    host_id = db.Column(db.Integer, db.ForeignKey('host.id'), unique=True)
    item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False)

    def __str__(self):
        return str(self.address)

    def to_dict(self, long=False):
        d = {
            'id': self.id,
            'address': self.address,
            'host_id': self.host_id,
            'item_id': self.item_id,
        }
        if long:
            d['item_ics_id'] = self.item.ics_id
            try:
                d['host'] = self.host.to_dict()
            except AttributeError:
                d['host'] = None
        return d


# call configure_mappers after defining all the models
# required by sqlalchemy_continuum
sa.orm.configure_mappers()
ItemVersion = version_class(Item)