Forked from
ICS Control System Infrastructure / csentry
357 commits behind the upstream repository.
-
Benjamin Bertrand authored
"type" field on host replaced with foreign key "machine_type_id" JIRA INFRA-281
Benjamin Bertrand authored"type" field on host replaced with foreign key "machine_type_id" JIRA INFRA-281
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
models.py 26.86 KiB
# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~
This module implements the models used in the app.
:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.
"""
import ipaddress
import string
import qrcode
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import validates
from sqlalchemy_continuum import make_versioned, version_class
from citext import CIText
from flask import current_app
from flask_login import UserMixin
from wtforms import ValidationError
from .extensions import db, login_manager, ldap_manager, cache
from .plugins import FlaskUserPlugin
from .validators import ICS_ID_RE, HOST_NAME_RE, VLAN_NAME_RE, MAC_ADDRESS_RE
from . import utils
make_versioned(plugins=[FlaskUserPlugin()])
# See http://docs.sqlalchemy.org/en/latest/core/compiler.html#utc-timestamp-function
class utcnow(sa.sql.expression.FunctionElement):
type = sa.types.DateTime()
@sa.ext.compiler.compiles(utcnow, 'postgresql')
def pg_utcnow(element, compiler, **kw):
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
def temporary_ics_ids():
"""Generator that returns the full list of temporary ICS ids"""
return (f'{current_app.config["TEMPORARY_ICS_ID"]}{letter}{number:0=3d}'
for letter in string.ascii_uppercase
for number in range(0, 1000))
def used_temporary_ics_ids():
"""Return a set with the temporary ICS ids used"""
temporary_items = Item.query.filter(
Item.ics_id.startswith(
current_app.config['TEMPORARY_ICS_ID'])).all()
return {item.ics_id for item in temporary_items}
def get_temporary_ics_id():
"""Return a temporary ICS id that is available"""
used_temp_ics_ids = used_temporary_ics_ids()
for ics_id in temporary_ics_ids():
if ics_id not in used_temp_ics_ids:
return ics_id
else:
raise ValueError('No temporary ICS id available')
@login_manager.user_loader
@cache.memoize(timeout=1800)
def load_user(user_id):
"""User loader callback for flask-login
:param str user_id: unicode ID of a user
:returns: corresponding user object or None
"""
return User.query.get(int(user_id))
@ldap_manager.save_user
def save_user(dn, username, data, memberships):
"""User saver for flask-ldap3-login
This method is called whenever a LDAPLoginForm()
successfully validates.
"""
user = User.query.filter_by(username=username).first()
if user is None:
user = User(username=username,
display_name=utils.attribute_to_string(data['cn']),
email=utils.attribute_to_string(data['mail']))
# Always update the user groups to keep them up-to-date
user.groups = sorted([utils.attribute_to_string(group['cn']) for group in memberships])
db.session.add(user)
db.session.commit()
return user
class User(db.Model, UserMixin):
# "user" is a reserved word in postgresql
# so let's use another name
__tablename__ = 'user_account'
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.Text, nullable=False, unique=True)
display_name = db.Column(db.Text, nullable=False)
email = db.Column(db.Text)
groups = db.Column(postgresql.ARRAY(db.Text), default=[])
tokens = db.relationship("Token", backref="user")
def get_id(self):
"""Return the user id as unicode
Required by flask-login
"""
return str(self.id)
@property
def csentry_groups(self):
groups = []
for key, values in current_app.config['CSENTRY_LDAP_GROUPS'].items():
for value in values:
if value in self.groups:
groups.append(key)
return groups
@property
def is_admin(self):
for group in current_app.config['CSENTRY_LDAP_GROUPS']['admin']:
if group in self.groups:
return True
return False
def is_member_of_one_group(self, groups):
"""Return True if the user is at least member of one of the given groups"""
names = []
for group in groups:
names.extend(current_app.config['CSENTRY_LDAP_GROUPS'].get(group))
return bool(set(self.groups) & set(names))
def __str__(self):
return self.username
def to_dict(self):
return {
'id': self.id,
'username': self.username,
'display_name': self.display_name,
'email': self.email,
'groups': self.csentry_groups,
}
class Token(db.Model):
"""Table to store valid tokens"""
id = db.Column(db.Integer, primary_key=True)
jti = db.Column(postgresql.UUID, nullable=False)
token_type = db.Column(db.Text, nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey('user_account.id'), nullable=False)
issued_at = db.Column(db.DateTime, nullable=False)
# expires can be set to None for tokens that never expire
expires = db.Column(db.DateTime)
description = db.Column(db.Text)
__table_args__ = (
sa.UniqueConstraint(jti, user_id),
)
def __str__(self):
return self.jti
class QRCodeMixin:
id = db.Column(db.Integer, primary_key=True)
name = db.Column(CIText, nullable=False, unique=True)
description = db.Column(db.Text)
def image(self):
"""Return a QRCode image to identify a record
The QRCode includes:
- CSE prefix
- the table name
- the name of the record
"""
data = ':'.join(['CSE', self.__tablename__, self.name])
return qrcode.make(data, version=1, box_size=5)
@cache.memoize(timeout=0)
def base64_image(self):
"""Return the QRCode image as base64 string"""
return utils.image_to_base64(self.image())
def __str__(self):
return self.name
def __repr__(self):
# The cache.memoize decorator performs a repr() on the passed in arguments
# __repr__ is used as part of the cache key and shall be a uniquely identifying string
# See https://flask-caching.readthedocs.io/en/latest/#memoization
return f'{self.__class__.__name__}(id={self.id}, name={self.name})'
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'description': self.description,
'qrcode': self.base64_image(),
}
class Action(QRCodeMixin, db.Model):
pass
class Manufacturer(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='manufacturer')
class Model(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='model')
class Location(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='location')
class Status(QRCodeMixin, db.Model):
items = db.relationship('Item', back_populates='status')
class CreatedMixin:
id = db.Column(db.Integer, primary_key=True)
created_at = db.Column(db.DateTime, default=utcnow())
updated_at = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow())
# Using ForeignKey and relationship in mixin requires the @declared_attr decorator
# See http://docs.sqlalchemy.org/en/latest/orm/extensions/declarative/mixins.html
@declared_attr
def user_id(cls):
return db.Column(db.Integer, db.ForeignKey('user_account.id'),
nullable=False, default=utils.fetch_current_user_id)
@declared_attr
def user(cls):
return db.relationship('User')
def __init__(self, **kwargs):
# Automatically convert created_at/updated_at strings
# to datetime object
for key in ('created_at', 'updated_at'):
if key in kwargs:
if isinstance(kwargs[key], str):
kwargs[key] = utils.parse_to_utc(kwargs[key])
super().__init__(**kwargs)
def to_dict(self):
return {
'id': self.id,
'created_at': utils.format_field(self.created_at),
'updated_at': utils.format_field(self.updated_at),
'user': str(self.user),
}
class Item(CreatedMixin, db.Model):
__versioned__ = {
'exclude': ['created_at', 'user_id', 'ics_id', 'serial_number',
'manufacturer_id', 'model_id']
}
# WARNING! Inheriting id from CreatedMixin doesn't play well with
# SQLAlchemy-Continuum. It has to be defined here.
id = db.Column(db.Integer, primary_key=True)
ics_id = db.Column(db.Text, unique=True, nullable=False,
index=True, default=get_temporary_ics_id)
serial_number = db.Column(db.Text, nullable=False)
quantity = db.Column(db.Integer, nullable=False, default=1)
manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id'))
model_id = db.Column(db.Integer, db.ForeignKey('model.id'))
location_id = db.Column(db.Integer, db.ForeignKey('location.id'))
status_id = db.Column(db.Integer, db.ForeignKey('status.id'))
parent_id = db.Column(db.Integer, db.ForeignKey('item.id'))
manufacturer = db.relationship('Manufacturer', back_populates='items')
model = db.relationship('Model', back_populates='items')
location = db.relationship('Location', back_populates='items')
status = db.relationship('Status', back_populates='items')
children = db.relationship('Item', backref=db.backref('parent', remote_side=[id]))
macs = db.relationship('Mac', backref='item')
host = db.relationship('Host', uselist=False, backref='item')
comments = db.relationship('ItemComment', backref='item')
def __init__(self, **kwargs):
# Automatically convert manufacturer/model/location/status to an
# instance of their class if passed as a string
for key, cls in [('manufacturer', Manufacturer),
('model', Model),
('location', Location),
('status', Status)]:
if key in kwargs:
kwargs[key] = utils.convert_to_model(kwargs[key], cls)
super().__init__(**kwargs)
def __str__(self):
return str(self.ics_id)
@validates('ics_id')
def validate_ics_id(self, key, string):
"""Ensure the ICS id field matches the required format"""
if string is not None:
if ICS_ID_RE.fullmatch(string) is None:
raise ValidationError('ICS id shall match [A-Z]{3}[0-9]{3}')
return string
def to_dict(self):
d = super().to_dict()
d.update({
'ics_id': self.ics_id,
'serial_number': self.serial_number,
'quantity': self.quantity,
'manufacturer': utils.format_field(self.manufacturer),
'model': utils.format_field(self.model),
'location': utils.format_field(self.location),
'status': utils.format_field(self.status),
'parent': utils.format_field(self.parent),
'children': [str(child) for child in self.children],
'macs': [str(mac) for mac in self.macs],
'history': self.history(),
'comments': [str(comment) for comment in self.comments],
})
return d
def history(self):
versions = []
for version in self.versions:
# parent is an attribute used by SQLAlchemy-Continuum
# version.parent refers to an ItemVersion instance (and has no link with
# the item parent_id)
# We need to retrieve the parent "manually"
if version.parent_id is None:
parent = None
else:
parent = Item.query.get(version.parent_id)
versions.append({
'updated_at': utils.format_field(version.updated_at),
'quantity': version.quantity,
'location': utils.format_field(version.location),
'status': utils.format_field(version.status),
'parent': utils.format_field(parent),
})
return versions
class ItemComment(CreatedMixin, db.Model):
body = db.Column(db.Text, nullable=False)
item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False)
def __str__(self):
return self.body
def to_dict(self):
d = super().to_dict()
d.update({
'body': self.body,
'item': str(self.item),
})
return d
class Network(CreatedMixin, db.Model):
vlan_name = db.Column(CIText, nullable=False, unique=True)
vlan_id = db.Column(db.Integer, nullable=False, unique=True)
address = db.Column(postgresql.CIDR, nullable=False, unique=True)
first_ip = db.Column(postgresql.INET, nullable=False, unique=True)
last_ip = db.Column(postgresql.INET, nullable=False, unique=True)
description = db.Column(db.Text)
admin_only = db.Column(db.Boolean, nullable=False, default=False)
scope_id = db.Column(db.Integer, db.ForeignKey('network_scope.id'), nullable=False)
domain_id = db.Column(db.Integer, db.ForeignKey('domain.id'), nullable=False)
interfaces = db.relationship('Interface', backref='network')
__table_args__ = (
sa.CheckConstraint('first_ip < last_ip', name='first_ip_less_than_last_ip'),
sa.CheckConstraint('first_ip << address', name='first_ip_in_network'),
sa.CheckConstraint('last_ip << address', name='last_ip_in_network'),
)
def __init__(self, **kwargs):
# Automatically convert scope to an instance of NetworkScope if it was passed
# as a string
if 'scope' in kwargs:
kwargs['scope'] = utils.convert_to_model(kwargs['scope'], NetworkScope, 'name')
# If domain_id is not passed, we set it to the network scope value
if 'domain_id' not in kwargs:
kwargs['domain_id'] = kwargs['scope'].domain_id
super().__init__(**kwargs)
def __str__(self):
return str(self.vlan_name)
@property
def network_ip(self):
return ipaddress.ip_network(self.address)
@property
def netmask(self):
return self.network_ip.netmask
@property
def first(self):
return ipaddress.ip_address(self.first_ip)
@property
def last(self):
return ipaddress.ip_address(self.last_ip)
def ip_range(self):
"""Return the list of IP addresses that can be assigned for this network
The range is defined by the first and last IP
"""
return [addr for addr in self.network_ip.hosts()
if self.first <= addr <= self.last]
def used_ips(self):
"""Return the list of IP addresses in use
The list is sorted
"""
return sorted(interface.address for interface in self.interfaces)
def available_ips(self):
"""Return the list of IP addresses available"""
return [addr for addr in self.ip_range()
if addr not in self.used_ips()]
def gateway(self):
"""Return the network gateway"""
for interface in self.interfaces:
if 'gateway' in [tag.name for tag in interface.tags]:
return interface
return None
@staticmethod
def ip_in_network(ip, address):
"""Ensure the IP is in the network
:param str user_id: unicode ID of a user
:returns: a tuple with the IP and network as (IPv4Address, IPv4Network)
:raises: ValidationError if the IP is not in the network
"""
addr = ipaddress.ip_address(ip)
net = ipaddress.ip_network(address)
if addr not in net:
raise ValidationError(f'IP address {ip} is not in network {address}')
return (addr, net)
@validates('first_ip')
def validate_first_ip(self, key, ip):
"""Ensure the first IP is in the network"""
self.ip_in_network(ip, self.address)
return ip
@validates('last_ip')
def validate_last_ip(self, key, ip):
"""Ensure the last IP is in the network and greater than first_ip"""
addr, net = self.ip_in_network(ip, self.address)
if addr < self.first:
raise ValidationError(f'Last IP address {ip} is less than the first address {self.first}')
return ip
@validates('interfaces')
def validate_interfaces(self, key, interface):
"""Ensure the interface IP is in the network range"""
addr, net = self.ip_in_network(interface.ip, self.address)
# Admin user can create IP outside the defined range
user = utils.cse_current_user()
if user is None or not user.is_admin:
if addr < self.first or addr > self.last:
raise ValidationError(f'IP address {interface.ip} is not in range {self.first} - {self.last}')
return interface
@validates('vlan_name')
def validate_vlan_name(self, key, string):
"""Ensure the name matches the required format"""
if string is None:
return None
if VLAN_NAME_RE.fullmatch(string) is None:
raise ValidationError('Vlan name shall match [A-Za-z0-9\-]{3,25}')
return string
def to_dict(self):
d = super().to_dict()
d.update({
'vlan_name': self.vlan_name,
'vlan_id': self.vlan_id,
'address': self.address,
'netmask': str(self.netmask),
'first_ip': self.first_ip,
'last_ip': self.last_ip,
'description': self.description,
'admin_only': self.admin_only,
'scope': utils.format_field(self.scope),
'domain': str(self.domain),
'interfaces': [str(interface) for interface in self.interfaces],
})
return d
# Table required for Many-to-Many relationships between interfaces and tags
interfacetags_table = db.Table(
'interfacetags',
db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True),
db.Column('interface_id', db.Integer, db.ForeignKey('interface.id'), primary_key=True)
)
class Tag(QRCodeMixin, db.Model):
admin_only = db.Column(db.Boolean, nullable=False, default=False)
class MachineType(db.Model):
__tablename__ = 'machine_type'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(CIText, nullable=False, unique=True)
hosts = db.relationship('Host', backref='machine_type')
def __str__(self):
return self.name
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'hosts': [str(host) for host in self.hosts]
}
class Host(CreatedMixin, db.Model):
name = db.Column(db.Text, nullable=False, unique=True)
description = db.Column(db.Text)
machine_type_id = db.Column(db.Integer, db.ForeignKey('machine_type.id'), nullable=False)
item_id = db.Column(db.Integer, db.ForeignKey('item.id'))
interfaces = db.relationship('Interface', backref='host')
def __init__(self, **kwargs):
# Automatically convert machine_type as an instance of its class if passed as a string
if 'machine_type' in kwargs:
kwargs['machine_type'] = utils.convert_to_model(kwargs['machine_type'], MachineType)
super().__init__(**kwargs)
def __str__(self):
return str(self.name)
@validates('name')
def validate_name(self, key, string):
"""Ensure the name matches the required format"""
if string is None:
return None
# Force the string to lowercase
lower_string = string.lower()
if HOST_NAME_RE.fullmatch(lower_string) is None:
raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}')
return lower_string
def to_dict(self):
d = super().to_dict()
d.update({
'name': self.name,
'machine_type': str(self.machine_type),
'description': self.description,
'item': utils.format_field(self.item),
'interfaces': [str(interface) for interface in self.interfaces],
})
return d
class Interface(CreatedMixin, db.Model):
network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False)
ip = db.Column(postgresql.INET, nullable=False, unique=True)
name = db.Column(db.Text, nullable=False, unique=True)
mac_id = db.Column(db.Integer, db.ForeignKey('mac.id'))
host_id = db.Column(db.Integer, db.ForeignKey('host.id'))
# Add delete and delete-orphan options to automatically delete cnames when:
# - deleting an interface
# - de-associating a cname (removing it from the interface.cnames list)
cnames = db.relationship('Cname', backref='interface',
cascade='all, delete, delete-orphan')
tags = db.relationship('Tag', secondary=interfacetags_table, lazy='subquery',
backref=db.backref('interfaces', lazy=True))
def __init__(self, **kwargs):
# Always set self.network and not self.network_id to call validate_interfaces
network_id = kwargs.pop('network_id', None)
if network_id is not None:
kwargs['network'] = Network.query.get(network_id)
elif 'network' in kwargs:
# Automatically convert network to an instance of Network if it was passed
# as a string
kwargs['network'] = utils.convert_to_model(kwargs['network'], Network, 'vlan_name')
# WARNING! Setting self.network will call validate_interfaces in the Network class
# For the validation to work, self.ip must be set before!
# Ensure that ip is passed before network
try:
ip = kwargs.pop('ip')
except KeyError:
super().__init__(**kwargs)
else:
super().__init__(ip=ip, **kwargs)
@validates('name')
def validate_name(self, key, string):
"""Ensure the name matches the required format"""
if string is None:
return None
# Force the string to lowercase
lower_string = string.lower()
if HOST_NAME_RE.fullmatch(lower_string) is None:
raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}')
return lower_string
@property
def address(self):
return ipaddress.ip_address(self.ip)
def __str__(self):
return str(self.name)
def __repr__(self):
return f'Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})'
def to_dict(self):
d = super().to_dict()
d.update({
'network': str(self.network),
'ip': self.ip,
'name': self.name,
'mac': utils.format_field(self.mac),
'host': utils.format_field(self.host),
'cnames': [str(cname) for cname in self.cnames],
'domain': str(self.network.domain),
'tags': [str(tag) for tag in self.tags],
})
return d
class Mac(db.Model):
id = db.Column(db.Integer, primary_key=True)
address = db.Column(postgresql.MACADDR, nullable=False, unique=True)
item_id = db.Column(db.Integer, db.ForeignKey('item.id'))
interfaces = db.relationship('Interface', backref='mac')
def __str__(self):
return str(self.address)
@validates('address')
def validate_address(self, key, string):
"""Ensure the address is a valid MAC address"""
if string is None:
return None
if MAC_ADDRESS_RE.fullmatch(string) is None:
raise ValidationError(f"'{string}' does not appear to be a MAC address")
return string
def to_dict(self):
return {
'id': self.id,
'address': self.address,
'item': utils.format_field(self.item),
'interfaces': [str(interface) for interface in self.interfaces],
}
class Cname(CreatedMixin, db.Model):
name = db.Column(db.Text, nullable=False, unique=True)
interface_id = db.Column(db.Integer, db.ForeignKey('interface.id'), nullable=False)
def __str__(self):
return str(self.name)
def to_dict(self):
d = super().to_dict()
d.update({
'name': self.name,
'interface': str(self.interface),
})
return d
class Domain(CreatedMixin, db.Model):
name = db.Column(db.Text, nullable=False, unique=True)
scopes = db.relationship('NetworkScope', backref='domain')
networks = db.relationship('Network', backref='domain')
def __str__(self):
return str(self.name)
def to_dict(self):
d = super().to_dict()
d.update({
'name': self.name,
'scopes': [str(scope) for scope in self.scopes],
'networks': [str(network) for network in self.networks],
})
return d
class NetworkScope(CreatedMixin, db.Model):
__tablename__ = 'network_scope'
name = db.Column(CIText, nullable=False, unique=True)
first_vlan = db.Column(db.Integer, nullable=False, unique=True)
last_vlan = db.Column(db.Integer, nullable=False, unique=True)
supernet = db.Column(postgresql.CIDR, nullable=False, unique=True)
domain_id = db.Column(db.Integer, db.ForeignKey('domain.id'), nullable=False)
description = db.Column(db.Text)
networks = db.relationship('Network', backref='scope')
__table_args__ = (
sa.CheckConstraint('first_vlan < last_vlan', name='first_vlan_less_than_last_vlan'),
)
def __str__(self):
return str(self.name)
@property
def supernet_ip(self):
return ipaddress.ip_network(self.supernet)
def prefix_range(self):
"""Return the list of subnet prefix that can be used for this network scope"""
return list(range(self.supernet_ip.prefixlen + 1, 31))
def vlan_range(self):
"""Return the list of vlan ids that can be assigned for this network scope
The range is defined by the first and last vlan
"""
return range(self.first_vlan, self.last_vlan + 1)
def used_vlans(self):
"""Return the list of vlan ids in use
The list is sorted
"""
return sorted(network.vlan_id for network in self.networks)
def available_vlans(self):
"""Return the list of vlan ids available"""
return [vlan for vlan in self.vlan_range()
if vlan not in self.used_vlans()]
def used_subnets(self):
"""Return the list of subnets in use
The list is sorted
"""
return sorted(network.network_ip for network in self.networks)
def available_subnets(self, prefix):
"""Return the list of available subnets with the given prefix"""
return [str(subnet) for subnet in self.supernet_ip.subnets(new_prefix=prefix)
if subnet not in self.used_subnets()]
def to_dict(self):
d = super().to_dict()
d.update({
'name': self.name,
'first_vlan': self.first_vlan,
'last_vlan': self.last_vlan,
'supernet': self.supernet,
'description': self.description,
'domain': str(self.domain),
'networks': [str(network) for network in self.networks],
})
return d
# call configure_mappers after defining all the models
# required by sqlalchemy_continuum
sa.orm.configure_mappers()
ItemVersion = version_class(Item)