# -*- coding: utf-8 -*- """ app.models ~~~~~~~~~~ This module implements the models used in the app. :copyright: (c) 2017 European Spallation Source ERIC :license: BSD 2-Clause, see LICENSE for more details. """ import ipaddress import string import qrcode import itertools import urllib.parse import elasticsearch import sqlalchemy as sa from enum import Enum from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.dialects import postgresql from sqlalchemy.orm import validates from sqlalchemy_continuum import make_versioned, version_class from citext import CIText from flask import current_app from flask_login import UserMixin, current_user from wtforms import ValidationError from rq import Queue from .extensions import db, login_manager, ldap_manager, cache from .plugins import FlaskUserPlugin from .validators import ( ICS_ID_RE, HOST_NAME_RE, VLAN_NAME_RE, MAC_ADDRESS_RE, DEVICE_TYPE_RE, ) from . import utils, search # Number of minutes to wait before to consider a deferrred job lost WAITING_DELAY = 30 make_versioned(plugins=[FlaskUserPlugin()]) # See http://docs.sqlalchemy.org/en/latest/core/compiler.html#utc-timestamp-function class utcnow(sa.sql.expression.FunctionElement): type = sa.types.DateTime() @sa.ext.compiler.compiles(utcnow, "postgresql") def pg_utcnow(element, compiler, **kw): return "TIMEZONE('utc', CURRENT_TIMESTAMP)" def temporary_ics_ids(): """Generator that returns the full list of temporary ICS ids""" return ( f'{current_app.config["TEMPORARY_ICS_ID"]}{letter}{number:0=3d}' for letter in string.ascii_uppercase for number in range(0, 1000) ) def used_temporary_ics_ids(): """Return a set with the temporary ICS ids used""" temporary_items = Item.query.filter( Item.ics_id.startswith(current_app.config["TEMPORARY_ICS_ID"]) ).all() return {item.ics_id for item in temporary_items} def get_temporary_ics_id(): """Return a temporary ICS id that is available""" used_temp_ics_ids = used_temporary_ics_ids() for ics_id in temporary_ics_ids(): if ics_id not in used_temp_ics_ids: return ics_id else: raise ValueError("No temporary ICS id available") @login_manager.user_loader @cache.memoize(timeout=1800) def load_user(user_id): """User loader callback for flask-login :param str user_id: unicode ID of a user :returns: corresponding user object or None """ return User.query.get(int(user_id)) @ldap_manager.save_user def save_user(dn, username, data, memberships): """User saver for flask-ldap3-login This method is called whenever a LDAPLoginForm() successfully validates. """ user = User.query.filter_by(username=username).first() if user is None: user = User( username=username, display_name=utils.attribute_to_string(data["cn"]), email=utils.attribute_to_string(data["mail"]), ) # Always update the user groups to keep them up-to-date user.groups = sorted( [utils.attribute_to_string(group["cn"]) for group in memberships] ) db.session.add(user) db.session.commit() return user # Tables required for Many-to-Many relationships between users and favorites attributes favorite_manufacturers_table = db.Table( "favorite_manufacturers", db.Column( "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True ), db.Column( "manufacturer_id", db.Integer, db.ForeignKey("manufacturer.id"), primary_key=True, ), ) favorite_models_table = db.Table( "favorite_models", db.Column( "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True ), db.Column("model_id", db.Integer, db.ForeignKey("model.id"), primary_key=True), ) favorite_locations_table = db.Table( "favorite_locations", db.Column( "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True ), db.Column( "location_id", db.Integer, db.ForeignKey("location.id"), primary_key=True ), ) favorite_statuses_table = db.Table( "favorite_statuses", db.Column( "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True ), db.Column("status_id", db.Integer, db.ForeignKey("status.id"), primary_key=True), ) favorite_actions_table = db.Table( "favorite_actions", db.Column( "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True ), db.Column("action_id", db.Integer, db.ForeignKey("action.id"), primary_key=True), ) class User(db.Model, UserMixin): # "user" is a reserved word in postgresql # so let's use another name __tablename__ = "user_account" id = db.Column(db.Integer, primary_key=True) username = db.Column(db.Text, nullable=False, unique=True) display_name = db.Column(db.Text, nullable=False) email = db.Column(db.Text) groups = db.Column(postgresql.ARRAY(db.Text), default=[]) tokens = db.relationship("Token", backref="user") tasks = db.relationship("Task", backref="user") # The favorites won't be accessed very often so we load them # only when necessary (lazy=True) favorite_manufacturers = db.relationship( "Manufacturer", secondary=favorite_manufacturers_table, lazy=True, backref=db.backref("favorite_users", lazy=True), ) favorite_models = db.relationship( "Model", secondary=favorite_models_table, lazy=True, backref=db.backref("favorite_users", lazy=True), ) favorite_locations = db.relationship( "Location", secondary=favorite_locations_table, lazy=True, backref=db.backref("favorite_users", lazy=True), ) favorite_statuses = db.relationship( "Status", secondary=favorite_statuses_table, lazy=True, backref=db.backref("favorite_users", lazy=True), ) favorite_actions = db.relationship( "Action", secondary=favorite_actions_table, lazy=True, backref=db.backref("favorite_users", lazy=True), ) def get_id(self): """Return the user id as unicode Required by flask-login """ return str(self.id) @property def csentry_groups(self): """Return the list of CSEntry groups the user belong to Groups are assigned based on the CSENTRY_LDAP_GROUPS mapping with LDAP groups """ if not hasattr(self, "_csentry_groups"): self._csentry_groups = [] for csentry_group, ldap_groups in current_app.config[ "CSENTRY_LDAP_GROUPS" ].items(): if set(self.groups) & set(ldap_groups): self._csentry_groups.append(csentry_group) # Add the network group based on CSENTRY_DOMAINS_LDAP_GROUPS network_ldap_groups = set( itertools.chain( *current_app.config["CSENTRY_DOMAINS_LDAP_GROUPS"].values() ) ) if set(self.groups) & network_ldap_groups: self._csentry_groups.append("network") return self._csentry_groups @property def csentry_domains(self): """Return the list of CSEntry domains the user has access to Domains are assigned based on the CSENTRY_DOMAINS_LDAP_GROUPS mapping with LDAP groups """ if not hasattr(self, "_csentry_domains"): self._csentry_domains = [] for domain, ldap_groups in current_app.config[ "CSENTRY_DOMAINS_LDAP_GROUPS" ].items(): if set(self.groups) & set(ldap_groups): self._csentry_domains.append(domain) return self._csentry_domains @property def is_admin(self): return "admin" in self.csentry_groups def is_member_of_one_group(self, groups): """Return True if the user is at least member of one of the given CSEntry groups""" return bool(set(groups) & set(self.csentry_groups)) def has_access_to_network(self, network): """Return True if the user has access to the network - admin users have access to all networks - normal users must have access to the network domain - normal users don't have access to admin_only networks (whatever the domain) - LOGIN_DISABLED can be set to True to turn off authentication check when testing. In this case, this function always returns True. """ if current_app.config.get("LOGIN_DISABLED") or self.is_admin: return True if network is None or network.admin_only: # True is already returned for admin users return False return str(network.domain) in self.csentry_domains def can_create_vm(self, host): """Return True if the user can create the VM - host.device_type shall be VirtualMachine - admin users can create anything - normal users must have access to the network to create VIOC - normal users can only create a VM if the host is in one of the allowed domains - LOGIN_DISABLED can be set to True to turn off authentication check when testing. In this case, this function always returns True. """ if str(host.device_type) != "VirtualMachine": return False if current_app.config.get("LOGIN_DISABLED") or self.is_admin: return True if not self.has_access_to_network(host.main_network): # True is already returned for admin users return False if host.is_ioc: # VIOC can be created by anyone having access to the network return True # VM can only be created if the domain is allowed return ( str(host.main_interface.network.domain) in current_app.config["ALLOWED_VM_CREATION_DOMAINS"] ) def favorite_attributes(self): """Return all user's favorite attributes""" favorites_list = [ self.favorite_manufacturers, self.favorite_models, self.favorite_locations, self.favorite_statuses, self.favorite_actions, ] return [favorite for favorites in favorites_list for favorite in favorites] def launch_task(self, name, func, queue_name="normal", **kwargs): """Launch a task in the background using RQ The task is added to the session but not committed. """ q = Queue(queue_name) job = q.enqueue(f"app.tasks.{func}", **kwargs) # The status will be set to QUEUED or DEFERRED task = Task( id=job.id, name=name, awx_resource=kwargs.get("resource", None), command=job.get_call_string(), status=JobStatus(job.status), user=self, ) db.session.add(task) return task def get_tasks(self, all=False): """Return all tasks created by the current user If the user is admin and all is set to True, will return all tasks """ if all and self.is_admin: return Task.query.order_by(Task.created_at).all() return Task.query.filter_by(user=self).order_by(Task.created_at).all() def get_tasks_in_progress(self, name): """Return all the <name> tasks not finished or failed""" return ( Task.query.filter_by(name=name) .filter(~Task.status.in_([JobStatus.FINISHED, JobStatus.FAILED])) .all() ) def get_task_started(self, name): """Return the <name> task currently running or None""" return Task.query.filter_by(name=name, status=JobStatus.STARTED).first() def is_task_waiting(self, name): """Return True if a <name> task is waiting Waiting means: - queued - deferred if not older than WAITING_DELAY minutes A deferred task can stay deferred forever if the task it depends on fails. """ count = ( Task.query.filter_by(name=name) .filter( (Task.status == JobStatus.QUEUED) | ( (Task.status == JobStatus.DEFERRED) & (Task.created_at > utils.minutes_ago(WAITING_DELAY)) ) ) .count() ) return count > 0 def get_task_waiting(self, name): """Return the latest <name> task currently waiting or None Waiting means: - queued - deferred if not older than WAITING_DELAY minutes A deferred task can stay deferred forever if the task it depends on fails. """ return ( Task.query.filter_by(name=name) .filter( (Task.status == JobStatus.QUEUED) | ( (Task.status == JobStatus.DEFERRED) & (Task.created_at > utils.minutes_ago(WAITING_DELAY)) ) ) .order_by(Task.created_at.desc()) .first() ) def __str__(self): return self.username def to_dict(self, recursive=False): return { "id": self.id, "username": self.username, "display_name": self.display_name, "email": self.email, "groups": self.csentry_groups, } class SearchableMixin(object): """Add search capability to a class""" @classmethod def search(cls, query, page=1, per_page=20, sort=None): try: ids, total = search.query_index( cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"], query, page, per_page, sort, ) except elasticsearch.ElasticsearchException as e: # Invalid query current_app.logger.warning(e) return cls.query.filter_by(id=0), 0 if total == 0: return cls.query.filter_by(id=0), 0 when = [(value, i) for i, value in enumerate(ids)] return ( cls.query.filter(cls.id.in_(ids)).order_by(db.case(when, value=cls.id)), total, ) @classmethod def before_flush(cls, session, flush_context, instances): """Save the new/modified/deleted objects""" # The session.new / dirty / deleted lists are empty in the after_flush_postexec event. # We need to record them here session._changes = {"add_obj": [], "delete": []} for obj in itertools.chain(session.new, session.dirty): if isinstance(obj, SearchableMixin): index = ( obj.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"] ) current_app.logger.debug( f"object to add/update in the {index} index: {obj}" ) session._changes["add_obj"].append((index, obj)) for obj in session.deleted: if isinstance(obj, SearchableMixin): index = ( obj.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"] ) current_app.logger.debug( f"object to remove from the {index} index: {obj}" ) session._changes["delete"].append((index, obj.id)) @classmethod def after_flush_postexec(cls, session, flush_context): """Retrieve the new and updated objects representation""" if session._changes is None: return # - We can't call obj.to_dict() in the before_flush event because the id # hasn't been allocated yet (for new objects) and other fields haven't been updated # (default values like created_at/updated_at and some relationships). # - We can't call obj.to_dict() in the after_commit event because it would raise: # sqlalchemy.exc.InvalidRequestError: # This session is in 'committed' state; no further SQL can be emitted within this transaction. session._changes["add"] = [ (index, obj.to_dict(recursive=True)) for index, obj in session._changes["add_obj"] ] @classmethod def after_commit(cls, session): """Update the elasticsearch index""" if session._changes is None: return for index, body in session._changes["add"]: search.add_to_index(index, body) for index, id in session._changes["delete"]: search.remove_from_index(index, id) session._changes = None @classmethod def delete_index(cls, **kwargs): """Delete the index of the class""" current_app.logger.info(f"Delete the {cls.__tablename__} index") search.delete_index( cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"], **kwargs, ) @classmethod def create_index(cls, **kwargs): """Create the index of the class""" if hasattr(cls, "__mapping__"): current_app.logger.info(f"Create the {cls.__tablename__} index") search.create_index( cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"], cls.__mapping__, **kwargs, ) else: current_app.logger.info( f"No mapping defined for {cls.__tablename__}. No index created." ) @classmethod def reindex(cls): """Force to reindex all instances of the class""" current_app.logger.info(f"Force to re-index all {cls.__tablename__} instances") # Ignore index_not_found_exception cls.delete_index(ignore_unavailable=True) cls.create_index() for obj in cls.query: search.add_to_index( cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"], obj.to_dict(recursive=True), ) class Token(db.Model): """Table to store valid tokens""" id = db.Column(db.Integer, primary_key=True) jti = db.Column(postgresql.UUID, nullable=False) token_type = db.Column(db.Text, nullable=False) user_id = db.Column(db.Integer, db.ForeignKey("user_account.id"), nullable=False) issued_at = db.Column(db.DateTime, nullable=False) # expires can be set to None for tokens that never expire expires = db.Column(db.DateTime) description = db.Column(db.Text) __table_args__ = (sa.UniqueConstraint(jti, user_id),) def __str__(self): return self.jti class QRCodeMixin: id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) description = db.Column(db.Text) def image(self): """Return a QRCode image to identify a record The QRCode includes: - CSE prefix - the table name - the name of the record """ data = ":".join(["CSE", self.__tablename__, self.name]) return qrcode.make(data, version=1, box_size=5) @cache.memoize(timeout=0) def base64_image(self): """Return the QRCode image as base64 string""" return utils.image_to_base64(self.image()) def is_user_favorite(self): """Return True if the attribute is part of the current user favorites""" return current_user in self.favorite_users def __str__(self): return self.name def __repr__(self): # The cache.memoize decorator performs a repr() on the passed in arguments # __repr__ is used as part of the cache key and shall be a uniquely identifying string # See https://flask-caching.readthedocs.io/en/latest/#memoization return f"{self.__class__.__name__}(id={self.id}, name={self.name})" def to_dict(self, recursive=False): return { "id": self.id, "name": self.name, "description": self.description, "qrcode": self.base64_image(), } class Action(QRCodeMixin, db.Model): pass class Manufacturer(QRCodeMixin, db.Model): items = db.relationship("Item", back_populates="manufacturer") class Model(QRCodeMixin, db.Model): items = db.relationship("Item", back_populates="model") class Location(QRCodeMixin, db.Model): items = db.relationship("Item", back_populates="location") class Status(QRCodeMixin, db.Model): items = db.relationship("Item", back_populates="status") class CreatedMixin: id = db.Column(db.Integer, primary_key=True) created_at = db.Column(db.DateTime, default=utcnow()) updated_at = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow()) # Using ForeignKey and relationship in mixin requires the @declared_attr decorator # See http://docs.sqlalchemy.org/en/latest/orm/extensions/declarative/mixins.html @declared_attr def user_id(cls): return db.Column( db.Integer, db.ForeignKey("user_account.id"), nullable=False, default=utils.fetch_current_user_id, ) @declared_attr def user(cls): return db.relationship("User") def __init__(self, **kwargs): # Automatically convert created_at/updated_at strings # to datetime object for key in ("created_at", "updated_at"): if key in kwargs: if isinstance(kwargs[key], str): kwargs[key] = utils.parse_to_utc(kwargs[key]) super().__init__(**kwargs) def to_dict(self): return { "id": self.id, "created_at": utils.format_field(self.created_at), "updated_at": utils.format_field(self.updated_at), "user": str(self.user), } class Item(CreatedMixin, SearchableMixin, db.Model): __versioned__ = { "exclude": [ "created_at", "user_id", "ics_id", "serial_number", "manufacturer_id", "model_id", ] } __mapping__ = { "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "ics_id": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "serial_number": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "quantity": {"type": "long"}, "manufacturer": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "model": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "location": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "status": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "parent": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "children": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "macs": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "host": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "stack_member": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "history": {"enabled": False}, "comments": {"type": "text"}, } # WARNING! Inheriting id from CreatedMixin doesn't play well with # SQLAlchemy-Continuum. It has to be defined here. id = db.Column(db.Integer, primary_key=True) ics_id = db.Column( db.Text, unique=True, nullable=False, index=True, default=get_temporary_ics_id ) serial_number = db.Column(db.Text, nullable=False) quantity = db.Column(db.Integer, nullable=False, default=1) manufacturer_id = db.Column(db.Integer, db.ForeignKey("manufacturer.id")) model_id = db.Column(db.Integer, db.ForeignKey("model.id")) location_id = db.Column(db.Integer, db.ForeignKey("location.id")) status_id = db.Column(db.Integer, db.ForeignKey("status.id")) parent_id = db.Column(db.Integer, db.ForeignKey("item.id")) host_id = db.Column(db.Integer, db.ForeignKey("host.id")) stack_member = db.Column(db.SmallInteger) manufacturer = db.relationship( "Manufacturer", back_populates="items", lazy="joined" ) model = db.relationship("Model", back_populates="items", lazy="joined") location = db.relationship("Location", back_populates="items", lazy="joined") status = db.relationship("Status", back_populates="items", lazy="joined") children = db.relationship("Item", backref=db.backref("parent", remote_side=[id])) macs = db.relationship("Mac", backref="item", lazy="joined") comments = db.relationship( "ItemComment", backref="item", cascade="all, delete-orphan", lazy="joined" ) __table_args__ = ( sa.CheckConstraint( "stack_member >= 0 AND stack_member <=9", name="stack_member_range" ), sa.UniqueConstraint(host_id, stack_member, name="uq_item_host_id_stack_member"), ) def __init__(self, **kwargs): # Automatically convert manufacturer/model/location/status to an # instance of their class if passed as a string for key, cls in [ ("manufacturer", Manufacturer), ("model", Model), ("location", Location), ("status", Status), ]: if key in kwargs: kwargs[key] = utils.convert_to_model(kwargs[key], cls) super().__init__(**kwargs) def __str__(self): return str(self.ics_id) @validates("ics_id") def validate_ics_id(self, key, string): """Ensure the ICS id field matches the required format""" if string is not None: if ICS_ID_RE.fullmatch(string) is None: raise ValidationError("ICS id shall match [A-Z]{3}[0-9]{3}") return string def to_dict(self, recursive=False): d = super().to_dict() d.update( { "ics_id": self.ics_id, "serial_number": self.serial_number, "quantity": self.quantity, "manufacturer": utils.format_field(self.manufacturer), "model": utils.format_field(self.model), "location": utils.format_field(self.location), "status": utils.format_field(self.status), "parent": utils.format_field(self.parent), "children": [str(child) for child in self.children], "macs": [str(mac) for mac in self.macs], "host": utils.format_field(self.host), "stack_member": utils.format_field(self.stack_member), "history": self.history(), "comments": [str(comment) for comment in self.comments], } ) return d def to_row_dict(self): """Convert to a dict that can easily be exported to an excel row All values should be a string """ d = self.to_dict().copy() d["children"] = " ".join(d["children"]) d["macs"] = " ".join(d["macs"]) d["comments"] = "\n\n".join(d["comments"]) d["history"] = "\n".join([str(version) for version in d["history"]]) return d def history(self): versions = [] for version in self.versions: # parent is an attribute used by SQLAlchemy-Continuum # version.parent refers to an ItemVersion instance (and has no link with # the item parent_id) # We need to retrieve the parent "manually" if version.parent_id is None: parent = None else: parent = Item.query.get(version.parent_id) versions.append( { "updated_at": utils.format_field(version.updated_at), "quantity": version.quantity, "location": utils.format_field(version.location), "status": utils.format_field(version.status), "parent": utils.format_field(parent), } ) return versions class ItemComment(CreatedMixin, db.Model): body = db.Column(db.Text, nullable=False) item_id = db.Column(db.Integer, db.ForeignKey("item.id"), nullable=False) def __str__(self): return self.body def to_dict(self, recursive=False): d = super().to_dict() d.update({"body": self.body, "item": str(self.item)}) return d class Network(CreatedMixin, db.Model): vlan_name = db.Column(CIText, nullable=False, unique=True) vlan_id = db.Column(db.Integer, nullable=False, unique=True) address = db.Column(postgresql.CIDR, nullable=False, unique=True) first_ip = db.Column(postgresql.INET, nullable=False, unique=True) last_ip = db.Column(postgresql.INET, nullable=False, unique=True) gateway = db.Column(postgresql.INET, nullable=False, unique=True) description = db.Column(db.Text) admin_only = db.Column(db.Boolean, nullable=False, default=False) scope_id = db.Column(db.Integer, db.ForeignKey("network_scope.id"), nullable=False) domain_id = db.Column(db.Integer, db.ForeignKey("domain.id"), nullable=False) interfaces = db.relationship( "Interface", backref=db.backref("network", lazy="joined"), lazy=True ) __table_args__ = ( sa.CheckConstraint("first_ip < last_ip", name="first_ip_less_than_last_ip"), sa.CheckConstraint("first_ip << address", name="first_ip_in_network"), sa.CheckConstraint("last_ip << address", name="last_ip_in_network"), sa.CheckConstraint("gateway << address", name="gateway_in_network"), ) def __init__(self, **kwargs): # Automatically convert scope to an instance of NetworkScope if it was passed # as a string if "scope" in kwargs: kwargs["scope"] = utils.convert_to_model( kwargs["scope"], NetworkScope, "name" ) # If domain_id is not passed, we set it to the network scope value if "domain_id" not in kwargs: kwargs["domain_id"] = kwargs["scope"].domain_id super().__init__(**kwargs) def __str__(self): return str(self.vlan_name) @property def network_ip(self): return ipaddress.ip_network(self.address) @property def netmask(self): return self.network_ip.netmask @property def first(self): return ipaddress.ip_address(self.first_ip) @property def last(self): return ipaddress.ip_address(self.last_ip) def ip_range(self): """Return the list of IP addresses that can be assigned for this network The range is defined by the first and last IP """ return [ addr for addr in self.network_ip.hosts() if self.first <= addr <= self.last ] def used_ips(self): """Return the list of IP addresses in use The list is sorted """ return sorted(interface.address for interface in self.interfaces) def available_ips(self): """Return the list of IP addresses available""" return [addr for addr in self.ip_range() if addr not in self.used_ips()] @staticmethod def ip_in_network(ip, address): """Ensure the IP is in the network :param str user_id: unicode ID of a user :returns: a tuple with the IP and network as (IPv4Address, IPv4Network) :raises: ValidationError if the IP is not in the network """ addr = ipaddress.ip_address(ip) net = ipaddress.ip_network(address) if addr not in net: raise ValidationError(f"IP address {ip} is not in network {address}") return (addr, net) @validates("first_ip") def validate_first_ip(self, key, ip): """Ensure the first IP is in the network""" self.ip_in_network(ip, self.address) return ip @validates("last_ip") def validate_last_ip(self, key, ip): """Ensure the last IP is in the network and greater than first_ip""" addr, net = self.ip_in_network(ip, self.address) if addr < self.first: raise ValidationError( f"Last IP address {ip} is less than the first address {self.first}" ) return ip @validates("interfaces") def validate_interfaces(self, key, interface): """Ensure the interface IP is in the network range""" addr, net = self.ip_in_network(interface.ip, self.address) # Admin user can create IP outside the defined range try: # current_user is a local proxy and is not # valid outside of a request context. is_admin = current_user.is_admin except AttributeError: is_admin = False if not is_admin: if addr < self.first or addr > self.last: raise ValidationError( f"IP address {interface.ip} is not in range {self.first} - {self.last}" ) return interface @validates("vlan_name") def validate_vlan_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None if VLAN_NAME_RE.fullmatch(string) is None: raise ValidationError("Vlan name shall match [A-Za-z0-9\-]{3,25}") return string def to_dict(self, recursive=False): d = super().to_dict() d.update( { "vlan_name": self.vlan_name, "vlan_id": self.vlan_id, "address": self.address, "netmask": str(self.netmask), "first_ip": self.first_ip, "last_ip": self.last_ip, "gateway": self.gateway, "description": self.description, "admin_only": self.admin_only, "scope": utils.format_field(self.scope), "domain": str(self.domain), "interfaces": [str(interface) for interface in self.interfaces], } ) return d class DeviceType(db.Model): __tablename__ = "device_type" id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) hosts = db.relationship( "Host", backref=db.backref("device_type", lazy="joined"), lazy=True ) @validates("name") def validate_name(self, key, string): """Ensure the name field matches the required format""" if string is not None: if DEVICE_TYPE_RE.fullmatch(string) is None: raise ValidationError(f"'{string}' is an invalid device type name") return string def __str__(self): return self.name def to_dict(self, recursive=False): return { "id": self.id, "name": self.name, "hosts": [str(host) for host in self.hosts], } # Table required for Many-to-Many relationships between Ansible parent and child groups ansible_groups_parent_child_table = db.Table( "ansible_groups_parent_child", db.Column( "parent_group_id", db.Integer, db.ForeignKey("ansible_group.id"), primary_key=True, ), db.Column( "child_group_id", db.Integer, db.ForeignKey("ansible_group.id"), primary_key=True, ), ) # Table required for Many-to-Many relationships between Ansible groups and hosts ansible_groups_hosts_table = db.Table( "ansible_groups_hosts", db.Column( "ansible_group_id", db.Integer, db.ForeignKey("ansible_group.id"), primary_key=True, ), db.Column("host_id", db.Integer, db.ForeignKey("host.id"), primary_key=True), ) class AnsibleGroupType(Enum): STATIC = "STATIC" NETWORK_SCOPE = "NETWORK_SCOPE" NETWORK = "NETWORK" DEVICE_TYPE = "DEVICE_TYPE" IOC = "IOC" def __str__(self): return self.name @classmethod def choices(cls): return [(item, item.name) for item in AnsibleGroupType] @classmethod def coerce(cls, value): return value if type(value) == AnsibleGroupType else AnsibleGroupType[value] class AnsibleGroup(CreatedMixin, db.Model): __versioned__ = {} __tablename__ = "ansible_group" # Define id here so that it can be used in the primary and secondary join id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) vars = db.Column(postgresql.JSONB) type = db.Column( db.Enum(AnsibleGroupType, name="ansible_group_type"), default=AnsibleGroupType.STATIC, nullable=False, ) children = db.relationship( "AnsibleGroup", secondary=ansible_groups_parent_child_table, primaryjoin=id == ansible_groups_parent_child_table.c.parent_group_id, secondaryjoin=id == ansible_groups_parent_child_table.c.child_group_id, backref=db.backref("parents"), ) def __str__(self): return str(self.name) @property def is_dynamic(self): return self.type != AnsibleGroupType.STATIC @property def hosts(self): if self.type == AnsibleGroupType.STATIC: return self._hosts if self.type == AnsibleGroupType.NETWORK_SCOPE: return ( Host.query.join(Host.interfaces) .join(Interface.network) .join(Network.scope) .filter(NetworkScope.name == self.name, Interface.name == Host.name) .order_by(Host.name) .all() ) if self.type == AnsibleGroupType.NETWORK: return ( Host.query.join(Host.interfaces) .join(Interface.network) .filter(Network.vlan_name == self.name, Interface.name == Host.name) .order_by(Host.name) .all() ) if self.type == AnsibleGroupType.DEVICE_TYPE: return ( Host.query.join(Host.device_type) .filter(DeviceType.name == self.name) .order_by(Host.name) .all() ) if self.type == AnsibleGroupType.IOC: return Host.query.filter(Host.is_ioc.is_(True)).order_by(Host.name).all() @hosts.setter def hosts(self, value): # For dynamic group type, _hosts can only be set to [] if self.is_dynamic and value: raise AttributeError("can't set dynamic hosts") self._hosts = value def to_dict(self, recursive=False): d = super().to_dict() d.update( { "name": self.name, "vars": self.vars, "type": self.type.name, "hosts": [host.fqdn for host in self.hosts], "children": [str(child) for child in self.children], } ) return d class Host(CreatedMixin, SearchableMixin, db.Model): __versioned__ = {} __mapping__ = { "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "fqdn": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "is_ioc": {"type": "boolean"}, "device_type": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "model": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "description": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "items": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "interfaces": { "properties": { "id": {"enabled": False}, "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"}, "user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "is_main": {"type": "boolean"}, "network": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "ip": {"type": "ip"}, "netmask": {"enabled": False}, "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "mac": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "host": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "cnames": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "domain": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "device_type": { "type": "text", "fields": {"keyword": {"type": "keyword"}}, }, "model": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, } }, "ansible_vars": {"enabled": False}, "ansible_groups": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, } # id shall be defined here to be used by SQLAlchemy-Continuum id = db.Column(db.Integer, primary_key=True) name = db.Column(db.Text, nullable=False, unique=True) description = db.Column(db.Text) device_type_id = db.Column( db.Integer, db.ForeignKey("device_type.id"), nullable=False ) is_ioc = db.Column(db.Boolean, nullable=False, default=False) ansible_vars = db.Column(postgresql.JSONB) # 1. Set cascade to all (to add delete) and delete-orphan to delete all interfaces # when deleting a host # 2. Return interfaces sorted by name so that the main one (the one starting with # the same name as the host) is always the first one. # As an interface name always has to start with the name of the host, the one # matching the host name will always come first. interfaces = db.relationship( "Interface", backref=db.backref("host", lazy="joined"), cascade="all, delete-orphan", lazy="joined", order_by="Interface.name", ) items = db.relationship( "Item", backref=db.backref("host", lazy="joined"), lazy="joined" ) ansible_groups = db.relationship( "AnsibleGroup", secondary=ansible_groups_hosts_table, lazy="joined", backref=db.backref("_hosts"), ) def __init__(self, **kwargs): # Automatically convert device_type as an instance of its class if passed as a string if "device_type" in kwargs: kwargs["device_type"] = utils.convert_to_model( kwargs["device_type"], DeviceType ) # Automatically convert items to a list of instances if passed as a list of ics_id if "items" in kwargs: kwargs["items"] = [ utils.convert_to_model(item, Item, filter="ics_id") for item in kwargs["items"] ] # Automatically convert ansible groups to a list of instances if passed as a list of strings if "ansible_groups" in kwargs: kwargs["ansible_groups"] = [ utils.convert_to_model(group, AnsibleGroup) for group in kwargs["ansible_groups"] ] super().__init__(**kwargs) @property def model(self): """Return the model of the first linked item""" try: return utils.format_field(self.items[0].model) except IndexError: return None @property def main_interface(self): """Return the host main interface The main interface is the one that has the same name as the host or the first one found """ # As interfaces are sorted, the first one is always the main one try: return self.interfaces[0] except IndexError: return None @property def main_network(self): """Return the host main interface network""" try: return self.main_interface.network except AttributeError: return None @property def fqdn(self): """Return the host fully qualified domain name The domain is based on the main interface """ if self.main_interface: return f"{self.name}.{self.main_interface.network.domain}" else: return self.name def __str__(self): return str(self.name) @validates("name") def validate_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: raise ValidationError("Host name shall match [a-z0-9\-]{2,20}") existing_cname = Cname.query.filter_by(name=lower_string).first() if existing_cname: raise ValidationError(f"Host name matches an existing cname") existing_interface = Interface.query.filter( Interface.name == lower_string, Interface.host_id != self.id ).first() if existing_interface: raise ValidationError(f"Host name matches an existing interface") return lower_string def stack_members(self): """Return all items part of the stack sorted by stack member number""" members = [item for item in self.items if item.stack_member is not None] return sorted(members, key=lambda x: x.stack_member) def stack_members_numbers(self): """Return the list of stack member numbers""" return [item.stack_member for item in self.stack_members()] def free_stack_members(self): """Return the list of free stack member numbers""" return [nb for nb in range(0, 10) if nb not in self.stack_members_numbers()] def to_dict(self, recursive=False): d = super().to_dict() d.update( { "name": self.name, "fqdn": self.fqdn, "is_ioc": self.is_ioc, "device_type": str(self.device_type), "model": self.model, "description": self.description, "items": [str(item) for item in self.items], "interfaces": [str(interface) for interface in self.interfaces], "ansible_vars": self.ansible_vars, "ansible_groups": [str(group) for group in self.ansible_groups], } ) if recursive: # Replace the list of interface names by the full representation # so that we can index everything in elasticsearch d["interfaces"] = [interface.to_dict() for interface in self.interfaces] return d class Interface(CreatedMixin, db.Model): network_id = db.Column(db.Integer, db.ForeignKey("network.id"), nullable=False) ip = db.Column(postgresql.INET, nullable=False, unique=True) name = db.Column(db.Text, nullable=False, unique=True) mac = db.Column(postgresql.MACADDR, nullable=True, unique=True) host_id = db.Column(db.Integer, db.ForeignKey("host.id"), nullable=False) # Add delete and delete-orphan options to automatically delete cnames when: # - deleting an interface # - de-associating a cname (removing it from the interface.cnames list) cnames = db.relationship( "Cname", backref=db.backref("interface", lazy="joined"), cascade="all, delete, delete-orphan", lazy="joined", ) def __init__(self, **kwargs): # Always set self.host and not self.host_id to call validate_name host_id = kwargs.pop("host_id", None) if host_id is not None: host = Host.query.get(host_id) elif "host" in kwargs: # Automatically convert host to an instance of Host if it was passed # as a string host = utils.convert_to_model(kwargs.pop("host"), Host, "name") else: host = None # Always set self.network and not self.network_id to call validate_interfaces network_id = kwargs.pop("network_id", None) if network_id is not None: kwargs["network"] = Network.query.get(network_id) elif "network" in kwargs: # Automatically convert network to an instance of Network if it was passed # as a string kwargs["network"] = utils.convert_to_model( kwargs["network"], Network, "vlan_name" ) # WARNING! Setting self.network will call validate_interfaces in the Network class # For the validation to work, self.ip must be set before! # Ensure that ip is passed before network try: ip = kwargs.pop("ip") except KeyError: super().__init__(host=host, **kwargs) else: super().__init__(host=host, ip=ip, **kwargs) @validates("name") def validate_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: raise ValidationError("Interface name shall match [a-z0-9\-]{2,20}") if self.host and not lower_string.startswith(self.host.name): raise ValidationError( f"Interface name shall start with the host name '{self.host}'" ) existing_cname = Cname.query.filter_by(name=lower_string).first() if existing_cname: raise ValidationError(f"Interface name matches an existing cname") existing_host = Host.query.filter( Host.name == lower_string, Host.id != self.host.id ).first() if existing_host: raise ValidationError(f"Interface name matches an existing host") return lower_string @validates("mac") def validate_mac(self, key, string): """Ensure the mac is a valid MAC address""" if not string: return None if MAC_ADDRESS_RE.fullmatch(string) is None: raise ValidationError(f"'{string}' does not appear to be a MAC address") return string @validates("cnames") def validate_cnames(self, key, cname): """Ensure the cname is unique by domain""" existing_cnames = Cname.query.filter_by(name=cname.name).all() for existing_cname in existing_cnames: if existing_cname.domain == str(self.network.domain): raise ValidationError( f"Duplicate cname on the {self.network.domain} domain" ) return cname @property def address(self): return ipaddress.ip_address(self.ip) @property def is_ioc(self): return self.is_main and self.host.is_ioc @property def is_main(self): return self.name == self.host.main_interface.name def __str__(self): return str(self.name) def __repr__(self): return f"Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})" def to_dict(self, recursive=False): d = super().to_dict() d.update( { "is_main": self.is_main, "network": str(self.network), "ip": self.ip, "netmask": str(self.network.netmask), "name": self.name, "mac": utils.format_field(self.mac), "host": utils.format_field(self.host), "cnames": [str(cname) for cname in self.cnames], "domain": str(self.network.domain), } ) if self.host: d["device_type"] = str(self.host.device_type) d["model"] = utils.format_field(self.host.model) else: d["device_type"] = None d["model"] = None return d class Mac(db.Model): id = db.Column(db.Integer, primary_key=True) address = db.Column(postgresql.MACADDR, nullable=False, unique=True) item_id = db.Column(db.Integer, db.ForeignKey("item.id")) def __str__(self): return str(self.address) @validates("address") def validate_address(self, key, string): """Ensure the address is a valid MAC address""" if string is None: return None if MAC_ADDRESS_RE.fullmatch(string) is None: raise ValidationError(f"'{string}' does not appear to be a MAC address") return string def to_dict(self, recursive=False): return { "id": self.id, "address": self.address, "item": utils.format_field(self.item), } class Cname(CreatedMixin, db.Model): name = db.Column(db.Text, nullable=False) interface_id = db.Column(db.Integer, db.ForeignKey("interface.id"), nullable=False) def __init__(self, **kwargs): # Always set self.interface and not self.interface_id to call validate_cnames interface_id = kwargs.pop("interface_id", None) if interface_id is not None: kwargs["interface"] = Interface.query.get(interface_id) super().__init__(**kwargs) def __str__(self): return str(self.name) @property def domain(self): """Return the cname domain name""" return str(self.interface.network.domain) @property def fqdn(self): """Return the cname fully qualified domain name""" return f"{self.name}.{self.domain}" @validates("name") def validate_name(self, key, string): """Ensure the name matches the required format""" if string is None: return None # Force the string to lowercase lower_string = string.lower() if HOST_NAME_RE.fullmatch(lower_string) is None: raise ValidationError("cname shall match [a-z0-9\-]{2,20}") existing_interface = Interface.query.filter_by(name=lower_string).first() if existing_interface: raise ValidationError(f"cname matches an existing interface") existing_host = Host.query.filter_by(name=lower_string).first() if existing_host: raise ValidationError(f"cname matches an existing host") return lower_string def to_dict(self, recursive=False): d = super().to_dict() d.update({"name": self.name, "interface": str(self.interface)}) return d class Domain(CreatedMixin, db.Model): name = db.Column(db.Text, nullable=False, unique=True) scopes = db.relationship( "NetworkScope", backref=db.backref("domain", lazy="joined"), lazy=True ) networks = db.relationship( "Network", backref=db.backref("domain", lazy="joined"), lazy=True ) def __str__(self): return str(self.name) def to_dict(self, recursive=False): d = super().to_dict() d.update( { "name": self.name, "scopes": [str(scope) for scope in self.scopes], "networks": [str(network) for network in self.networks], } ) return d class NetworkScope(CreatedMixin, db.Model): __tablename__ = "network_scope" name = db.Column(CIText, nullable=False, unique=True) first_vlan = db.Column(db.Integer, nullable=False, unique=True) last_vlan = db.Column(db.Integer, nullable=False, unique=True) supernet = db.Column(postgresql.CIDR, nullable=False, unique=True) domain_id = db.Column(db.Integer, db.ForeignKey("domain.id"), nullable=False) description = db.Column(db.Text) networks = db.relationship( "Network", backref=db.backref("scope", lazy="joined"), lazy=True ) __table_args__ = ( sa.CheckConstraint( "first_vlan < last_vlan", name="first_vlan_less_than_last_vlan" ), ) def __str__(self): return str(self.name) @property def supernet_ip(self): return ipaddress.ip_network(self.supernet) def prefix_range(self): """Return the list of subnet prefix that can be used for this network scope""" return list(range(self.supernet_ip.prefixlen + 1, 31)) def vlan_range(self): """Return the list of vlan ids that can be assigned for this network scope The range is defined by the first and last vlan """ return range(self.first_vlan, self.last_vlan + 1) def used_vlans(self): """Return the list of vlan ids in use The list is sorted """ return sorted(network.vlan_id for network in self.networks) def available_vlans(self): """Return the list of vlan ids available""" return [vlan for vlan in self.vlan_range() if vlan not in self.used_vlans()] def used_subnets(self): """Return the list of subnets in use The list is sorted """ return sorted(network.network_ip for network in self.networks) def available_subnets(self, prefix): """Return the list of available subnets with the given prefix""" return [ str(subnet) for subnet in self.supernet_ip.subnets(new_prefix=prefix) if subnet not in self.used_subnets() ] def to_dict(self, recursive=False): d = super().to_dict() d.update( { "name": self.name, "first_vlan": self.first_vlan, "last_vlan": self.last_vlan, "supernet": self.supernet, "description": self.description, "domain": str(self.domain), "networks": [str(network) for network in self.networks], } ) return d # Define RQ JobStatus as a Python enum # We can't use the one defined in rq/job.py as it's # not a real enum (it's a custom one) and is not # compatible with sqlalchemy class JobStatus(Enum): QUEUED = "queued" FINISHED = "finished" FAILED = "failed" STARTED = "started" DEFERRED = "deferred" class Task(db.Model): # Use job id generated by RQ id = db.Column(postgresql.UUID, primary_key=True) created_at = db.Column(db.DateTime, default=utcnow()) ended_at = db.Column(db.DateTime) name = db.Column(db.Text, nullable=False, index=True) command = db.Column(db.Text) status = db.Column(db.Enum(JobStatus, name="job_status")) awx_resource = db.Column(db.Text) awx_job_id = db.Column(db.Integer) exception = db.Column(db.Text) user_id = db.Column( db.Integer, db.ForeignKey("user_account.id"), nullable=False, default=utils.fetch_current_user_id, ) @property def awx_job_url(self): if self.awx_job_id is None: return None if self.awx_resource == "job": route = "jobs/playbook" elif self.awx_resource == "workflow_job": route = "workflows" elif self.awx_resource == "inventory_source": route = "jobs/inventory" else: current_app.logger.warning(f"Unknown AWX resource: {self.awx_resource}") return None return urllib.parse.urljoin( current_app.config["AWX_URL"], f"/#/{route}/{self.awx_job_id}" ) def __str__(self): return str(self.id) def to_dict(self, recursive=False): return { "id": self.id, "name": self.name, "created_at": utils.format_field(self.created_at), "ended_at": utils.format_field(self.ended_at), "status": self.status.name, "awx_resource": self.awx_resource, "awx_job_id": self.awx_job_id, "awx_job_url": self.awx_job_url, "command": self.command, "exception": self.exception, "user": str(self.user), } def trigger_core_services_update(session): """Trigger core services update on any Interface modification. Called by before flush hook """ # In session.dirty, we need to check session.is_modified(instance) because when updating a Host, # the interface is added to the session even if not modified. # In session.deleted, session.is_modified(instance) is usually False (we shouldn't check it). # In session.new, it will always be True and we don't need to check it. for kind in ("new", "dirty", "deleted"): for instance in getattr(session, kind): if isinstance(instance, Interface) and ( (kind == "dirty" and session.is_modified(instance)) or (kind in ("new", "deleted")) ): utils.trigger_core_services_update() return True return False def trigger_inventory_update(session): """Trigger an inventory update in AWX Update on any AnsibleGroup/Cname/Domain/Host/Network/NetworkScope modification, but not on Interface as it's triggered with core services update. Called by before flush hook """ # In session.dirty, we need to check session.is_modified(instance) because the instance # could have been added to the session without being modified # In session.deleted, session.is_modified(instance) is usually False (we shouldn't check it). # In session.new, it will always be True and we don't need to check it. for kind in ("new", "dirty", "deleted"): for instance in getattr(session, kind): if isinstance( instance, (AnsibleGroup, Cname, Domain, Host, Network, NetworkScope) ) and ( (kind == "dirty" and session.is_modified(instance)) or (kind in ("new", "deleted")) ): utils.trigger_inventory_update() return True return False @sa.event.listens_for(db.session, "before_flush") def before_flush(session, flush_context, instances): """Before flush hook Used to trigger core services and inventory update. See http://docs.sqlalchemy.org/en/latest/orm/session_events.html#before-flush """ if trigger_core_services_update(session): # This will also trigger an inventory update return trigger_inventory_update(session) # call configure_mappers after defining all the models # required by sqlalchemy_continuum sa.orm.configure_mappers() ItemVersion = version_class(Item) # Set SQLAlchemy event listeners db.event.listen(db.session, "before_flush", SearchableMixin.before_flush) db.event.listen( db.session, "after_flush_postexec", SearchableMixin.after_flush_postexec ) db.event.listen(db.session, "after_commit", SearchableMixin.after_commit)