Skip to content
Snippets Groups Projects
models.py 5.42 KiB
Newer Older
Benjamin Bertrand's avatar
Benjamin Bertrand committed
# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~

This module implements the models used in the app.

:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.

"""
import uuid
import qrcode
from sqlalchemy.types import TypeDecorator, CHAR
from sqlalchemy.dialects.postgresql import UUID
from flask_login import UserMixin
from .extensions import db, login_manager, ldap_manager
from . import utils


class GUID(TypeDecorator):
    """Platform-independent GUID type.

    Uses Postgresql's UUID type, otherwise uses
    CHAR(32), storing as stringified hex values.

    From http://docs.sqlalchemy.org/en/rel_0_9/core/custom_types.html?highlight=guid#backend-agnostic-guid-type
    """
    impl = CHAR

    def load_dialect_impl(self, dialect):
        if dialect.name == 'postgresql':
            return dialect.type_descriptor(UUID())
        else:
            return dialect.type_descriptor(CHAR(32))

    def process_bind_param(self, value, dialect):
        if value is None:
            return value
        elif dialect.name == 'postgresql':
            return str(value)
        else:
            if not isinstance(value, uuid.UUID):
                return "%.32x" % uuid.UUID(value).int
            else:
                # hexstring
                return "%.32x" % value.int

    def process_result_value(self, value, dialect):
        if value is None:
            return value
        else:
            return uuid.UUID(value)


@login_manager.user_loader
def load_user(user_id):
    """User loader callback for flask-login

    :param str user_id: unicode ID of a user
    :returns: corresponding user object
    """
    return User.query.get(int(user_id))


@ldap_manager.save_user
def save_user(dn, username, data, memberships):
    """User saver for flask-ldap3-login

    This method is called whenever a LDAPLoginForm()
    successfully validates.
    """
    user = User.query.filter_by(username=username).first()
    if user is None:
        user = User(username, name=data['cn'], email=data['mail'])
        db.session.add(user)
        db.session.commit()
    else:
        pass
        # TODO: update the user in the database?
        # probably not needed for the name and email fields
        # maybe when we add groups from LDAP
    return user


class Role(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(50), unique=True)
    users = db.relationship('User', backref='role')

    def __str__(self):
        return self.name


class User(db.Model, UserMixin):
    # "user" is a reserved word in postgresql
    # so let's use another name
    __tablename__ = 'user_account'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(50), unique=True)
    name = db.Column(db.String(100))
    email = db.Column(db.String(100))
    role_id = db.Column(db.Integer, db.ForeignKey('role.id'))

    def __init__(self, username, name, email, role='user'):
        self.username = username
        self.role = Role.query.filter_by(name=role).first()
        self.name = name
        self.email = email

    def get_id(self):
        """Return the user id as unicode

        Required by flask-login
        """
        return str(self.id)

    @property
    def is_admin(self):
        return self.role.name == 'admin'

    def __str__(self):
        return self.name


class QRCodeMixin:
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(50), nullable=False, unique=True)

    def image(self):
        data = ','.join([self.code, str(self.id), self.name])
        return qrcode.make(data, version=1, box_size=5)


class Action(QRCodeMixin, db.Model):
    code = 'AC'


class Vendor(QRCodeMixin, db.Model):
    code = 'VD'
    items = db.relationship('Item', back_populates='vendor')


class Model(QRCodeMixin, db.Model):
    code = 'MO'
    items = db.relationship('Item', back_populates='model')


class Location(QRCodeMixin, db.Model):
    code = 'LO'
    items = db.relationship('Item', back_populates='location')


class Status(QRCodeMixin, db.Model):
    code = 'ST'
    items = db.relationship('Item', back_populates='status')


class Item(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    _created = db.Column(db.DateTime, default=db.func.now())
    _updated = db.Column(db.DateTime, default=db.func.now(), onupdate=db.func.now())
    name = db.Column(db.String(100))
    serial_number = db.Column(db.String(100), nullable=False)
    hash = db.Column(GUID, unique=True)
    vendor_id = db.Column(db.Integer, db.ForeignKey('vendor.id'))
    model_id = db.Column(db.Integer, db.ForeignKey('model.id'))
    location_id = db.Column(db.Integer, db.ForeignKey('location.id'))
    status_id = db.Column(db.Integer, db.ForeignKey('status.id'))
    parent_id = db.Column(db.Integer, db.ForeignKey('item.id'))

    vendor = db.relationship('Vendor', back_populates='items')
    model = db.relationship('Model', back_populates='items')
    location = db.relationship('Location', back_populates='items')
    status = db.relationship('Status', back_populates='items')
    children = db.relationship('Item', backref=db.backref('parent', remote_side=[id]))

    def __init__(self, name, serial_number, vendor, model, location, status):
        self.name = name
        self.serial_number = serial_number
        self.vendor = vendor
        self.model = model
        self.location = location
        self.status = status
        self.hash = utils.compute_hash(serial_number)