Newer
Older
# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~
This module implements the models used in the app.
:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.
"""
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import validates
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy_continuum import make_versioned, version_class
from citext import CIText
from wtforms import ValidationError
from .extensions import db, login_manager, ldap_manager
from .plugins import FlaskUserPlugin
ICS_ID_RE = re.compile('[A-Z]{3}[0-9]{3}')
HOST_NAME_RE = re.compile('^[a-z0-9\-]{2,20}$')
make_versioned(plugins=[FlaskUserPlugin()])
# See http://docs.sqlalchemy.org/en/latest/core/compiler.html#utc-timestamp-function
class utcnow(sa.sql.expression.FunctionElement):
type = sa.types.DateTime()
@sa.ext.compiler.compiles(utcnow, 'postgresql')
def pg_utcnow(element, compiler, **kw):
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
@login_manager.user_loader
def load_user(user_id):
"""User loader callback for flask-login
:param str user_id: unicode ID of a user
:returns: corresponding user object or None
"""
return User.query.get(int(user_id))
@ldap_manager.save_user
def save_user(dn, username, data, memberships):
"""User saver for flask-ldap3-login
This method is called whenever a LDAPLoginForm()
successfully validates.
"""
user = User.query.filter_by(username=username).first()
if user is None:
name=utils.attribute_to_string(data['cn']),
email=utils.attribute_to_string(data['mail']))
# Always update the user groups to keep them up-to-date
user.groups = [utils.attribute_to_string(group['cn']) for group in memberships]
db.session.add(user)
db.session.commit()
# Table required for Many-to-Many relationships between users and groups
usergroups_table = db.Table(
'usergroups',
db.Column('user_id', db.Integer, db.ForeignKey('user_account.id')),
db.Column('group_id', db.Integer, db.ForeignKey('group.id'))
)
class Group(db.Model):
name = db.Column(db.Text, nullable=False, unique=True)
def find_or_create_group(name):
"""Return the existing group or a newly created one"""
group = Group.query.filter_by(name=name).first()
return group or Group(name=name)
class User(db.Model, UserMixin):
# "user" is a reserved word in postgresql
# so let's use another name
__tablename__ = 'user_account'
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.Text, unique=True)
name = db.Column(db.Text)
email = db.Column(db.Text)
grp = db.relationship('Group', secondary=usergroups_table,
backref=db.backref('members', lazy='dynamic'))
# Proxy the 'name' attribute from the 'grp' relationship
# See http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html
groups = association_proxy('grp', 'name',
creator=find_or_create_group)
tokens = db.relationship("Token", backref="user")
comments = db.relationship('ItemComment', backref='user')
def get_id(self):
"""Return the user id as unicode
Required by flask-login
"""
return str(self.id)
@property
def csentry_groups(self):
groups = []
for key, value in current_app.config['CSENTRY_LDAP_GROUPS'].items():
if value in self.groups:
groups.append(key)
return groups
return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups
def is_member_of_one_group(self, groups):
"""Return True if the user is at least member of one of the given groups"""
names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
return bool(set(self.groups) & set(names))
def is_member_of_all_groups(self, groups):
"""Return True if the user is member of all the given groups"""
names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
return set(names).issubset(self.groups)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class Token(db.Model):
"""Table to store valid tokens"""
id = db.Column(db.Integer, primary_key=True)
jti = db.Column(postgresql.UUID, nullable=False)
token_type = db.Column(db.Text, nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey('user_account.id'), nullable=False)
issued_at = db.Column(db.DateTime, nullable=False)
# expires can be set to None for tokens that never expire
expires = db.Column(db.DateTime)
description = db.Column(db.Text)
__table_args__ = (
sa.UniqueConstraint(jti, user_id),
)
def __str__(self):
return self.jti
def to_dict(self):
return {
'id': self.id,
'jti': self.jti,
'token_type': self.token_type,
'user_id': self.user_id,
'expires': self.expires,
'description': self.description,
}
class QRCodeMixin:
id = db.Column(db.Integer, primary_key=True)
name = db.Column(CIText, nullable=False, unique=True)
"""Return a QRCode image to identify a record
The QRCode includes:
- the table name
- the name of the record
"""
data = ':'.join(['CSE', self.__tablename__, self.name])
def to_dict(self, qrcode=False):
d = {'id': self.id, 'name': self.name}
if qrcode:
d['qrcode'] = utils.image_to_base64(self.image())
return d
class Manufacturer(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='manufacturer')
items = db.relationship('Item', back_populates='model')
def to_dict(self, qrcode=False):
d = super().to_dict(qrcode)
d['description'] = self.description
return d
class Location(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='location')
class Status(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='status')
class Item(db.Model):
'exclude': ['_created', 'ics_id', 'serial_number',
'manufacturer_id', 'model_id']
_created = db.Column(db.DateTime, default=utcnow())
_updated = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow())
ics_id = db.Column(db.Text, unique=True, index=True)
serial_number = db.Column(db.Text, nullable=False)
manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id'))
model_id = db.Column(db.Integer, db.ForeignKey('model.id'))
location_id = db.Column(db.Integer, db.ForeignKey('location.id'))
status_id = db.Column(db.Integer, db.ForeignKey('status.id'))
parent_id = db.Column(db.Integer, db.ForeignKey('item.id'))
manufacturer = db.relationship('Manufacturer', back_populates='items')
model = db.relationship('Model', back_populates='items')
location = db.relationship('Location', back_populates='items')
status = db.relationship('Status', back_populates='items')
children = db.relationship('Item', backref=db.backref('parent', remote_side=[id]))
macs = db.relationship('Mac', backref='item')
comments = db.relationship('ItemComment', backref='item')
def __init__(self, **kwargs):
# Automatically convert manufacturer/model/location/status to an
# instance of their class if passed as a string
for key, cls in [('manufacturer', Manufacturer),
('model', Model),
('location', Location),
('status', Status)]:
if key in kwargs:
kwargs[key] = utils.convert_to_model(kwargs[key], cls)
super().__init__(**kwargs)
return str(self.ics_id)
@validates('ics_id')
def validate_ics_id(self, key, string):
"""Ensure the ICS id field matches the required format"""
if string is not None:
if ICS_ID_RE.fullmatch(string) is None:
raise ValidationError('ICS id shall match [A-Z]{3}[0-9]{3}')
return string
'ics_id': self.ics_id,
'serial_number': self.serial_number,
'manufacturer': utils.format_field(self.manufacturer),
'model': utils.format_field(self.model),
'location': utils.format_field(self.location),
'status': utils.format_field(self.status),
'updated': utils.format_field(self._updated),
'created': utils.format_field(self._created),
'parent': utils.format_field(self.parent),
d['children'] = [utils.format_field(child) for child in self.children]
d['macs'] = [mac.to_dict(long=True) for mac in self.macs]
d['history'] = self.history()
return d
def history(self):
versions = []
for version in self.versions:
# parent is an attribute used by SQLAlchemy-Continuum
# version.parent refers to an ItemVersion instance (and has no link with
# the item parent_id)
# We need to retrieve the parent "manually"
if version.parent_id is None:
parent = None
else:
parent = Item.query.get(version.parent_id)
versions.append({
'updated': utils.format_field(version._updated),
'location': utils.format_field(version.location),
'status': utils.format_field(version.status),
'parent': utils.format_field(parent),
})
return versions
class ItemComment(db.Model):
id = db.Column(db.Integer, primary_key=True)
timestamp = db.Column(db.DateTime, default=utcnow(), index=True)
text = db.Column(db.Text, nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey('user_account.id'), nullable=False)
item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False)
class Network(db.Model):
id = db.Column(db.Integer, primary_key=True)
vlan_name = db.Column(CIText, nullable=False, unique=True)
vlan_id = db.Column(db.Integer, nullable=False, unique=True)
address = db.Column(postgresql.CIDR, nullable=False, unique=True)
first_ip = db.Column(postgresql.INET, nullable=False, unique=True)
last_ip = db.Column(postgresql.INET, nullable=False, unique=True)
description = db.Column(db.Text)
admin_only = db.Column(db.Boolean, nullable=False, default=False)
scope_id = db.Column(db.Integer, db.ForeignKey('network_scope.id'), nullable=False)
interfaces = db.relationship('Interface', backref='network')
sa.CheckConstraint('first_ip < last_ip', name='first_ip_less_than_last_ip'),
sa.CheckConstraint('first_ip << address', name='first_ip_in_network'),
sa.CheckConstraint('last_ip << address', name='last_ip_in_network'),
def __init__(self, **kwargs):
# Automatically convert scope to an instance of NetworkScope if it was passed
# as a string
if 'scope' in kwargs:
kwargs['scope'] = utils.convert_to_model(kwargs['scope'], NetworkScope, 'name')
super().__init__(**kwargs)
@property
def network_ip(self):
return ipaddress.ip_network(self.address)
def first(self):
return ipaddress.ip_address(self.first_ip)
def last(self):
return ipaddress.ip_address(self.last_ip)
def ip_range(self):
"""Return the list of IP addresses that can be assigned for this network
The range is defined by the first and last IP
"""
if self.first <= addr <= self.last]
def used_ips(self):
"""Return the list of IP addresses in use
The list is sorted
"""
return sorted(interface.address for interface in self.interfaces)
def available_ips(self):
"""Return the list of IP addresses available"""
return [addr for addr in self.ip_range()
if addr not in self.used_ips()]
"""Ensure the IP is in the network
:param str user_id: unicode ID of a user
:returns: a tuple with the IP and network as (IPv4Address, IPv4Network)
:raises: ValidationError if the IP is not in the network
"""
addr = ipaddress.ip_address(ip)
net = ipaddress.ip_network(address)
raise ValidationError(f'IP address {ip} is not in network {address}')
return (addr, net)
@validates('first_ip')
def validate_first_ip(self, key, ip):
"""Ensure the first IP is in the network"""
self.ip_in_network(ip, self.address)
@validates('last_ip')
def validate_last_ip(self, key, ip):
"""Ensure the last IP is in the network and greater than first_ip"""
addr, net = self.ip_in_network(ip, self.address)
if addr < self.first:
raise ValidationError(f'Last IP address {ip} is less than the first address {self.first}')
@validates('interfaces')
def validate_interfaces(self, key, interface):
"""Ensure the interface IP is in the network range"""
addr, net = self.ip_in_network(interface.ip, self.address)
if addr < self.first or addr > self.last:
raise ValidationError(f'IP address {interface.ip} is not in range {self.first} - {self.last}')
return interface
def to_dict(self):
return {
'id': self.id,
'vlan_name': self.vlan_name,
'address': self.address,
'first_ip': self.first_ip,
'last_ip': self.last_ip,
'vlan_id': self.vlan_id,
'description': self.description,
'admin_only': self.admin_only,
'scope': utils.format_field(self.scope),
# Table required for Many-to-Many relationships between interfaces and tags
interfacetags_table = db.Table(
'interfacetags',
db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True),
db.Column('interface_id', db.Integer, db.ForeignKey('interface.id'), primary_key=True)
)
class Tag(QRCodeMixin, db.Model):
pass
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.Text, nullable=False, unique=True)
type = db.Column(db.Text)
description = db.Column(db.Text)
item_id = db.Column(db.Integer, db.ForeignKey('item.id'))
interfaces = db.relationship('Interface', backref='host')
def __str__(self):
return str(self.name)
@validates('name')
def validate_name(self, key, string):
"""Ensure the name matches the required format"""
if string is None:
return None
# Force the string to lowercase
lower_string = string.lower()
if HOST_NAME_RE.fullmatch(lower_string) is None:
raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}')
return lower_string
class Interface(db.Model):
id = db.Column(db.Integer, primary_key=True)
network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False)
ip = db.Column(postgresql.INET, nullable=False, unique=True)
name = db.Column(db.Text, nullable=False, unique=True)
mac_id = db.Column(db.Integer, db.ForeignKey('mac.id'))
host_id = db.Column(db.Integer, db.ForeignKey('host.id'))
cnames = db.relationship('Cname', backref='interface')
tags = db.relationship('Tag', secondary=interfacetags_table, lazy='subquery',
backref=db.backref('interfaces', lazy=True))
def __init__(self, **kwargs):
# Automatically convert network to an instance of Network if it was passed
kwargs['network'] = utils.convert_to_model(kwargs['network'], Network, 'address')
# WARNING! Setting self.network will call validates_interfaces in the Network class
# For the validation to work, self.ip must be set before!
# Ensure that ip is passed before network
try:
ip = kwargs.pop('ip')
except KeyError:
super().__init__(**kwargs)
else:
super().__init__(ip=ip, **kwargs)
@validates('name')
def validate_name(self, key, string):
"""Ensure the name matches the required format"""
if string is None:
return None
# Force the string to lowercase
lower_string = string.lower()
if HOST_NAME_RE.fullmatch(lower_string) is None:
raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}')
@property
def address(self):
return ipaddress.ip_address(self.ip)
return f'Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})'
def to_dict(self, long=False):
d = {
'id': self.id,
'ip': self.ip,
'name': self.name,
'network_id': self.network_id,
}
if long:
d['mac'] = getattr(self.mac, 'address', None)
return d
class Mac(db.Model):
id = db.Column(db.Integer, primary_key=True)
address = db.Column(postgresql.MACADDR, nullable=False, unique=True)
item_id = db.Column(db.Integer, db.ForeignKey('item.id'))
interfaces = db.relationship('Interface', backref='mac')
def __str__(self):
return str(self.address)
def to_dict(self, long=False):
d = {
'id': self.id,
'address': self.address,
'item_id': self.item_id,
}
if long:
d['item_ics_id'] = self.item.ics_id
d['interfaces'] = [interface.to_dict() for interface in self.interfaces]
class Cname(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.Text, nullable=False, unique=True)
interface_id = db.Column(db.Integer, db.ForeignKey('interface.id'), nullable=False, unique=True)
def __str__(self):
return str(self.name)
def to_dict(self):
return {
'id': self.id,
'name': self.name,
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
}
class NetworkScope(db.Model):
__tablename__ = 'network_scope'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(CIText, nullable=False, unique=True)
first_vlan = db.Column(db.Integer, nullable=False, unique=True)
last_vlan = db.Column(db.Integer, nullable=False, unique=True)
subnet = db.Column(postgresql.CIDR, nullable=False, unique=True)
networks = db.relationship('Network', backref='scope')
def __str__(self):
return str(self.name)
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'first_vlan': self.first_vlan,
'last_vlan': self.last_vlan,
'subnet': self.subnet,
}
# call configure_mappers after defining all the models
# required by sqlalchemy_continuum
sa.orm.configure_mappers()