Newer
Older
# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~
This module implements the models used in the app.
:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.
"""
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import validates
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy_continuum import make_versioned, version_class
from citext import CIText
from .extensions import db, login_manager, ldap_manager, jwt
from .plugins import FlaskUserPlugin
ICS_ID_RE = re.compile('[A-Z]{3}[0-9]{3}')
make_versioned(plugins=[FlaskUserPlugin()])
@login_manager.user_loader
def load_user(user_id):
"""User loader callback for flask-login
:param str user_id: unicode ID of a user
:returns: corresponding user object or None
@jwt.user_loader_callback_loader
def user_loader_callback(identity):
"""User loader callback for flask-jwt-extended
:param str identity: identity from the token (user_id)
:returns: corresponding user object or None
"""
return User.query.get(int(identity))
@ldap_manager.save_user
def save_user(dn, username, data, memberships):
"""User saver for flask-ldap3-login
This method is called whenever a LDAPLoginForm()
successfully validates.
"""
user = User.query.filter_by(username=username).first()
if user is None:
user = User(username,
name=utils.attribute_to_string(data['cn']),
email=utils.attribute_to_string(data['mail']))
# Always update the user groups to keep them up-to-date
user.groups = [utils.attribute_to_string(group['cn']) for group in memberships]
db.session.add(user)
db.session.commit()
# Table required for Many-to-Many relationships between users and groups
usergroups_table = db.Table(
'usergroups',
db.Column('user_id', db.Integer, db.ForeignKey('user_account.id')),
db.Column('group_id', db.Integer, db.ForeignKey('group.id'))
)
class Group(db.Model):
name = db.Column(db.Text, nullable=False, unique=True)
def __init__(self, name):
self.name = name
def find_or_create_group(name):
"""Return the existing group or a newly created one"""
group = Group.query.filter_by(name=name).first()
return group or Group(name=name)
class User(db.Model, UserMixin):
# "user" is a reserved word in postgresql
# so let's use another name
__tablename__ = 'user_account'
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.Text, unique=True)
name = db.Column(db.Text)
email = db.Column(db.Text)
grp = db.relationship('Group', secondary=usergroups_table,
backref=db.backref('members', lazy='dynamic'))
# Proxy the 'name' attribute from the 'grp' relationship
# See http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html
groups = association_proxy('grp', 'name',
creator=find_or_create_group)
def __init__(self, username, name, email):
self.username = username
self.name = name
self.email = email
def get_id(self):
"""Return the user id as unicode
Required by flask-login
"""
return str(self.id)
@property
def is_admin(self):
return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups
def is_member_of_one_group(self, groups):
"""Return True if the user is at least member of one of the given groups"""
names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
return bool(set(self.groups) & set(names))
def is_member_of_all_groups(self, groups):
"""Return True if the user is member of all the given groups"""
names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
return set(names).issubset(self.groups)
def __str__(self):
return self.name
class QRCodeMixin:
id = db.Column(db.Integer, primary_key=True)
name = db.Column(CIText, nullable=False, unique=True)
def __init__(self, name=None):
self.name = name
"""Return a QRCode image to identify a record
The QRCode includes:
- the table name
- the name of the record
"""
data = ':'.join(['CSE', self.__tablename__, self.name])
def to_dict(self, qrcode=False):
d = {'id': self.id, 'name': self.name}
if qrcode:
d['qrcode'] = utils.image_to_base64(self.image())
return d
class Manufacturer(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='manufacturer')
items = db.relationship('Item', back_populates='model')
def to_dict(self, qrcode=False):
d = super().to_dict(qrcode)
d['description'] = self.description
return d
class Location(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='location')
networks = db.relationship('Network', backref='location')
class Status(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='status')
class Item(db.Model):
'exclude': ['_created', 'ics_id', 'serial_number',
'manufacturer_id', 'model_id']
id = db.Column(db.Integer, primary_key=True)
_created = db.Column(db.DateTime, default=db.func.now())
_updated = db.Column(db.DateTime, default=db.func.now(), onupdate=db.func.now())
ics_id = db.Column(db.Text, unique=True, index=True)
serial_number = db.Column(db.Text, nullable=False)
manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id'))
model_id = db.Column(db.Integer, db.ForeignKey('model.id'))
location_id = db.Column(db.Integer, db.ForeignKey('location.id'))
status_id = db.Column(db.Integer, db.ForeignKey('status.id'))
parent_id = db.Column(db.Integer, db.ForeignKey('item.id'))
manufacturer = db.relationship('Manufacturer', back_populates='items')
model = db.relationship('Model', back_populates='items')
location = db.relationship('Location', back_populates='items')
status = db.relationship('Status', back_populates='items')
children = db.relationship('Item', backref=db.backref('parent', remote_side=[id]))
macs = db.relationship('Mac', backref='item')
def __init__(self, ics_id=None, serial_number=None, manufacturer=None, model=None, location=None, status=None):
# All arguments must be optional for this class to work with flask-admin!
self.ics_id = ics_id
self.manufacturer = utils.convert_to_model(manufacturer, Manufacturer)
self.model = utils.convert_to_model(model, Model)
self.location = utils.convert_to_model(location, Location)
self.status = utils.convert_to_model(status, Status)
return str(self.ics_id)
@validates('ics_id')
def validate_ics_id(self, key, string):
"""Ensure the ICS id field matches the required format"""
if string is not None:
if ICS_ID_RE.fullmatch(string) is None:
raise utils.CSEntryError('ICS id shall match [A-Z]{3}[0-9]{3}', status_code=422)
return string
'ics_id': self.ics_id,
'serial_number': self.serial_number,
'manufacturer': utils.format_field(self.manufacturer),
'model': utils.format_field(self.model),
'location': utils.format_field(self.location),
'status': utils.format_field(self.status),
'updated': utils.format_field(self._updated),
'created': utils.format_field(self._created),
'parent': utils.format_field(self.parent),
d['children'] = [utils.format_field(child) for child in self.children]
d['macs'] = [mac.to_dict(long=True) for mac in self.macs]
d['history'] = self.history()
return d
def history(self):
versions = []
for version in self.versions:
# parent is an attribute used by SQLAlchemy-Continuum
# version.parent refers to an ItemVersion instance (and has no link with
# the item parent_id)
# We need to retrieve the parent "manually"
if version.parent_id is None:
parent = None
else:
parent = Item.query.get(version.parent_id)
versions.append({
'updated': utils.format_field(version._updated),
'location': utils.format_field(version.location),
'status': utils.format_field(version.status),
'parent': utils.format_field(parent),
})
return versions
class Network(db.Model):
id = db.Column(db.Integer, primary_key=True)
label = db.Column(db.Text)
prefix = db.Column(postgresql.CIDR, nullable=False, unique=True)
first = db.Column(postgresql.INET, nullable=False, unique=True)
last = db.Column(postgresql.INET, nullable=False, unique=True)
gateway = db.Column(postgresql.INET)
vlanid = db.Column(db.Integer, unique=True)
location_id = db.Column(db.Integer, db.ForeignKey('location.id'))
hosts = db.relationship('Host', backref='network')
__table_args__ = (
sa.CheckConstraint('first < last', name='first_less_than_last'),
sa.CheckConstraint('first << prefix', name='first_in_prefix'),
sa.CheckConstraint('last << prefix', name='last_in_prefix'),
)
def __str__(self):
return str(self.prefix)
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
@staticmethod
def ip_in_network(ip, prefix):
"""Ensure the IP is in the network
:param str user_id: unicode ID of a user
:returns: a tuple with the IP and network as (IPv4Address, IPv4Network)
:raises: CSEntryError if the IP is not in the network
"""
addr = ipaddress.ip_address(ip)
net = ipaddress.ip_network(prefix)
if addr not in net:
raise utils.CSEntryError(f'IP address {ip} is not in network {prefix}', status_code=422)
return (addr, net)
@validates('first')
def validate_first(self, key, ip):
"""Ensure the first IP is in the network"""
self.ip_in_network(ip, self.prefix)
return ip
@validates('last')
def validate_last(self, key, ip):
"""Ensure the last IP is in the network"""
addr, net = self.ip_in_network(ip, self.prefix)
if addr < ipaddress.ip_address(self.first):
raise utils.CSEntryError(f'Last IP address {ip} is less than the first address {self.first}', status_code=422)
return ip
@validates('hosts')
def validate_hosts(self, key, host):
"""Ensure the host IP is in the network range"""
addr, net = self.ip_in_network(host.ip, self.prefix)
if addr < ipaddress.ip_address(self.first) or addr > ipaddress.ip_address(self.last):
raise utils.CSEntryError(f'IP address {host.ip} is not in range {self.first} - {self.last}', status_code=422)
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
return host
def to_dict(self):
return {
'id': self.id,
'label': self.label,
'prefix': self.prefix,
'first': self.first,
'last': self.last,
'gateway': self.gateway,
'vlanid': self.vlanid,
'location': utils.format_field(self.location),
}
class Host(db.Model):
id = db.Column(db.Integer, primary_key=True)
network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False)
ip = db.Column(postgresql.INET, nullable=False, unique=True)
name = db.Column(db.Text, unique=True)
mac = db.relationship('Mac', backref='host')
def __str__(self):
return str(self.ip)
def to_dict(self, long=False):
d = {
'id': self.id,
'ip': self.ip,
'name': self.name,
'network_id': self.network_id,
}
if long:
d['mac'] = getattr(self.mac, 'address', None)
return d
class Mac(db.Model):
id = db.Column(db.Integer, primary_key=True)
address = db.Column(postgresql.MACADDR, nullable=False, unique=True)
host_id = db.Column(db.Integer, db.ForeignKey('host.id'), unique=True)
item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False)
def __str__(self):
return str(self.address)
def to_dict(self, long=False):
d = {
'id': self.id,
'address': self.address,
'host_id': self.host_id,
'item_id': self.item_id,
}
if long:
d['item_ics_id'] = self.item.ics_id
try:
d['host'] = self.host.to_dict()
except AttributeError:
d['host'] = None
return d
# call configure_mappers after defining all the models
# required by sqlalchemy_continuum
sa.orm.configure_mappers()