Skip to content
Snippets Groups Projects
models.py 9.22 KiB
Newer Older
Benjamin Bertrand's avatar
Benjamin Bertrand committed
# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~

This module implements the models used in the app.

:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.

"""
Benjamin Bertrand's avatar
Benjamin Bertrand committed
import qrcode
import sqlalchemy as sa
from sqlalchemy.orm import validates
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy_continuum import make_versioned, version_class
from citext import CIText
from flask import current_app
Benjamin Bertrand's avatar
Benjamin Bertrand committed
from flask_login import UserMixin
from .extensions import db, login_manager, ldap_manager, jwt
from .plugins import FlaskUserPlugin
Benjamin Bertrand's avatar
Benjamin Bertrand committed
from . import utils


ICS_ID_RE = re.compile('[A-Z]{3}[0-9]{3}')
make_versioned(plugins=[FlaskUserPlugin()])
Benjamin Bertrand's avatar
Benjamin Bertrand committed


@login_manager.user_loader
def load_user(user_id):
    """User loader callback for flask-login

    :param str user_id: unicode ID of a user
    :returns: corresponding user object or None
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    """
    return User.query.get(int(user_id))


@jwt.user_loader_callback_loader
def user_loader_callback(identity):
    """User loader callback for flask-jwt-extended

    :param str identity: identity from the token (user_id)
    :returns: corresponding user object or None
    """
    return User.query.get(int(identity))


Benjamin Bertrand's avatar
Benjamin Bertrand committed
@ldap_manager.save_user
def save_user(dn, username, data, memberships):
    """User saver for flask-ldap3-login

    This method is called whenever a LDAPLoginForm()
    successfully validates.
    """
    user = User.query.filter_by(username=username).first()
    if user is None:
        user = User(username,
                    name=utils.attribute_to_string(data['cn']),
                    email=utils.attribute_to_string(data['mail']))
    # Always update the user groups to keep them up-to-date
    user.groups = [utils.attribute_to_string(group['cn']) for group in memberships]
    db.session.add(user)
    db.session.commit()
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    return user


# Table required for Many-to-Many relationships between users and groups
usergroups_table = db.Table(
    'usergroups',
    db.Column('user_id', db.Integer, db.ForeignKey('user_account.id')),
    db.Column('group_id', db.Integer, db.ForeignKey('group.id'))
)


class Group(db.Model):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(100), nullable=False, unique=True)

    def __init__(self, name):
        self.name = name
Benjamin Bertrand's avatar
Benjamin Bertrand committed

    def __str__(self):
        return self.name


def find_or_create_group(name):
    """Return the existing group or a newly created one"""
    group = Group.query.filter_by(name=name).first()
    return group or Group(name=name)


Benjamin Bertrand's avatar
Benjamin Bertrand committed
class User(db.Model, UserMixin):
    # "user" is a reserved word in postgresql
    # so let's use another name
    __tablename__ = 'user_account'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(50), unique=True)
    name = db.Column(db.String(100))
    email = db.Column(db.String(100))
    grp = db.relationship('Group', secondary=usergroups_table,
                          backref=db.backref('members', lazy='dynamic'))
    # Proxy the 'name' attribute from the 'grp' relationship
    # See http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html
    groups = association_proxy('grp', 'name',
                               creator=find_or_create_group)

    def __init__(self, username, name, email):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        self.username = username
        self.name = name
        self.email = email

    def get_id(self):
        """Return the user id as unicode

        Required by flask-login
        """
        return str(self.id)

    @property
    def is_admin(self):
        return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups

    def is_member_of_one_group(self, groups):
        """Return True if the user is at least member of one of the given groups"""
        names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
        return bool(set(self.groups) & set(names))

    def is_member_of_all_groups(self, groups):
        """Return True if the user is member of all the given groups"""
        names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
        return set(names).issubset(self.groups)
Benjamin Bertrand's avatar
Benjamin Bertrand committed

    def __str__(self):
        return self.name


class QRCodeMixin:
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(CIText, nullable=False, unique=True)
Benjamin Bertrand's avatar
Benjamin Bertrand committed

    def __init__(self, name=None):
        self.name = name

Benjamin Bertrand's avatar
Benjamin Bertrand committed
    def image(self):
        """Return a QRCode image to identify a record

        The QRCode includes:
             - ICS prefix
             - the table name
             - the name of the record
        """
        data = ':'.join(['ICS', self.__tablename__, self.name])
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        return qrcode.make(data, version=1, box_size=5)

Benjamin Bertrand's avatar
Benjamin Bertrand committed
    def __str__(self):
        return self.name

    def to_dict(self, qrcode=False):
        d = {'id': self.id, 'name': self.name}
        if qrcode:
            d['qrcode'] = utils.image_to_base64(self.image())
        return d
Benjamin Bertrand's avatar
Benjamin Bertrand committed

class Action(QRCodeMixin, db.Model):
class Manufacturer(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='manufacturer')
Benjamin Bertrand's avatar
Benjamin Bertrand committed


class Model(QRCodeMixin, db.Model):
    description = db.Column(db.Text)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    items = db.relationship('Item', back_populates='model')

    def to_dict(self, qrcode=False):
        d = super().to_dict(qrcode)
        d['description'] = self.description
        return d
Benjamin Bertrand's avatar
Benjamin Bertrand committed

class Location(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='location')


class Status(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='status')


class Item(db.Model):
        'exclude': ['_created', 'ics_id', 'serial_number',
                    'manufacturer_id', 'model_id']
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    id = db.Column(db.Integer, primary_key=True)
    _created = db.Column(db.DateTime, default=db.func.now())
    _updated = db.Column(db.DateTime, default=db.func.now(), onupdate=db.func.now())
    ics_id = db.Column(db.String(6), unique=True, index=True)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    serial_number = db.Column(db.String(100), nullable=False)
    manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id'))
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    model_id = db.Column(db.Integer, db.ForeignKey('model.id'))
    location_id = db.Column(db.Integer, db.ForeignKey('location.id'))
    status_id = db.Column(db.Integer, db.ForeignKey('status.id'))
    parent_id = db.Column(db.Integer, db.ForeignKey('item.id'))

    manufacturer = db.relationship('Manufacturer', back_populates='items')
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    model = db.relationship('Model', back_populates='items')
    location = db.relationship('Location', back_populates='items')
    status = db.relationship('Status', back_populates='items')
    children = db.relationship('Item', backref=db.backref('parent', remote_side=[id]))

    def __init__(self, ics_id=None, serial_number=None, manufacturer=None, model=None, location=None, status=None):
        # All arguments must be optional for this class to work with flask-admin!
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        self.serial_number = serial_number
        self.manufacturer = utils.convert_to_model(manufacturer, Manufacturer)
        self.model = utils.convert_to_model(model, Model)
        self.location = utils.convert_to_model(location, Location)
        self.status = utils.convert_to_model(status, Status)
Benjamin Bertrand's avatar
Benjamin Bertrand committed

    def __str__(self):
        return str(self.ics_id)

    @validates('ics_id')
    def validate_ics_id(self, key, string):
        """Ensure the ICS id field matches the required format"""
        if string is not None:
            if ICS_ID_RE.fullmatch(string) is None:
                raise utils.CSEntryError('ICS id shall match [A-Z]{3}[0-9]{3}', status_code=422)
    def to_dict(self, extra=False):
        d = {
            'serial_number': self.serial_number,
            'manufacturer': utils.format_field(self.manufacturer),
            'model': utils.format_field(self.model),
            'location': utils.format_field(self.location),
            'status': utils.format_field(self.status),
            'updated': utils.format_field(self._updated),
            'created': utils.format_field(self._created),
            'parent': utils.format_field(self.parent),
        if extra:
            d['children'] = [utils.format_field(child) for child in self.children]
            d['history'] = self.history()
        return d

    def history(self):
        versions = []
        for version in self.versions:
            # parent is an attribute used by SQLAlchemy-Continuum
            # version.parent refers to an ItemVersion instance (and has no link with
            # the item parent_id)
            # We need to retrieve the parent "manually"
            if version.parent_id is None:
                parent = None
            else:
                parent = Item.query.get(version.parent_id)
            versions.append({
                'updated': utils.format_field(version._updated),
                'location': utils.format_field(version.location),
                'status': utils.format_field(version.status),
                'parent': utils.format_field(parent),
            })
        return versions


# call configure_mappers after defining all the models
# required by sqlalchemy_continuum
sa.orm.configure_mappers()
ItemVersion = version_class(Item)