# -*- coding: utf-8 -*- """ app.models ~~~~~~~~~~ This module implements the models used in the app. :copyright: (c) 2017 European Spallation Source ERIC :license: BSD 2-Clause, see LICENSE for more details. """ import datetime import ipaddress import string import qrcode import itertools import urllib.parse import elasticsearch import sqlalchemy as sa from enum import Enum from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.dialects import postgresql from sqlalchemy.orm import validates from sqlalchemy_continuum import make_versioned, version_class from citext import CIText from flask import current_app from flask_login import UserMixin, current_user from wtforms import ValidationError from rq import Queue from .extensions import db, login_manager, ldap_manager, cache from .plugins import FlaskUserPlugin from .validators import ( ICS_ID_RE, HOST_NAME_RE, VLAN_NAME_RE, MAC_ADDRESS_RE, DEVICE_TYPE_RE, TAG_RE, ) from . import utils, search make_versioned(plugins=[FlaskUserPlugin()]) # See http://docs.sqlalchemy.org/en/latest/core/compiler.html#utc-timestamp-function class utcnow(sa.sql.expression.FunctionElement): type = sa.types.DateTime() @sa.ext.compiler.compiles(utcnow, "postgresql") def pg_utcnow(element, compiler, **kw): return "TIMEZONE('utc', CURRENT_TIMESTAMP)" def temporary_ics_ids(): """Generator that returns the full list of temporary ICS ids""" return ( f'{current_app.config["TEMPORARY_ICS_ID"]}{letter}{number:0=3d}' for letter in string.ascii_uppercase for number in range(0, 1000) ) def used_temporary_ics_ids(): """Return a set with the temporary ICS ids used""" temporary_items = Item.query.filter( Item.ics_id.startswith(current_app.config["TEMPORARY_ICS_ID"]) ).all() return {item.ics_id for item in temporary_items} def get_temporary_ics_id(): """Return a temporary ICS id that is available""" used_temp_ics_ids = used_temporary_ics_ids() for ics_id in temporary_ics_ids(): if ics_id not in used_temp_ics_ids: return ics_id else: raise ValueError("No temporary ICS id available") @login_manager.user_loader @cache.memoize(timeout=1800) def load_user(user_id): """User loader callback for flask-login :param str user_id: unicode ID of a user :returns: corresponding user object or None """ return User.query.get(int(user_id)) @ldap_manager.save_user def save_user(dn, username, data, memberships): """User saver for flask-ldap3-login This method is called whenever a LDAPLoginForm() successfully validates. """ user = User.query.filter_by(username=username).first() if user is None: user = User( username=username, display_name=utils.attribute_to_string(data["cn"]), email=utils.attribute_to_string(data["mail"]), ) # Always update the user groups to keep them up-to-date user.groups = sorted( [utils.attribute_to_string(group["cn"]) for group in memberships] ) db.session.add(user) db.session.commit() return user # Tables required for Many-to-Many relationships between users and favorites attributes favorite_manufacturers_table = db.Table( "favorite_manufacturers", db.Column( "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True ), db.Column( "manufacturer_id", db.Integer, db.ForeignKey("manufacturer.id"), primary_key=True, ), ) favorite_models_table = db.Table( "favorite_models", db.Column( "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True ), db.Column("model_id", db.Integer, db.ForeignKey("model.id"), primary_key=True), ) favorite_locations_table = db.Table( "favorite_locations", db.Column( "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True ), db.Column( "location_id", db.Integer, db.ForeignKey("location.id"), primary_key=True ), ) favorite_statuses_table = db.Table( "favorite_statuses", db.Column( "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True ), db.Column("status_id", db.Integer, db.ForeignKey("status.id"), primary_key=True), ) favorite_actions_table = db.Table( "favorite_actions", db.Column( "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True ), db.Column("action_id", db.Integer, db.ForeignKey("action.id"), primary_key=True), ) class User(db.Model, UserMixin): # "user" is a reserved word in postgresql # so let's use another name __tablename__ = "user_account" id = db.Column(db.Integer, primary_key=True) username = db.Column(db.Text, nullable=False, unique=True) display_name = db.Column(db.Text, nullable=False) email = db.Column(db.Text) groups = db.Column(postgresql.ARRAY(db.Text), default=[]) tokens = db.relationship("Token", backref="user") tasks = db.relationship("Task", backref="user") # The favorites won't be accessed very often so we load them # only when necessary (lazy=True) favorite_manufacturers = db.relationship( "Manufacturer", secondary=favorite_manufacturers_table, lazy=True, backref=db.backref("favorite_users", lazy=True), ) favorite_models = db.relationship( "Model", secondary=favorite_models_table, lazy=True, backref=db.backref("favorite_users", lazy=True), ) favorite_locations = db.relationship( "Location", secondary=favorite_locations_table, lazy=True, backref=db.backref("favorite_users", lazy=True), ) favorite_statuses = db.relationship( "Status", secondary=favorite_statuses_table, lazy=True, backref=db.backref("favorite_users", lazy=True), ) favorite_actions = db.relationship( "Action", secondary=favorite_actions_table, lazy=True, backref=db.backref("favorite_users", lazy=True), ) def get_id(self): """Return the user id as unicode Required by flask-login """ return str(self.id) @property def csentry_groups(self): """Return the list of CSEntry groups the user belong to Groups are assigned based on the CSENTRY_LDAP_GROUPS mapping with LDAP groups """ if not hasattr(self, "_csentry_groups"): self._csentry_groups = [] for csentry_group, ldap_groups in current_app.config[ "CSENTRY_LDAP_GROUPS" ].items(): if set(self.groups) & set(ldap_groups): self._csentry_groups.append(csentry_group) # Add the network group based on CSENTRY_DOMAINS_LDAP_GROUPS network_ldap_groups = set( itertools.chain( *current_app.config["CSENTRY_DOMAINS_LDAP_GROUPS"].values() ) ) if set(self.groups) & network_ldap_groups: self._csentry_groups.append("network") return self._csentry_groups @property def csentry_domains(self): """Return the list of CSEntry domains the user has access to Domains are assigned based on the CSENTRY_DOMAINS_LDAP_GROUPS mapping with LDAP groups """ if not hasattr(self, "_csentry_domains"): self._csentry_domains = [] for domain, ldap_groups in current_app.config[ "CSENTRY_DOMAINS_LDAP_GROUPS" ].items(): if set(self.groups) & set(ldap_groups): self._csentry_domains.append(domain) return self._csentry_domains @property def is_admin(self): return "admin" in self.csentry_groups def is_member_of_one_group(self, groups): """Return True if the user is at least member of one of the given CSEntry groups""" return bool(set(groups) & set(self.csentry_groups)) def has_access_to_network(self, network): """Return True if the user has access to the network - admin users have access to all networks - normal users must have access to the network domain - normal users don't have access to admin_only networks (whatever the domain) - LOGIN_DISABLED can be set to True to turn off authentication check when testing. In this case, this function always returns True. """ if current_app.config.get("LOGIN_DISABLED") or self.is_admin: return True if network is None or network.admin_only: # True is already returned for admin users return False return str(network.domain) in self.csentry_domains def favorite_attributes(self): """Return all user's favorite attributes""" favorites_list = [ self.favorite_manufacturers, self.favorite_models, self.favorite_locations, self.favorite_statuses, self.favorite_actions, ] return [favorite for favorites in favorites_list for favorite in favorites] def launch_task(self, name, func, *args, **kwargs): """Launch a task in the background using RQ The task is added to the session but not committed. """ q = Queue() job = q.enqueue(f"app.tasks.{func}", *args, **kwargs) # The status will be set to QUEUED or DEFERRED task = Task( id=job.id, name=name, command=job.get_call_string(), status=JobStatus(job.status), user=self, ) db.session.add(task) return task def get_tasks(self, all=False): """Return all tasks created by the current user If the user is admin and all is set to True, will return all tasks """ if all and self.is_admin: return Task.query.order_by(Task.created_at).all() return Task.query.filter_by(user=self).order_by(Task.created_at).all() def get_tasks_in_progress(self, name): """Return all the <name> tasks not finished or failed""" return ( Task.query.filter_by(name=name) .filter(~Task.status.in_([JobStatus.FINISHED, JobStatus.FAILED])) .all() ) def get_task_started(self, name): """Return the <name> task currently running or None""" return Task.query.filter_by(name=name, status=JobStatus.STARTED).first() def is_task_waiting(self, name): """Return True if a <name> task is waiting Waiting means: - queued - deferred if not older than 30 minutes A deferred task can stay deferred forever if the task it depends on fails. """ thirty_minutes_ago = datetime.datetime.utcnow() - datetime.timedelta(minutes=30) count = ( Task.query.filter_by(name=name) .filter( (Task.status == JobStatus.QUEUED) | ( (Task.status == JobStatus.DEFERRED) & (Task.created_at > thirty_minutes_ago) ) ) .count() ) return count > 0 def __str__(self): return self.username def to_dict(self, recursive=False): return { "id": self.id, "username": self.username, "display_name": self.display_name, "email": self.email, "groups": self.csentry_groups, } class SearchableMixin(object): """Add search capability to a class""" @classmethod def search(cls, query, page=1, per_page=20, sort=None): try: ids, total = search.query_index( cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"], query, page, per_page, sort, ) except elasticsearch.ElasticsearchException as e: # Invalid query current_app.logger.warning(e) return cls.query.filter_by(id=0), 0 if total == 0: return cls.query.filter_by(id=0), 0 when = [(value, i) for i, value in enumerate(ids)] return ( cls.query.filter(cls.id.in_(ids)).order_by(db.case(when, value=cls.id)), total, ) @classmethod def before_flush(cls, session, flush_context, instances): """Save the new/modified/deleted objects""" # The session.new / dirty / deleted lists are empty in the after_flush_postexec event. # We need to record them here session._changes = {"add_obj": [], "delete": []} for obj in itertools.chain(session.new, session.dirty): if isinstance(obj, SearchableMixin): index = ( obj.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"] ) current_app.logger.debug( f"object to add/update in the {index} index: {obj}" ) session._changes["add_obj"].append((index, obj)) for obj in session.deleted: if isinstance(obj, SearchableMixin): index = ( obj.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"] ) current_app.logger.debug( f"object to remove from the {index} index: {obj}" ) session._changes["delete"].append((index, obj.id)) @classmethod def after_flush_postexec(cls, session, flush_context): """Retrieve the new and updated objects representation""" if session._changes is None: return # - We can't call obj.to_dict() in the before_flush event because the id # hasn't been allocated yet (for new objects) and other fields haven't been updated # (default values like created_at/updated_at and some relationships). # - We can't call obj.to_dict() in the after_commit event because it would raise: # sqlalchemy.exc.InvalidRequestError: # This session is in 'committed' state; no further SQL can be emitted within this transaction. session._changes["add"] = [ (index, obj.to_dict(recursive=True)) for index, obj in session._changes["add_obj"] ] @classmethod def after_commit(cls, session): """Update the elasticsearch index""" if session._changes is None: return for index, body in session._changes["add"]: search.add_to_index(index, body) for index, id in session._changes["delete"]: search.remove_from_index(index, id) session._changes = None @classmethod def delete_index(cls, **kwargs): """Delete the index of the class""" current_app.logger.info(f"Delete the {cls.__tablename__} index") search.delete_index( cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"], **kwargs, ) @classmethod def create_index(cls, **kwargs): """Create the index of the class""" if hasattr(cls, "__mapping__"): current_app.logger.info(f"Create the {cls.__tablename__} index") search.create_index( cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"], cls.__mapping__, **kwargs, ) else: current_app.logger.info( f"No mapping defined for {cls.__tablename__}. No index created." ) @classmethod def reindex(cls): """Force to reindex all instances of the class""" current_app.logger.info(f"Force to re-index all {cls.__tablename__} instances") # Ignore index_not_found_exception cls.delete_index(ignore_unavailable=True) cls.create_index() for obj in cls.query: search.add_to_index( cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"], obj.to_dict(recursive=True), ) class Token(db.Model): """Table to store valid tokens""" id = db.Column(db.Integer, primary_key=True) jti = db.Column(postgresql.UUID, nullable=False) token_type = db.Column(db.Text, nullable=False) user_id = db.Column(db.Integer, db.ForeignKey("user_account.id"), nullable=False) issued_at = db.Column(db.DateTime, nullable=False) # expires can be set to None for tokens that never expire expires = db.Column(db.DateTime) description = db.Column(db.Text) __table_args__ = (sa.UniqueConstraint(jti, user_id),) def __str__(self): return self.jti class QRCodeMixin: id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) description = db.Column(db.Text) def image(self): """Return a QRCode image to identify a record The QRCode includes: - CSE prefix - the table name - the name of the record """ data = ":".join(["CSE", self.__tablename__, self.name]) return qrcode.make(data, version=1, box_size=5) @cache.memoize(timeout=0) def base64_image(self): """Return the QRCode image as base64 string""" return utils.image_to_base64(self.image()) def is_user_favorite(self): """Return True if the attribute is part of the current user favorites""" return current_user in self.favorite_users def __str__(self): return self.name def __repr__(self): # The cache.memoize decorator performs a repr() on the passed in arguments # __repr__ is used as part of the cache key and shall be a uniquely identifying string # See https://flask-caching.readthedocs.io/en/latest/#memoization return f"{self.__class__.__name__}(id={self.id}, name={self.name})" def to_dict(self, recursive=False): return { "id": self.id, "name": self.name, "description": self.description, "qrcode": self.base64_image(), } class Action(QRCodeMixin, db.Model): pass class Manufacturer(QRCodeMixin, db.Model): items = db.relationship("Item", back_populates="manufacturer") class Model(QRCodeMixin, db.Model): items = db.relationship("Item", back_populates="model") class Location(QRCodeMixin, db.Model): items = db.relationship("Item", back_populates="location") class Status(QRCodeMixin, db.Model): items = db.relationship("Item", back_populates="status") class CreatedMixin: id = db.Column(db.Integer, primary_key=True) created_at = db.Column(db.DateTime, default=utcnow()) updated_at = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow()) # Using ForeignKey and relationship in mixin requires the @declared_attr decorator # See http://docs.sqlalchemy.org/en/latest/orm/extensions/declarative/mixins.html @declared_attr def user_id(cls): return db.Column( db.Integer, db.ForeignKey("user_account.id"), nullable=False, default=utils.fetch_current_user_id, ) @declared_attr def user(cls): return db.relationship("User") def __init__(self, **kwargs): # Automatically convert created_at/updated_at strings # to datetime object for key in ("created_at", "updated_at"): if key in kwargs: if isinstance(kwargs[key], str): kwargs[key] = utils.parse_to_utc(kwargs[key]) super().__init__(**kwargs) def to_dict(self): return { "id": self.id, "created_at": utils.format_field(self.created_at), "updated_at": utils.format_field(self.updated_at), "user": str(self.user), } class Item(CreatedMixin, SearchableMixin, db.Model): __versioned__ = { "exclude": [ "created_at", "user_id", "ics_id", "serial_number", "manufacturer_id", "model_id", ] } __mapping__ = { "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "user": {"type": "keyword"}, "ics_id": {"type": "keyword"}, "serial_number": {"type": "keyword"}, "quantity": {"type": "long"}, "manufacturer": {"type": "keyword"}, "model": {"type": "keyword"}, "location": {"type": "keyword"}, "status": {"type": "keyword"}, "parent": {"type": "keyword"}, "children": {"type": "keyword"}, "macs": {"type": "keyword"}, "host": {"type": "keyword"}, "stack_member": {"type": "keyword"}, "history": {"enabled": False}, "comments": {"type": "text"}, } # WARNING! Inheriting id from CreatedMixin doesn't play well with # SQLAlchemy-Continuum. It has to be defined here. id = db.Column(db.Integer, primary_key=True) ics_id = db.Column( db.Text, unique=True, nullable=False, index=True, default=get_temporary_ics_id ) serial_number = db.Column(db.Text, nullable=False) quantity = db.Column(db.Integer, nullable=False, default=1) manufacturer_id = db.Column(db.Integer, db.ForeignKey("manufacturer.id")) model_id = db.Column(db.Integer, db.ForeignKey("model.id")) location_id = db.Column(db.Integer, db.ForeignKey("location.id")) status_id = db.Column(db.Integer, db.ForeignKey("status.id")) parent_id = db.Column(db.Integer, db.ForeignKey("item.id")) host_id = db.Column(db.Integer, db.ForeignKey("host.id")) stack_member = db.Column(db.SmallInteger) manufacturer = db.relationship( "Manufacturer", back_populates="items", lazy="joined" ) model = db.relationship("Model", back_populates="items", lazy="joined") location = db.relationship("Location", back_populates="items", lazy="joined") status = db.relationship("Status", back_populates="items", lazy="joined") children = db.relationship("Item", backref=db.backref("parent", remote_side=[id])) macs = db.relationship("Mac", backref="item", lazy="joined") comments = db.relationship( "ItemComment", backref="item", cascade="all, delete-orphan", lazy="joined" ) __table_args__ = ( sa.CheckConstraint( "stack_member >= 0 AND stack_member <=9", name="stack_member_range" ), sa.UniqueConstraint(host_id, stack_member, name="uq_item_host_id_stack_member"), ) def __init__(self, **kwargs): # Automatically convert manufacturer/model/location/status to an # instance of their class if passed as a string for key, cls in [ ("manufacturer", Manufacturer), ("model", Model), ("location", Location), ("status", Status), ]: if key in kwargs: kwargs[key] = utils.convert_to_model(kwargs[key], cls) super().__init__(**kwargs) def __str__(self): return str(self.ics_id) @validates("ics_id") def validate_ics_id(self, key, string): """Ensure the ICS id field matches the required format""" if string is not None: if ICS_ID_RE.fullmatch(string) is None: raise ValidationError("ICS id shall match [A-Z]{3}[0-9]{3}") return string def to_dict(self, recursive=False): d = super().to_dict() d.update( { "ics_id": self.ics_id, "serial_number": self.serial_number, "quantity": self.quantity, "manufacturer": utils.format_field(self.manufacturer), "model": utils.format_field(self.model), "location": utils.format_field(self.location), "status": utils.format_field(self.status), "parent": utils.format_field(self.parent), "children": [str(child) for child in self.children], "macs": [str(mac) for mac in self.macs], "host": utils.format_field(self.host), "stack_member": utils.format_field(self.stack_member), "history": self.history(), "comments": [str(comment) for comment in self.comments], } ) return d def to_row_dict(self): """Convert to a dict that can easily be exported to an excel row All values should be a string """ d = self.to_dict().copy() d["children"] = " ".join(d["children"]) d["macs"] = " ".join(d["macs"]) d["comments"] = "\n\n".join(d["comments"]) d["history"] = "\n".join([str(version) for version in d["history"]]) return d def history(self): versions = [] for version in self.versions: # parent is an attribute used by SQLAlchemy-Continuum # version.parent refers to an ItemVersion instance (and has no link with # the item parent_id) # We need to retrieve the parent "manually" if version.parent_id is None: parent = None else: parent = Item.query.get(version.parent_id) versions.append( { "updated_at": utils.format_field(version.updated_at), "quantity": version.quantity, "location": utils.format_field(version.location), "status": utils.format_field(version.status), "parent": utils.format_field(parent), } ) return versions class ItemComment(CreatedMixin, db.Model): body = db.Column(db.Text, nullable=False) item_id = db.Column(db.Integer, db.ForeignKey("item.id"), nullable=False) def __str__(self): return self.body def to_dict(self, recursive=False): d = super().to_dict() d.update({"body": self.body, "item": str(self.item)}) return d class Network(CreatedMixin, db.Model): vlan_name = db.Column(CIText, nullable=False, unique=True) vlan_id = db.Column(db.Integer, nullable=False, unique=True) address = db.Column(postgresql.CIDR, nullable=False, unique=True) first_ip = db.Column(postgresql.INET, nullable=False, unique=True) last_ip = db.Column(postgresql.INET, nullable=False, unique=True) description = db.Column(db.Text) admin_only = db.Column(db.Boolean, nullable=False, default=False) scope_id = db.Column(db.Integer, db.ForeignKey("network_scope.id"), nullable=False) domain_id = db.Column(db.Integer, db.ForeignKey("domain.id"), nullable=False) interfaces = db.relationship( "Interface", backref=db.backref("network", lazy="joined"), lazy=True ) __table_args__ = ( sa.CheckConstraint("first_ip < last_ip", name="first_ip_less_than_last_ip"), sa.CheckConstraint("first_ip << address", name="first_ip_in_network"), sa.CheckConstraint("last_ip << address", name="last_ip_in_network"), ) def __init__(self, **kwargs): # Automatically convert scope to an instance of NetworkScope if it was passed # as a string if "scope" in kwargs: kwargs["scope"] = utils.convert_to_model( kwargs["scope"], NetworkScope, "name" ) # If domain_id is not passed, we set it to the network scope value if "domain_id" not in kwargs: kwargs["domain_id"] = kwargs["scope"].domain_id super().__init__(**kwargs) def __str__(self): return str(self.vlan_name) @property def network_ip(self): return ipaddress.ip_network(self.address) @property def netmask(self): return self.network_ip.netmask @property def first(self): return ipaddress.ip_address(self.first_ip) @property def last(self): return ipaddress.ip_address(self.last_ip) def ip_range(self): """Return the list of IP addresses that can be assigned for this network The range is defined by the first and last IP """ return [ addr for addr in self.network_ip.hosts() if self.first <= addr <= self.last ] def used_ips(self): """Return the list of IP addresses in use The list is sorted """ return sorted(interface.address for interface in self.interfaces) def available_ips(self): """Return the list of IP addresses available""" return [addr for addr in self.ip_range() if addr not in self.used_ips()] @property def gateway(self): """Return the network gateway IP""" return list(self.network_ip.hosts())[-1] @staticmethod def ip_in_network(ip, address): """Ensure the IP is in the network :param str user_id: unicode ID of a user :returns: a tuple with the IP and network as (IPv4Address, IPv4Network) :raises: ValidationError if the IP is not in the network """ addr = ipaddress.ip_address(ip) net = ipaddress.ip_network(address) if addr not in net: raise ValidationError(f"IP address {ip} is not in network {address}") return (addr, net) @validates("first_ip") def validate_first_ip(self, key, ip): """Ensure the first IP is in the network""" self.ip_in_network(ip, self.address) return ip @validates("last_ip") def validate_last_ip(self, key, ip): """Ensure the last IP is in the network and greater than first_ip""" addr, net = self.ip_in_network(ip, self.address) if addr < self.first: raise ValidationError( f"Last IP address {ip} is less than the first address {self.first}" ) return ip @validates("interfaces") def validate_interfaces(self, key, interface): """Ensure the interface IP is in the network range""" addr, net = self.ip_in_network(interface.ip, self.address) # Admin user can create IP outside the defined range try: # current_user is a local proxy and is not # valid outside of a request context. is_admin = current_user.is_admin except AttributeError: is_admin = False if not is_admin: if addr < self.first or addr > self.last: raise ValidationError( f"IP address {interface.ip} is not in range {self.first} - {self.last}" ) return interface @validates("vlan_name") def validate_vlan_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None if VLAN_NAME_RE.fullmatch(string) is None: raise ValidationError("Vlan name shall match [A-Za-z0-9\-]{3,25}") return string def to_dict(self, recursive=False): d = super().to_dict() d.update( { "vlan_name": self.vlan_name, "vlan_id": self.vlan_id, "address": self.address, "netmask": str(self.netmask), "first_ip": self.first_ip, "last_ip": self.last_ip, "gateway": str(self.gateway), "description": self.description, "admin_only": self.admin_only, "scope": utils.format_field(self.scope), "domain": str(self.domain), "interfaces": [str(interface) for interface in self.interfaces], } ) return d # Table required for Many-to-Many relationships between interfaces and tags interfacetags_table = db.Table( "interfacetags", db.Column("tag_id", db.Integer, db.ForeignKey("tag.id"), primary_key=True), db.Column( "interface_id", db.Integer, db.ForeignKey("interface.id"), primary_key=True ), ) class Tag(QRCodeMixin, db.Model): admin_only = db.Column(db.Boolean, nullable=False, default=False) @validates("name") def validate_name(self, key, string): """Ensure the name field matches the required format""" if string is not None: if TAG_RE.fullmatch(string) is None: raise ValidationError(f"'{string}' is an invalid tag name") return string class DeviceType(db.Model): __tablename__ = "device_type" id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) hosts = db.relationship( "Host", backref=db.backref("device_type", lazy="joined"), lazy=True ) @validates("name") def validate_name(self, key, string): """Ensure the name field matches the required format""" if string is not None: if DEVICE_TYPE_RE.fullmatch(string) is None: raise ValidationError(f"'{string}' is an invalid device type name") return string def __str__(self): return self.name def to_dict(self, recursive=False): return { "id": self.id, "name": self.name, "hosts": [str(host) for host in self.hosts], } # Table required for Many-to-Many relationships between Ansible parent and child groups ansible_groups_parent_child_table = db.Table( "ansible_groups_parent_child", db.Column( "parent_group_id", db.Integer, db.ForeignKey("ansible_group.id"), primary_key=True, ), db.Column( "child_group_id", db.Integer, db.ForeignKey("ansible_group.id"), primary_key=True, ), ) # Table required for Many-to-Many relationships between Ansible groups and hosts ansible_groups_hosts_table = db.Table( "ansible_groups_hosts", db.Column( "ansible_group_id", db.Integer, db.ForeignKey("ansible_group.id"), primary_key=True, ), db.Column("host_id", db.Integer, db.ForeignKey("host.id"), primary_key=True), ) class AnsibleGroupType(Enum): STATIC = "STATIC" NETWORK_SCOPE = "NETWORK_SCOPE" NETWORK = "NETWORK" DEVICE_TYPE = "DEVICE_TYPE" def __str__(self): return self.name @classmethod def choices(cls): return [(item, item.name) for item in AnsibleGroupType] @classmethod def coerce(cls, value): return value if type(value) == AnsibleGroupType else AnsibleGroupType[value] class AnsibleGroup(CreatedMixin, db.Model): __versioned__ = {} __tablename__ = "ansible_group" # Define id here so that it can be used in the primary and secondary join id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) vars = db.Column(postgresql.JSONB) type = db.Column( db.Enum(AnsibleGroupType, name="ansible_group_type"), default=AnsibleGroupType.STATIC, nullable=False, ) children = db.relationship( "AnsibleGroup", secondary=ansible_groups_parent_child_table, primaryjoin=id == ansible_groups_parent_child_table.c.parent_group_id, secondaryjoin=id == ansible_groups_parent_child_table.c.child_group_id, backref=db.backref("parents"), ) def __str__(self): return str(self.name) @property def is_dynamic(self): return self.type != AnsibleGroupType.STATIC @property def hosts(self): if self.type == AnsibleGroupType.STATIC: return self._hosts if self.type == AnsibleGroupType.NETWORK_SCOPE: return ( Host.query.join(Host.interfaces) .join(Interface.network) .join(Network.scope) .filter(NetworkScope.name == self.name) .order_by(Host.name) .all() ) if self.type == AnsibleGroupType.NETWORK: return ( Host.query.join(Host.interfaces) .join(Interface.network) .filter(Network.vlan_name == self.name) .order_by(Host.name) .all() ) if self.type == AnsibleGroupType.DEVICE_TYPE: return ( Host.query.join(Host.device_type) .filter(DeviceType.name == self.name) .order_by(Host.name) .all() ) @hosts.setter def hosts(self, value): # For dynamic group type, _hosts can only be set to [] if self.is_dynamic and value: raise AttributeError("can't set dynamic hosts") self._hosts = value def to_dict(self, recursive=False): d = super().to_dict() d.update( { "name": self.name, "vars": self.vars, "type": self.type.name, "hosts": [host.fqdn for host in self.hosts], "children": [str(child) for child in self.children], } ) return d class Host(CreatedMixin, SearchableMixin, db.Model): __versioned__ = {} __mapping__ = { "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "user": {"type": "keyword"}, "name": {"type": "keyword"}, "fqdn": {"type": "keyword"}, "is_ioc": {"type": "boolean"}, "device_type": {"type": "keyword"}, "model": {"type": "keyword"}, "description": {"type": "text"}, "items": {"type": "keyword"}, "interfaces": { "properties": { "id": {"enabled": False}, "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "user": {"type": "keyword"}, "is_main": {"type": "boolean"}, "network": {"type": "keyword"}, "ip": {"type": "ip"}, "netmask": {"enabled": False}, "name": {"type": "keyword"}, "mac": {"type": "keyword"}, "host": {"type": "keyword"}, "cnames": {"type": "keyword"}, "domain": {"type": "keyword"}, "tags": {"type": "keyword"}, "device_type": {"type": "keyword"}, "model": {"type": "keyword"}, } }, "ansible_vars": {"enabled": False}, "ansible_groups": {"type": "keyword"}, } # id shall be defined here to be used by SQLAlchemy-Continuum id = db.Column(db.Integer, primary_key=True) name = db.Column(db.Text, nullable=False, unique=True) description = db.Column(db.Text) device_type_id = db.Column( db.Integer, db.ForeignKey("device_type.id"), nullable=False ) ansible_vars = db.Column(postgresql.JSONB) # Set cascade to all (to add delete) and delete-orphan to delete all interfaces # when deleting a host interfaces = db.relationship( "Interface", backref=db.backref("host", lazy="joined"), cascade="all, delete-orphan", lazy="joined", ) items = db.relationship( "Item", backref=db.backref("host", lazy="joined"), lazy="joined" ) ansible_groups = db.relationship( "AnsibleGroup", secondary=ansible_groups_hosts_table, lazy="joined", backref=db.backref("_hosts"), ) def __init__(self, **kwargs): # Automatically convert device_type as an instance of its class if passed as a string if "device_type" in kwargs: kwargs["device_type"] = utils.convert_to_model( kwargs["device_type"], DeviceType ) # Automatically convert items to a list of instances if passed as a list of ics_id if "items" in kwargs: kwargs["items"] = [ utils.convert_to_model(item, Item, filter="ics_id") for item in kwargs["items"] ] # Automatically convert ansible groups to a list of instances if passed as a list of strings if "ansible_groups" in kwargs: kwargs["ansible_groups"] = [ utils.convert_to_model(group, AnsibleGroup) for group in kwargs["ansible_groups"] ] super().__init__(**kwargs) @property def is_ioc(self): for interface in self.interfaces: if interface.is_ioc: return True return False @property def model(self): """Return the model of the first linked item""" try: return utils.format_field(self.items[0].model) except IndexError: return None @property def main_interface(self): """Return the host main interface The main interface is the one that has the same name as the host or the first one found """ for interface in self.interfaces: if interface.name == self.name: return interface # No interface with the same name found... # Return the first one try: return self.interfaces[0] except IndexError: return None @property def main_network(self): """Return the host main interface network""" try: return self.main_interface.network except AttributeError: return None @property def fqdn(self): """Return the host fully qualified domain name The domain is based on the main interface """ if self.main_interface: return f"{self.name}.{self.main_interface.network.domain}" else: return self.name def __str__(self): return str(self.name) @validates("name") def validate_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: raise ValidationError("Host name shall match [a-z0-9\-]{2,20}") existing_cname = Cname.query.filter_by(name=lower_string).first() if existing_cname: raise ValidationError(f"Host name matches an existing cname") existing_interface = Interface.query.filter( Interface.name == lower_string, Interface.host_id != self.id ).first() if existing_interface: raise ValidationError(f"Host name matches an existing interface") return lower_string def stack_members(self): """Return all items part of the stack sorted by stack member number""" members = [item for item in self.items if item.stack_member is not None] return sorted(members, key=lambda x: x.stack_member) def stack_members_numbers(self): """Return the list of stack member numbers""" return [item.stack_member for item in self.stack_members()] def free_stack_members(self): """Return the list of free stack member numbers""" return [nb for nb in range(0, 10) if nb not in self.stack_members_numbers()] def to_dict(self, recursive=False): d = super().to_dict() d.update( { "name": self.name, "fqdn": self.fqdn, "is_ioc": self.is_ioc, "device_type": str(self.device_type), "model": self.model, "description": self.description, "items": [str(item) for item in self.items], "interfaces": [str(interface) for interface in self.interfaces], "ansible_vars": self.ansible_vars, "ansible_groups": [str(group) for group in self.ansible_groups], } ) if recursive: # Replace the list of interface names by the full representation # so that we can index everything in elasticsearch d["interfaces"] = [interface.to_dict() for interface in self.interfaces] return d class Interface(CreatedMixin, db.Model): network_id = db.Column(db.Integer, db.ForeignKey("network.id"), nullable=False) ip = db.Column(postgresql.INET, nullable=False, unique=True) name = db.Column(db.Text, nullable=False, unique=True) mac = db.Column(postgresql.MACADDR, nullable=True, unique=True) host_id = db.Column(db.Integer, db.ForeignKey("host.id"), nullable=False) # Add delete and delete-orphan options to automatically delete cnames when: # - deleting an interface # - de-associating a cname (removing it from the interface.cnames list) cnames = db.relationship( "Cname", backref=db.backref("interface", lazy="joined"), cascade="all, delete, delete-orphan", lazy="joined", ) tags = db.relationship( "Tag", secondary=interfacetags_table, lazy="subquery", backref=db.backref("interfaces", lazy=True), ) def __init__(self, **kwargs): # Always set self.host and not self.host_id to call validate_name host_id = kwargs.pop("host_id", None) if host_id is not None: host = Host.query.get(host_id) elif "host" in kwargs: # Automatically convert host to an instance of Host if it was passed # as a string host = utils.convert_to_model(kwargs.pop("host"), Host, "name") else: host = None # Always set self.network and not self.network_id to call validate_interfaces network_id = kwargs.pop("network_id", None) if network_id is not None: kwargs["network"] = Network.query.get(network_id) elif "network" in kwargs: # Automatically convert network to an instance of Network if it was passed # as a string kwargs["network"] = utils.convert_to_model( kwargs["network"], Network, "vlan_name" ) # WARNING! Setting self.network will call validate_interfaces in the Network class # For the validation to work, self.ip must be set before! # Ensure that ip is passed before network try: ip = kwargs.pop("ip") except KeyError: super().__init__(host=host, **kwargs) else: super().__init__(host=host, ip=ip, **kwargs) @validates("name") def validate_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: raise ValidationError("Interface name shall match [a-z0-9\-]{2,20}") if self.host and not lower_string.startswith(self.host.name): raise ValidationError( f"Interface name shall start with the host name '{self.host}'" ) existing_cname = Cname.query.filter_by(name=lower_string).first() if existing_cname: raise ValidationError(f"Interface name matches an existing cname") existing_host = Host.query.filter( Host.name == lower_string, Host.id != self.host.id ).first() if existing_host: raise ValidationError(f"Interface name matches an existing host") return lower_string @validates("mac") def validate_mac(self, key, string): """Ensure the mac is a valid MAC address""" if not string: return None if MAC_ADDRESS_RE.fullmatch(string) is None: raise ValidationError(f"'{string}' does not appear to be a MAC address") return string @validates("cnames") def validate_cnames(self, key, cname): """Ensure the cname is unique by domain""" existing_cnames = Cname.query.filter_by(name=cname.name).all() for existing_cname in existing_cnames: if existing_cname.domain == str(self.network.domain): raise ValidationError( f"Duplicate cname on the {self.network.domain} domain" ) return cname @property def address(self): return ipaddress.ip_address(self.ip) @property def is_ioc(self): for tag in self.tags: if tag.name == "IOC": return True return False @property def is_main(self): return self.name == self.host.main_interface.name def __str__(self): return str(self.name) def __repr__(self): return f"Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})" def to_dict(self, recursive=False): d = super().to_dict() d.update( { "is_main": self.is_main, "network": str(self.network), "ip": self.ip, "netmask": str(self.network.netmask), "name": self.name, "mac": utils.format_field(self.mac), "host": utils.format_field(self.host), "cnames": [str(cname) for cname in self.cnames], "domain": str(self.network.domain), "tags": [str(tag) for tag in self.tags], } ) if self.host: d["device_type"] = str(self.host.device_type) d["model"] = utils.format_field(self.host.model) else: d["device_type"] = None d["model"] = None return d class Mac(db.Model): id = db.Column(db.Integer, primary_key=True) address = db.Column(postgresql.MACADDR, nullable=False, unique=True) item_id = db.Column(db.Integer, db.ForeignKey("item.id")) def __str__(self): return str(self.address) @validates("address") def validate_address(self, key, string): """Ensure the address is a valid MAC address""" if string is None: return None if MAC_ADDRESS_RE.fullmatch(string) is None: raise ValidationError(f"'{string}' does not appear to be a MAC address") return string def to_dict(self, recursive=False): return { "id": self.id, "address": self.address, "item": utils.format_field(self.item), } class Cname(CreatedMixin, db.Model): name = db.Column(db.Text, nullable=False) interface_id = db.Column(db.Integer, db.ForeignKey("interface.id"), nullable=False) def __init__(self, **kwargs): # Always set self.interface and not self.interface_id to call validate_cnames interface_id = kwargs.pop("interface_id", None) if interface_id is not None: kwargs["interface"] = Interface.query.get(interface_id) super().__init__(**kwargs) def __str__(self): return str(self.name) @property def domain(self): """Return the cname domain name""" return str(self.interface.network.domain) @property def fqdn(self): """Return the cname fully qualified domain name""" return f"{self.name}.{self.domain}" @validates("name") def validate_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: raise ValidationError("cname shall match [a-z0-9\-]{2,20}") existing_interface = Interface.query.filter_by(name=lower_string).first() if existing_interface: raise ValidationError(f"cname matches an existing interface") existing_host = Host.query.filter_by(name=lower_string).first() if existing_host: raise ValidationError(f"cname matches an existing host") return lower_string def to_dict(self, recursive=False): d = super().to_dict() d.update({"name": self.name, "interface": str(self.interface)}) return d class Domain(CreatedMixin, db.Model): name = db.Column(db.Text, nullable=False, unique=True) scopes = db.relationship( "NetworkScope", backref=db.backref("domain", lazy="joined"), lazy=True ) networks = db.relationship( "Network", backref=db.backref("domain", lazy="joined"), lazy=True ) def __str__(self): return str(self.name) def to_dict(self, recursive=False): d = super().to_dict() d.update( { "name": self.name, "scopes": [str(scope) for scope in self.scopes], "networks": [str(network) for network in self.networks], } ) return d class NetworkScope(CreatedMixin, db.Model): __tablename__ = "network_scope" name = db.Column(CIText, nullable=False, unique=True) first_vlan = db.Column(db.Integer, nullable=False, unique=True) last_vlan = db.Column(db.Integer, nullable=False, unique=True) supernet = db.Column(postgresql.CIDR, nullable=False, unique=True) domain_id = db.Column(db.Integer, db.ForeignKey("domain.id"), nullable=False) description = db.Column(db.Text) networks = db.relationship( "Network", backref=db.backref("scope", lazy="joined"), lazy=True ) __table_args__ = ( sa.CheckConstraint( "first_vlan < last_vlan", name="first_vlan_less_than_last_vlan" ), ) def __str__(self): return str(self.name) @property def supernet_ip(self): return ipaddress.ip_network(self.supernet) def prefix_range(self): """Return the list of subnet prefix that can be used for this network scope""" return list(range(self.supernet_ip.prefixlen + 1, 31)) def vlan_range(self): """Return the list of vlan ids that can be assigned for this network scope The range is defined by the first and last vlan """ return range(self.first_vlan, self.last_vlan + 1) def used_vlans(self): """Return the list of vlan ids in use The list is sorted """ return sorted(network.vlan_id for network in self.networks) def available_vlans(self): """Return the list of vlan ids available""" return [vlan for vlan in self.vlan_range() if vlan not in self.used_vlans()] def used_subnets(self): """Return the list of subnets in use The list is sorted """ return sorted(network.network_ip for network in self.networks) def available_subnets(self, prefix): """Return the list of available subnets with the given prefix""" return [ str(subnet) for subnet in self.supernet_ip.subnets(new_prefix=prefix) if subnet not in self.used_subnets() ] def to_dict(self, recursive=False): d = super().to_dict() d.update( { "name": self.name, "first_vlan": self.first_vlan, "last_vlan": self.last_vlan, "supernet": self.supernet, "description": self.description, "domain": str(self.domain), "networks": [str(network) for network in self.networks], } ) return d # Define RQ JobStatus as a Python enum # We can't use the one defined in rq/job.py as it's # not a real enum (it's a custom one) and is not # compatible with sqlalchemy class JobStatus(Enum): QUEUED = "queued" FINISHED = "finished" FAILED = "failed" STARTED = "started" DEFERRED = "deferred" class Task(db.Model): # Use job id generated by RQ id = db.Column(postgresql.UUID, primary_key=True) created_at = db.Column(db.DateTime, default=utcnow()) ended_at = db.Column(db.DateTime) name = db.Column(db.Text, nullable=False, index=True) command = db.Column(db.Text) status = db.Column(db.Enum(JobStatus, name="job_status")) awx_job_id = db.Column(db.Integer) exception = db.Column(db.Text) user_id = db.Column( db.Integer, db.ForeignKey("user_account.id"), nullable=False, default=utils.fetch_current_user_id, ) @property def awx_job_url(self): if self.awx_job_id is None: return None return urllib.parse.urljoin( current_app.config["AWX_URL"], f"/#/jobs/{self.awx_job_id}" ) def __str__(self): return str(self.id) def to_dict(self, recursive=False): return { "id": self.id, "name": self.name, "created_at": utils.format_field(self.created_at), "ended_at": utils.format_field(self.ended_at), "status": self.status.name, "awx_job_id": self.awx_job_id, "awx_job_url": self.awx_job_url, "command": self.command, "exception": self.exception, "user": str(self.user), } @sa.event.listens_for(db.session, "before_flush") def before_flush(session, flush_context, instances): """Before flush hook Used to trigger core services update on any Interface modification. See http://docs.sqlalchemy.org/en/latest/orm/session_events.html#before-flush """ # In session.dirty, we need to check session.is_modified(instance) because when updating a Host, # the interface is added to the session even if not modified. # In session.deleted, session.is_modified(instance) is usually False (we shouldn't check it). # In session.new, it will always be True and we don't need to check it. for kind in ("new", "dirty", "deleted"): for instance in getattr(session, kind): if isinstance(instance, Interface) and ( (kind == "dirty" and session.is_modified(instance)) or (kind in ("new", "deleted")) ): utils.trigger_core_services_update() return # call configure_mappers after defining all the models # required by sqlalchemy_continuum sa.orm.configure_mappers() ItemVersion = version_class(Item) # Set SQLAlchemy event listeners db.event.listen(db.session, "before_flush", SearchableMixin.before_flush) db.event.listen( db.session, "after_flush_postexec", SearchableMixin.after_flush_postexec ) db.event.listen(db.session, "after_commit", SearchableMixin.after_commit)