Forked from
ICS Control System Infrastructure / csentry
491 commits behind the upstream repository.
-
Benjamin Bertrand authoredBenjamin Bertrand authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
models.py 21.73 KiB
# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~
This module implements the models used in the app.
:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.
"""
import ipaddress
import re
import qrcode
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import validates
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy_continuum import make_versioned, version_class
from citext import CIText
from flask import current_app
from flask_login import UserMixin
from wtforms import ValidationError
from .extensions import db, login_manager, ldap_manager
from .plugins import FlaskUserPlugin
from . import utils
ICS_ID_RE = re.compile('[A-Z]{3}[0-9]{3}')
HOST_NAME_RE = re.compile('^[a-z0-9\-]{2,20}$')
VLAN_NAME_RE = re.compile('^[A-Za-z0-9\-]{3,25}$')
make_versioned(plugins=[FlaskUserPlugin()])
# See http://docs.sqlalchemy.org/en/latest/core/compiler.html#utc-timestamp-function
class utcnow(sa.sql.expression.FunctionElement):
type = sa.types.DateTime()
@sa.ext.compiler.compiles(utcnow, 'postgresql')
def pg_utcnow(element, compiler, **kw):
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
@login_manager.user_loader
def load_user(user_id):
"""User loader callback for flask-login
:param str user_id: unicode ID of a user
:returns: corresponding user object or None
"""
return User.query.get(int(user_id))
@ldap_manager.save_user
def save_user(dn, username, data, memberships):
"""User saver for flask-ldap3-login
This method is called whenever a LDAPLoginForm()
successfully validates.
"""
user = User.query.filter_by(username=username).first()
if user is None:
user = User(username=username,
name=utils.attribute_to_string(data['cn']),
email=utils.attribute_to_string(data['mail']))
# Always update the user groups to keep them up-to-date
user.groups = [utils.attribute_to_string(group['cn']) for group in memberships]
db.session.add(user)
db.session.commit()
return user
# Table required for Many-to-Many relationships between users and groups
usergroups_table = db.Table(
'usergroups',
db.Column('user_id', db.Integer, db.ForeignKey('user_account.id')),
db.Column('group_id', db.Integer, db.ForeignKey('group.id'))
)
class Group(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.Text, nullable=False, unique=True)
def __str__(self):
return self.name
def find_or_create_group(name):
"""Return the existing group or a newly created one"""
group = Group.query.filter_by(name=name).first()
return group or Group(name=name)
class User(db.Model, UserMixin):
# "user" is a reserved word in postgresql
# so let's use another name
__tablename__ = 'user_account'
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.Text, unique=True)
name = db.Column(db.Text)
email = db.Column(db.Text)
grp = db.relationship('Group', secondary=usergroups_table,
backref=db.backref('members', lazy='dynamic'))
# Proxy the 'name' attribute from the 'grp' relationship
# See http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html
groups = association_proxy('grp', 'name',
creator=find_or_create_group)
tokens = db.relationship("Token", backref="user")
comments = db.relationship('ItemComment', backref='user')
def get_id(self):
"""Return the user id as unicode
Required by flask-login
"""
return str(self.id)
@property
def csentry_groups(self):
groups = []
for key, value in current_app.config['CSENTRY_LDAP_GROUPS'].items():
if value in self.groups:
groups.append(key)
return groups
@property
def is_admin(self):
return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups
def is_member_of_one_group(self, groups):
"""Return True if the user is at least member of one of the given groups"""
names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
return bool(set(self.groups) & set(names))
def is_member_of_all_groups(self, groups):
"""Return True if the user is member of all the given groups"""
names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
return set(names).issubset(self.groups)
def __str__(self):
return self.name
class Token(db.Model):
"""Table to store valid tokens"""
id = db.Column(db.Integer, primary_key=True)
jti = db.Column(postgresql.UUID, nullable=False)
token_type = db.Column(db.Text, nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey('user_account.id'), nullable=False)
issued_at = db.Column(db.DateTime, nullable=False)
# expires can be set to None for tokens that never expire
expires = db.Column(db.DateTime)
description = db.Column(db.Text)
__table_args__ = (
sa.UniqueConstraint(jti, user_id),
)
def __str__(self):
return self.jti
def to_dict(self):
return {
'id': self.id,
'jti': self.jti,
'token_type': self.token_type,
'user_id': self.user_id,
'expires': self.expires,
'description': self.description,
}
class QRCodeMixin:
id = db.Column(db.Integer, primary_key=True)
name = db.Column(CIText, nullable=False, unique=True)
def image(self):
"""Return a QRCode image to identify a record
The QRCode includes:
- CSE prefix
- the table name
- the name of the record
"""
data = ':'.join(['CSE', self.__tablename__, self.name])
return qrcode.make(data, version=1, box_size=5)
def __str__(self):
return self.name
def to_dict(self, qrcode=False):
d = {'id': self.id, 'name': self.name}
if qrcode:
d['qrcode'] = utils.image_to_base64(self.image())
return d
class Action(QRCodeMixin, db.Model):
pass
class Manufacturer(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='manufacturer')
class Model(QRCodeMixin, db.Model):
description = db.Column(db.Text)
items = db.relationship('Item', back_populates='model')
def to_dict(self, qrcode=False):
d = super().to_dict(qrcode)
d['description'] = self.description
return d
class Location(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='location')
class Status(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='status')
class Item(db.Model):
__versioned__ = {
'exclude': ['_created', 'ics_id', 'serial_number',
'manufacturer_id', 'model_id']
}
id = db.Column(db.Integer, primary_key=True)
_created = db.Column(db.DateTime, default=utcnow())
_updated = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow())
ics_id = db.Column(db.Text, unique=True, index=True)
serial_number = db.Column(db.Text, nullable=False)
manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id'))
model_id = db.Column(db.Integer, db.ForeignKey('model.id'))
location_id = db.Column(db.Integer, db.ForeignKey('location.id'))
status_id = db.Column(db.Integer, db.ForeignKey('status.id'))
parent_id = db.Column(db.Integer, db.ForeignKey('item.id'))
manufacturer = db.relationship('Manufacturer', back_populates='items')
model = db.relationship('Model', back_populates='items')
location = db.relationship('Location', back_populates='items')
status = db.relationship('Status', back_populates='items')
children = db.relationship('Item', backref=db.backref('parent', remote_side=[id]))
macs = db.relationship('Mac', backref='item')
comments = db.relationship('ItemComment', backref='item')
def __init__(self, **kwargs):
# Automatically convert manufacturer/model/location/status to an
# instance of their class if passed as a string
for key, cls in [('manufacturer', Manufacturer),
('model', Model),
('location', Location),
('status', Status)]:
if key in kwargs:
kwargs[key] = utils.convert_to_model(kwargs[key], cls)
super().__init__(**kwargs)
def __str__(self):
return str(self.ics_id)
@validates('ics_id')
def validate_ics_id(self, key, string):
"""Ensure the ICS id field matches the required format"""
if string is not None:
if ICS_ID_RE.fullmatch(string) is None:
raise ValidationError('ICS id shall match [A-Z]{3}[0-9]{3}')
return string
def to_dict(self, long=False):
d = {
'id': self.id,
'ics_id': self.ics_id,
'serial_number': self.serial_number,
'manufacturer': utils.format_field(self.manufacturer),
'model': utils.format_field(self.model),
'location': utils.format_field(self.location),
'status': utils.format_field(self.status),
'updated': utils.format_field(self._updated),
'created': utils.format_field(self._created),
'parent': utils.format_field(self.parent),
}
if long:
d['children'] = [utils.format_field(child) for child in self.children]
d['macs'] = [mac.to_dict(long=True) for mac in self.macs]
d['history'] = self.history()
return d
def history(self):
versions = []
for version in self.versions:
# parent is an attribute used by SQLAlchemy-Continuum
# version.parent refers to an ItemVersion instance (and has no link with
# the item parent_id)
# We need to retrieve the parent "manually"
if version.parent_id is None:
parent = None
else:
parent = Item.query.get(version.parent_id)
versions.append({
'updated': utils.format_field(version._updated),
'location': utils.format_field(version.location),
'status': utils.format_field(version.status),
'parent': utils.format_field(parent),
})
return versions
class ItemComment(db.Model):
id = db.Column(db.Integer, primary_key=True)
timestamp = db.Column(db.DateTime, default=utcnow(), index=True)
text = db.Column(db.Text, nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey('user_account.id'), nullable=False)
item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False)
class Network(db.Model):
id = db.Column(db.Integer, primary_key=True)
vlan_name = db.Column(CIText, nullable=False, unique=True)
vlan_id = db.Column(db.Integer, nullable=False, unique=True)
address = db.Column(postgresql.CIDR, nullable=False, unique=True)
first_ip = db.Column(postgresql.INET, nullable=False, unique=True)
last_ip = db.Column(postgresql.INET, nullable=False, unique=True)
gateway = db.Column(postgresql.INET)
description = db.Column(db.Text)
admin_only = db.Column(db.Boolean, nullable=False, default=False)
scope_id = db.Column(db.Integer, db.ForeignKey('network_scope.id'), nullable=False)
interfaces = db.relationship('Interface', backref='network')
__table_args__ = (
sa.CheckConstraint('first_ip < last_ip', name='first_ip_less_than_last_ip'),
sa.CheckConstraint('first_ip << address', name='first_ip_in_network'),
sa.CheckConstraint('last_ip << address', name='last_ip_in_network'),
)
def __init__(self, **kwargs):
# Automatically convert scope to an instance of NetworkScope if it was passed
# as a string
if 'scope' in kwargs:
kwargs['scope'] = utils.convert_to_model(kwargs['scope'], NetworkScope, 'name')
super().__init__(**kwargs)
def __str__(self):
return str(self.vlan_name)
@property
def network_ip(self):
return ipaddress.ip_network(self.address)
@property
def first(self):
return ipaddress.ip_address(self.first_ip)
@property
def last(self):
return ipaddress.ip_address(self.last_ip)
def ip_range(self):
"""Return the list of IP addresses that can be assigned for this network
The range is defined by the first and last IP
"""
return [addr for addr in self.network_ip.hosts()
if self.first <= addr <= self.last]
def used_ips(self):
"""Return the list of IP addresses in use
The list is sorted
"""
return sorted(interface.address for interface in self.interfaces)
def available_ips(self):
"""Return the list of IP addresses available"""
return [addr for addr in self.ip_range()
if addr not in self.used_ips()]
@staticmethod
def ip_in_network(ip, address):
"""Ensure the IP is in the network
:param str user_id: unicode ID of a user
:returns: a tuple with the IP and network as (IPv4Address, IPv4Network)
:raises: ValidationError if the IP is not in the network
"""
addr = ipaddress.ip_address(ip)
net = ipaddress.ip_network(address)
if addr not in net:
raise ValidationError(f'IP address {ip} is not in network {address}')
return (addr, net)
@validates('first_ip')
def validate_first_ip(self, key, ip):
"""Ensure the first IP is in the network"""
self.ip_in_network(ip, self.address)
return ip
@validates('last_ip')
def validate_last_ip(self, key, ip):
"""Ensure the last IP is in the network and greater than first_ip"""
addr, net = self.ip_in_network(ip, self.address)
if addr < self.first:
raise ValidationError(f'Last IP address {ip} is less than the first address {self.first}')
return ip
@validates('interfaces')
def validate_interfaces(self, key, interface):
"""Ensure the interface IP is in the network range"""
addr, net = self.ip_in_network(interface.ip, self.address)
if addr < self.first or addr > self.last:
raise ValidationError(f'IP address {interface.ip} is not in range {self.first} - {self.last}')
return interface
@validates('vlan_name')
def validate_vlan_name(self, key, string):
"""Ensure the name matches the required format"""
if string is None:
return None
if VLAN_NAME_RE.fullmatch(string) is None:
raise ValidationError('Vlan name shall match [A-Za-z0-9\-]{3,25}')
return string
def to_dict(self):
return {
'id': self.id,
'vlan_name': self.vlan_name,
'address': self.address,
'first_ip': self.first_ip,
'last_ip': self.last_ip,
'gateway': self.gateway,
'vlan_id': self.vlan_id,
'description': self.description,
'admin_only': self.admin_only,
'scope': utils.format_field(self.scope),
}
# Table required for Many-to-Many relationships between interfaces and tags
interfacetags_table = db.Table(
'interfacetags',
db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True),
db.Column('interface_id', db.Integer, db.ForeignKey('interface.id'), primary_key=True)
)
class Tag(QRCodeMixin, db.Model):
pass
class Host(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.Text, nullable=False, unique=True)
type = db.Column(db.Text)
description = db.Column(db.Text)
item_id = db.Column(db.Integer, db.ForeignKey('item.id'))
interfaces = db.relationship('Interface', backref='host')
def __str__(self):
return str(self.name)
@validates('name')
def validate_name(self, key, string):
"""Ensure the name matches the required format"""
if string is None:
return None
# Force the string to lowercase
lower_string = string.lower()
if HOST_NAME_RE.fullmatch(lower_string) is None:
raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}')
return lower_string
class Interface(db.Model):
id = db.Column(db.Integer, primary_key=True)
network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False)
ip = db.Column(postgresql.INET, nullable=False, unique=True)
name = db.Column(db.Text, nullable=False, unique=True)
mac_id = db.Column(db.Integer, db.ForeignKey('mac.id'))
host_id = db.Column(db.Integer, db.ForeignKey('host.id'))
cnames = db.relationship('Cname', backref='interface')
tags = db.relationship('Tag', secondary=interfacetags_table, lazy='subquery',
backref=db.backref('interfaces', lazy=True))
def __init__(self, **kwargs):
# Automatically convert network to an instance of Network if it was passed
# as an address string
if 'network' in kwargs:
kwargs['network'] = utils.convert_to_model(kwargs['network'], Network, 'address')
# WARNING! Setting self.network will call validates_interfaces in the Network class
# For the validation to work, self.ip must be set before!
# Ensure that ip is passed before network
try:
ip = kwargs.pop('ip')
except KeyError:
super().__init__(**kwargs)
else:
super().__init__(ip=ip, **kwargs)
@validates('name')
def validate_name(self, key, string):
"""Ensure the name matches the required format"""
if string is None:
return None
# Force the string to lowercase
lower_string = string.lower()
if HOST_NAME_RE.fullmatch(lower_string) is None:
raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}')
return lower_string
@property
def address(self):
return ipaddress.ip_address(self.ip)
def __str__(self):
return str(self.name)
def __repr__(self):
return f'Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})'
def to_dict(self, long=False):
d = {
'id': self.id,
'ip': self.ip,
'name': self.name,
'network_id': self.network_id,
}
if long:
d['mac'] = getattr(self.mac, 'address', None)
return d
class Mac(db.Model):
id = db.Column(db.Integer, primary_key=True)
address = db.Column(postgresql.MACADDR, nullable=False, unique=True)
item_id = db.Column(db.Integer, db.ForeignKey('item.id'))
interfaces = db.relationship('Interface', backref='mac')
def __str__(self):
return str(self.address)
def to_dict(self, long=False):
d = {
'id': self.id,
'address': self.address,
'item_id': self.item_id,
}
if long:
d['item_ics_id'] = self.item.ics_id
d['interfaces'] = [interface.to_dict() for interface in self.interfaces]
return d
class Cname(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.Text, nullable=False, unique=True)
interface_id = db.Column(db.Integer, db.ForeignKey('interface.id'), nullable=False, unique=True)
def __str__(self):
return str(self.name)
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'interface_id': self.interface_id,
}
class NetworkScope(db.Model):
__tablename__ = 'network_scope'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(CIText, nullable=False, unique=True)
first_vlan = db.Column(db.Integer, nullable=False, unique=True)
last_vlan = db.Column(db.Integer, nullable=False, unique=True)
supernet = db.Column(postgresql.CIDR, nullable=False, unique=True)
networks = db.relationship('Network', backref='scope')
__table_args__ = (
sa.CheckConstraint('first_vlan < last_vlan', name='first_vlan_less_than_last_vlan'),
)
def __str__(self):
return str(self.name)
@property
def supernet_ip(self):
return ipaddress.ip_network(self.supernet)
def prefix_range(self):
"""Return the list of subnet prefix that can be used for this network scope"""
return list(range(self.supernet_ip.prefixlen + 1, 31))
def vlan_range(self):
"""Return the list of vlan ids that can be assigned for this network scope
The range is defined by the first and last vlan
"""
return range(self.first_vlan, self.last_vlan + 1)
def used_vlans(self):
"""Return the list of vlan ids in use
The list is sorted
"""
return sorted(network.vlan_id for network in self.networks)
def available_vlans(self):
"""Return the list of vlan ids available"""
return [vlan for vlan in self.vlan_range()
if vlan not in self.used_vlans()]
def used_subnets(self):
"""Return the list of subnets in use
The list is sorted
"""
return sorted(network.network_ip for network in self.networks)
def available_subnets(self, prefix):
"""Return the list of available subnets with the given prefix"""
return [str(subnet) for subnet in self.supernet_ip.subnets(new_prefix=prefix)
if subnet not in self.used_subnets()]
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'first_vlan': self.first_vlan,
'last_vlan': self.last_vlan,
'supernet': self.supernet,
}
# call configure_mappers after defining all the models
# required by sqlalchemy_continuum
sa.orm.configure_mappers()
ItemVersion = version_class(Item)