Skip to content
Snippets Groups Projects
models.py 77.7 KiB
Newer Older
        """Ensure the vlan_id is in the scope range"""
        if value is None or self.scope is None:
            # If scope is None, we can't do any validation
            # This will occur when vlan_id is passed before scope
            # We could ensure it's not the case but main use case
            # is when editing network. This won't happen then.
            return value
        if int(value) not in self.scope.vlan_range():
            raise ValidationError(
                f"Vlan id shall be in the range [{self.scope.first_vlan} - {self.scope.last_vlan}]"
            )
        return value

    def to_dict(self, recursive=False):
        d = super().to_dict()
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        d.update(
            {
                "vlan_name": self.vlan_name,
                "vlan_id": self.vlan_id,
                "address": self.address,
                "netmask": str(self.netmask),
                "broadcast": str(self.broadcast),
Benjamin Bertrand's avatar
Benjamin Bertrand committed
                "first_ip": self.first_ip,
                "last_ip": self.last_ip,
                "gateway": self.gateway,
Benjamin Bertrand's avatar
Benjamin Bertrand committed
                "description": self.description,
                "admin_only": self.admin_only,
                "sensitive": self.sensitive,
Benjamin Bertrand's avatar
Benjamin Bertrand committed
                "scope": utils.format_field(self.scope),
                "domain": str(self.domain),
                "interfaces": [str(interface) for interface in self.interfaces],
            }
        )
        return d
class DeviceType(db.Model):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    __tablename__ = "device_type"
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(CIText, nullable=False, unique=True)

    hosts = db.relationship(
        "Host", backref=db.backref("device_type", lazy="joined"), lazy=True
    )
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    @validates("name")
    def validate_name(self, key, string):
        """Ensure the name field matches the required format"""
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        if string is not None and DEVICE_TYPE_RE.fullmatch(string) is None:
            raise ValidationError(f"'{string}' is an invalid device type name")
    def __str__(self):
        return self.name

    def to_dict(self, recursive=False):
        return {
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            "id": self.id,
            "name": self.name,
            "hosts": [str(host) for host in self.hosts],
# Table required for Many-to-Many relationships between Ansible parent and child groups
ansible_groups_parent_child_table = db.Table(
    "ansible_groups_parent_child",
    db.Column(
        "parent_group_id",
        db.Integer,
        db.ForeignKey("ansible_group.id"),
        primary_key=True,
    ),
    db.Column(
        "child_group_id",
        db.Integer,
        db.ForeignKey("ansible_group.id"),
        primary_key=True,
    ),
)


# Table required for Many-to-Many relationships between Ansible groups and hosts
ansible_groups_hosts_table = db.Table(
    "ansible_groups_hosts",
    db.Column(
        "ansible_group_id",
        db.Integer,
        db.ForeignKey("ansible_group.id"),
        primary_key=True,
    ),
    db.Column("host_id", db.Integer, db.ForeignKey("host.id"), primary_key=True),
)


class AnsibleGroupType(Enum):
    STATIC = "STATIC"
    NETWORK_SCOPE = "NETWORK_SCOPE"
    NETWORK = "NETWORK"
    DEVICE_TYPE = "DEVICE_TYPE"
    HOSTNAME = "HOSTNAME"

    def __str__(self):
        return self.name

    @classmethod
    def choices(cls):
        return [(item, item.name) for item in AnsibleGroupType]

    @classmethod
    def coerce(cls, value):
        return value if type(value) == AnsibleGroupType else AnsibleGroupType[value]


class AnsibleGroup(CreatedMixin, SearchableMixin, db.Model):
    __versioned__ = {}
    __tablename__ = "ansible_group"
    __mapping__ = {
        "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
        "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
        "user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "vars": {"type": "flattened"},
        "type": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "hosts": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "children": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
    }
    # Define id here so that it can be used in the primary and secondary join
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(CIText, nullable=False, unique=True)
    vars = db.Column(postgresql.JSONB)
    type = db.Column(
        db.Enum(AnsibleGroupType, name="ansible_group_type"),
        default=AnsibleGroupType.STATIC,
        nullable=False,
    )
    _children = db.relationship(
        "AnsibleGroup",
        secondary=ansible_groups_parent_child_table,
        primaryjoin=id == ansible_groups_parent_child_table.c.parent_group_id,
        secondaryjoin=id == ansible_groups_parent_child_table.c.child_group_id,
        backref=db.backref("_parents"),
    def __str__(self):
        return str(self.name)

    def validate_children(self, key, child):
        """Ensure the child is not in the group parents to avoid circular references"""
        if child == self:
            raise ValidationError(f"Group '{self.name}' can't be a child of itself.")
        # "all" is special for Ansible. Any group is automatically a child of "all".
        if child.name == "all":
            raise ValidationError(
                f"Adding group 'all' as child to '{self.name}' creates a recursive dependency loop."
            )

        def check_parents(group):
            """Recursively check all parents"""
            if child in group.parents:
                raise ValidationError(
                    f"Adding group '{child}' as child to '{self.name}' creates a recursive dependency loop."
                )
            for parent in group.parents:
                check_parents(parent)

        check_parents(self)
        return child

    @property
    def is_dynamic(self):
        return self.type != AnsibleGroupType.STATIC

    @property
    def hosts(self):
        if self.type == AnsibleGroupType.STATIC:
            return self._hosts
        if self.type == AnsibleGroupType.NETWORK_SCOPE:
            return (
                Host.query.join(Host.interfaces)
                .join(Interface.network)
                .join(Network.scope)
                .filter(NetworkScope.name == self.name, Interface.name == Host.name)
                .order_by(Host.name)
                .all()
            )
        if self.type == AnsibleGroupType.NETWORK:
            return (
                Host.query.join(Host.interfaces)
                .join(Interface.network)
                .filter(Network.vlan_name == self.name, Interface.name == Host.name)
                .order_by(Host.name)
                .all()
            )
        if self.type == AnsibleGroupType.DEVICE_TYPE:
            return (
                Host.query.join(Host.device_type)
                .filter(DeviceType.name == self.name)
                .order_by(Host.name)
                .all()
            )
        if self.type == AnsibleGroupType.IOC:
            return Host.query.filter(Host.is_ioc.is_(True)).order_by(Host.name).all()
        if self.type == AnsibleGroupType.HOSTNAME:
            return (
                Host.query.filter(Host.name.startswith(self.name))
                .order_by(Host.name)
                .all()
            )

    @hosts.setter
    def hosts(self, value):
        # For dynamic group type, _hosts can only be set to []
        if self.is_dynamic and value:
            raise AttributeError("can't set dynamic hosts")
        self._hosts = value

    @property
    def children(self):
        if self.type == AnsibleGroupType.NETWORK_SCOPE:
            # Return all existing network groups part of the scope
            network_children = (
                AnsibleGroup.query.filter(AnsibleGroup.type == AnsibleGroupType.NETWORK)
                .join(Network, AnsibleGroup.name == Network.vlan_name)
                .join(NetworkScope)
                .filter(NetworkScope.name == self.name)
                .all()
            )
            return sorted(self._children + network_children, key=attrgetter("name"))
        return sorted(self._children, key=attrgetter("name"))

    @children.setter
    def children(self, value):
        if self.type == AnsibleGroupType.NETWORK_SCOPE:
            # Forbid setting a NETWORK group as child
            # Groups linked to networks part of the scope are added automatically
            # Also forbid NETWORK_SCOPE group as child
            for group in value:
                if group.type in (
                    AnsibleGroupType.NETWORK,
                    AnsibleGroupType.NETWORK_SCOPE,
                ):
                    raise ValidationError(
                        f"can't set {str(group.type).lower()} group '{group}' as a network scope child"
                    )
        self._children = value

    @property
    def parents(self):
        if self.type == AnsibleGroupType.NETWORK:
            # Add the group corresponding to the network scope if it exists
            network = Network.query.filter_by(vlan_name=self.name).first()
            if network is not None:
                scope_group = AnsibleGroup.query.filter_by(
                    name=network.scope.name
                ).first()
                if scope_group is not None:
                    return sorted(self._parents + [scope_group], key=attrgetter("name"))
        return sorted(self._parents, key=attrgetter("name"))

    @parents.setter
    def parents(self, value):
        if self.type == AnsibleGroupType.NETWORK:
            # Forbid setting a NETWORK_SCOPE group as parent
            # The group linked to the scope of the network is added automatically
            # Also forbid setting a NETWORK group as it doesn't make sense
            for group in value:
                if group.type in (
                    AnsibleGroupType.NETWORK,
                    AnsibleGroupType.NETWORK_SCOPE,
                ):
                    raise ValidationError(
                        f"can't set {str(group.type).lower()} group '{group}' as a network parent"
                    )
        self._parents = value

    def to_dict(self, recursive=False):
        d = super().to_dict()
        d.update(
            {
                "name": self.name,
                "type": self.type.name,
                "hosts": [host.fqdn for host in self.hosts],
                "children": [str(child) for child in self.children],
class Host(CreatedMixin, SearchableMixin, db.Model):
    __versioned__ = {}
    __mapping__ = {
        "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
        "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
        "user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "fqdn": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "is_ioc": {"type": "boolean"},
        "device_type": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "model": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "description": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "items": {
            "properties": {
                "ics_id": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "serial_number": {
                    "type": "text",
                    "fields": {"keyword": {"type": "keyword"}},
                },
                "stack_member": {"type": "byte"},
            }
        },
        "interfaces": {
            "properties": {
                "id": {"enabled": False},
                "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
                "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
                "user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "is_main": {"type": "boolean"},
                "network": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "ip": {"type": "ip"},
                "netmask": {"enabled": False},
                "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "description": {
                    "type": "text",
                    "fields": {"keyword": {"type": "keyword"}},
                },
                "mac": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "host": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "cnames": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "domain": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "device_type": {
                    "type": "text",
                    "fields": {"keyword": {"type": "keyword"}},
                },
                "model": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "ansible_vars": {"type": "flattened"},
        "ansible_groups": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "scope": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "sensitive": {"type": "boolean"},

    # id shall be defined here to be used by SQLAlchemy-Continuum
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.Text, nullable=False, unique=True)
    description = db.Column(db.Text)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    device_type_id = db.Column(
        db.Integer, db.ForeignKey("device_type.id"), nullable=False
    )
    is_ioc = db.Column(db.Boolean, nullable=False, default=False)
    ansible_vars = db.Column(postgresql.JSONB)
    # 1. Set cascade to all (to add delete) and delete-orphan to delete all interfaces
    # when deleting a host
    # 2. Return interfaces sorted by name so that the main one (the one starting with
    # the same name as the host) is always the first one.
    # As an interface name always has to start with the name of the host, the one
    # matching the host name will always come first.
    interfaces = db.relationship(
        "Interface",
        backref=db.backref("host", lazy="joined"),
        cascade="all, delete-orphan",
        lazy="joined",
        order_by="Interface.name",
    )
    items = db.relationship(
        "Item", backref=db.backref("host", lazy="joined"), lazy="joined"
    ansible_groups = db.relationship(
        "AnsibleGroup",
        secondary=ansible_groups_hosts_table,
        lazy="joined",
        backref=db.backref("_hosts"),
    def __init__(self, **kwargs):
        # Automatically convert device_type as an instance of its class if passed as a string
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        if "device_type" in kwargs:
            kwargs["device_type"] = utils.convert_to_model(
                kwargs["device_type"], DeviceType
            )
        # Automatically convert items to a list of instances if passed as a list of ics_id
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        if "items" in kwargs:
            kwargs["items"] = [
                utils.convert_to_model(item, Item, filter_by="ics_id")
Benjamin Bertrand's avatar
Benjamin Bertrand committed
                for item in kwargs["items"]
            ]
        # Automatically convert ansible groups to a list of instances if passed as a list of strings
        if "ansible_groups" in kwargs:
            kwargs["ansible_groups"] = [
                utils.convert_to_model(group, AnsibleGroup)
                for group in kwargs["ansible_groups"]
            ]
        super().__init__(**kwargs)

    @property
    def model(self):
        """Return the model of the first linked item"""
        try:
            return utils.format_field(self.items[0].model)
        except IndexError:
            return None

    @property
    def main_interface(self):
        """Return the host main interface

        The main interface is the one that has the same name as the host
        or the first one found
        """
        # As interfaces are sorted, the first one is always the main one
        try:
            return self.interfaces[0]
        except IndexError:
            return None

    @property
    def main_network(self):
        """Return the host main interface network"""
        try:
            return self.main_interface.network
        except AttributeError:
            return None

    @property
    def scope(self):
        """Return the host main interface network scope"""
        try:
            return self.main_network.scope
        except AttributeError:
            return None

    @property
    def sensitive(self):
        """Return True if the host is on a sensitive network"""
        try:
            return self.main_network.sensitive
        except AttributeError:
            return False

    @property
    def fqdn(self):
        """Return the host fully qualified domain name

        The domain is based on the main interface
        """
        if self.main_interface:
            return f"{self.name}.{self.main_interface.network.domain}"
        else:
            return self.name

    def __str__(self):
        return str(self.name)

Benjamin Bertrand's avatar
Benjamin Bertrand committed
    @validates("name")
    def validate_name(self, key, string):
        """Ensure the name matches the required format"""
        if string is None:
            return None
        # Force the string to lowercase
        lower_string = string.lower()
        if HOST_NAME_RE.fullmatch(lower_string) is None:
            raise ValidationError(f"Host name shall match {HOST_NAME_RE.pattern}")
        existing_cname = Cname.query.filter_by(name=lower_string).first()
        if existing_cname:
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            raise ValidationError("Host name matches an existing cname")
        existing_interface = Interface.query.filter(
            Interface.name == lower_string, Interface.host_id != self.id
        ).first()
        if existing_interface:
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            raise ValidationError("Host name matches an existing interface")
        return lower_string

    def stack_members(self):
        """Return all items part of the stack sorted by stack member number"""
        members = [item for item in self.items if item.stack_member is not None]
        return sorted(members, key=lambda x: x.stack_member)

    def stack_members_numbers(self):
        """Return the list of stack member numbers"""
        return [item.stack_member for item in self.stack_members()]

    def free_stack_members(self):
        """Return the list of free stack member numbers"""
        return [nb for nb in range(0, 10) if nb not in self.stack_members_numbers()]

    def to_dict(self, recursive=False):
        # None can't be compared to not None values
        # This function replaces None by Inf so it is set at the end of the list
        # items are sorted by stack_member and then ics_id
        def none_to_inf(nb):
            return float("inf") if nb is None else int(nb)
        d = super().to_dict()
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        d.update(
            {
                "name": self.name,
                "fqdn": self.fqdn,
                "is_ioc": self.is_ioc,
Benjamin Bertrand's avatar
Benjamin Bertrand committed
                "device_type": str(self.device_type),
                "model": self.model,
Benjamin Bertrand's avatar
Benjamin Bertrand committed
                "description": self.description,
                "items": [
                    str(item)
                    for item in sorted(
                        self.items,
                        key=lambda x: (none_to_inf(x.stack_member), x.ics_id),
Benjamin Bertrand's avatar
Benjamin Bertrand committed
                "interfaces": [str(interface) for interface in self.interfaces],
                "ansible_vars": self.ansible_vars,
                "ansible_groups": [str(group) for group in self.ansible_groups],
                "scope": utils.format_field(self.scope),
                "sensitive": self.sensitive,
        if recursive:
            # Replace the list of interface names by the full representation
            # so that we can index everything in elasticsearch
            d["interfaces"] = [interface.to_dict() for interface in self.interfaces]
            # Add extra info in items
            d["items"] = sorted(
                [
                    {
                        "ics_id": item.ics_id,
                        "serial_number": item.serial_number,
                        "stack_member": item.stack_member,
                    }
                    for item in self.items
                ],
                key=lambda x: (none_to_inf(x["stack_member"]), x["ics_id"]),
class Interface(CreatedMixin, db.Model):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    network_id = db.Column(db.Integer, db.ForeignKey("network.id"), nullable=False)
    ip = db.Column(postgresql.INET, nullable=False, unique=True)
    name = db.Column(db.Text, nullable=False, unique=True)
    description = db.Column(db.Text)
    mac = db.Column(postgresql.MACADDR, nullable=True, unique=True)
    host_id = db.Column(db.Integer, db.ForeignKey("host.id"), nullable=False)
    # Add delete and delete-orphan options to automatically delete cnames when:
    # - deleting an interface
    # - de-associating a cname (removing it from the interface.cnames list)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    cnames = db.relationship(
        "Cname",
        backref=db.backref("interface", lazy="joined"),
        cascade="all, delete, delete-orphan",
        lazy="joined",
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    def __init__(self, **kwargs):
        # Always set self.host and not self.host_id to call validate_name
        host_id = kwargs.pop("host_id", None)
        if host_id is not None:
            host = Host.query.get(host_id)
        elif "host" in kwargs:
            # Automatically convert host to an instance of Host if it was passed
            # as a string
            host = utils.convert_to_model(kwargs.pop("host"), Host, "name")
        else:
            host = None
        # Always set self.network and not self.network_id to call validate_interfaces
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        network_id = kwargs.pop("network_id", None)
        if network_id is not None:
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            kwargs["network"] = Network.query.get(network_id)
        elif "network" in kwargs:
            # Automatically convert network to an instance of Network if it was passed
            # as a string
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            kwargs["network"] = utils.convert_to_model(
                kwargs["network"], Network, "vlan_name"
            )
        # WARNING! Setting self.network will call validate_interfaces in the Network class
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        # For the validation to work, self.ip must be set before!
        # Ensure that ip is passed before network
        try:
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            ip = kwargs.pop("ip")
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        except KeyError:
            # Assign first available IP
            ip = str(kwargs["network"].available_ips()[0])
        super().__init__(host=host, ip=ip, **kwargs)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    @validates("name")
    def validate_name(self, key, string):
        """Ensure the name matches the required format"""
        if string is None:
            return None
        # Force the string to lowercase
        lower_string = string.lower()
        if INTERFACE_NAME_RE.fullmatch(lower_string) is None:
            raise ValidationError(
                f"Interface name shall match {INTERFACE_NAME_RE.pattern}"
            )
        if self.host and not lower_string.startswith(self.host.name):
            raise ValidationError(
                f"Interface name shall start with the host name '{self.host}'"
            )
        existing_cname = Cname.query.filter_by(name=lower_string).first()
        if existing_cname:
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            raise ValidationError("Interface name matches an existing cname")
        existing_host = Host.query.filter(
            Host.name == lower_string, Host.id != self.host.id
        ).first()
        if existing_host:
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            raise ValidationError("Interface name matches an existing host")
        return lower_string

    @validates("mac")
    def validate_mac(self, key, string):
        """Ensure the mac is a valid MAC address"""
        if not string:
            return None
        if MAC_ADDRESS_RE.fullmatch(string) is None:
            raise ValidationError(f"'{string}' does not appear to be a MAC address")
        return string

    @validates("cnames")
    def validate_cnames(self, key, cname):
        """Ensure the cname is unique by domain"""
        existing_cnames = Cname.query.filter_by(name=cname.name).all()
        for existing_cname in existing_cnames:
            if existing_cname.domain == str(self.network.domain):
                raise ValidationError(
                    f"Duplicate cname on the {self.network.domain} domain"
                )
        return cname

    @property
    def address(self):
        return ipaddress.ip_address(self.ip)

Benjamin Bertrand's avatar
Benjamin Bertrand committed
    @property
    def is_ioc(self):
        return self.is_main and self.host.is_ioc
    @property
    def is_main(self):
        return self.name == self.host.main_interface.name

    def __str__(self):
        return str(self.name)
    def __repr__(self):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        return f"Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})"
    def to_dict(self, recursive=False):
        d = super().to_dict()
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        d.update(
            {
                "is_main": self.is_main,
Benjamin Bertrand's avatar
Benjamin Bertrand committed
                "network": str(self.network),
                "ip": self.ip,
                "netmask": str(self.network.netmask),
Benjamin Bertrand's avatar
Benjamin Bertrand committed
                "name": self.name,
                "description": self.description,
Benjamin Bertrand's avatar
Benjamin Bertrand committed
                "mac": utils.format_field(self.mac),
                "host": utils.format_field(self.host),
                "cnames": [str(cname) for cname in self.cnames],
                "domain": str(self.network.domain),
            }
        )
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            d["device_type"] = str(self.host.device_type)
            d["model"] = utils.format_field(self.host.model)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            d["device_type"] = None
        return d


class Mac(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    address = db.Column(postgresql.MACADDR, nullable=False, unique=True)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    item_id = db.Column(db.Integer, db.ForeignKey("item.id"))

    def __str__(self):
        return str(self.address)

Benjamin Bertrand's avatar
Benjamin Bertrand committed
    @validates("address")
    def validate_address(self, key, string):
        """Ensure the address is a valid MAC address"""
        if string is None:
            return None
        if MAC_ADDRESS_RE.fullmatch(string) is None:
            raise ValidationError(f"'{string}' does not appear to be a MAC address")
        return string

    def to_dict(self, recursive=False):
        return {
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            "id": self.id,
            "address": self.address,
            "item": utils.format_field(self.item),
class Cname(CreatedMixin, db.Model):
    name = db.Column(db.Text, nullable=False)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    interface_id = db.Column(db.Integer, db.ForeignKey("interface.id"), nullable=False)
    def __init__(self, **kwargs):
        # Always set self.interface and not self.interface_id to call validate_cnames
        interface_id = kwargs.pop("interface_id", None)
        if interface_id is not None:
            kwargs["interface"] = Interface.query.get(interface_id)
        super().__init__(**kwargs)

    def __str__(self):
        return str(self.name)

    @property
    def domain(self):
        """Return the cname domain name"""
        return str(self.interface.network.domain)

    @property
    def fqdn(self):
        """Return the cname fully qualified domain name"""
        return f"{self.name}.{self.domain}"

    @validates("name")
    def validate_name(self, key, string):
        """Ensure the name matches the required format"""
        if string is None:
            return None
        # Force the string to lowercase
        lower_string = string.lower()
        if HOST_NAME_RE.fullmatch(lower_string) is None:
            raise ValidationError(f"cname shall match {HOST_NAME_RE.pattern}")
        existing_interface = Interface.query.filter_by(name=lower_string).first()
        if existing_interface:
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            raise ValidationError("cname matches an existing interface")
        existing_host = Host.query.filter_by(name=lower_string).first()
        if existing_host:
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            raise ValidationError("cname matches an existing host")
        return lower_string

    def to_dict(self, recursive=False):
        d = super().to_dict()
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        d.update({"name": self.name, "interface": str(self.interface)})
        return d
Benjamin Bertrand's avatar
Benjamin Bertrand committed
class Domain(CreatedMixin, db.Model):
    name = db.Column(db.Text, nullable=False, unique=True)

    scopes = db.relationship(
        "NetworkScope", backref=db.backref("domain", lazy="joined"), lazy=True
    )
    networks = db.relationship(
        "Network", backref=db.backref("domain", lazy="joined"), lazy=True
    )
Benjamin Bertrand's avatar
Benjamin Bertrand committed

    def __str__(self):
        return str(self.name)

    def to_dict(self, recursive=False):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        d = super().to_dict()
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        d.update(
            {
                "name": self.name,
                "scopes": [str(scope) for scope in self.scopes],
                "networks": [str(network) for network in self.networks],
            }
        )
class NetworkScope(CreatedMixin, db.Model):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    __tablename__ = "network_scope"
    name = db.Column(CIText, nullable=False, unique=True)
    first_vlan = db.Column(db.Integer, nullable=True, unique=True)
    last_vlan = db.Column(db.Integer, nullable=True, unique=True)
    supernet = db.Column(postgresql.CIDR, nullable=False, unique=True)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    domain_id = db.Column(db.Integer, db.ForeignKey("domain.id"), nullable=False)
    description = db.Column(db.Text)
    networks = db.relationship(
        "Network", backref=db.backref("scope", lazy="joined"), lazy=True
    )
    __table_args__ = (
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        sa.CheckConstraint(
            "first_vlan < last_vlan", name="first_vlan_less_than_last_vlan"
        ),
    def __str__(self):
        return str(self.name)

    @validates("supernet")
    def validate_supernet(self, key, supernet):
        """Ensure the supernet doesn't overlap existing supernets

        Also ensure it's a supernet of all existing networks (when editing)
        """
        supernet_address = ipaddress.ip_network(supernet)
        existing_scopes = NetworkScope.query.filter(NetworkScope.id != self.id).all()
        for existing_scope in existing_scopes:
            if supernet_address.overlaps(existing_scope.supernet_ip):
                raise ValidationError(
                    f"{supernet} overlaps {existing_scope} ({existing_scope.supernet_ip})"
                )
        for network in self.networks:
            if not network.network_ip.subnet_of(supernet_address):
                raise ValidationError(
                    f"{network.network_ip} is not a subnet of {supernet}"
                )
        return supernet

    @validates("networks")
    def validate_networks(self, key, network):
        """Ensure the network is included in the supernet and doesn't overlap
        existing networks"""
        if not network.network_ip.subnet_of(self.supernet_ip):
            raise ValidationError(
                f"{network.network_ip} is not a subnet of {self.supernet_ip}"
            )
        existing_networks = Network.query.filter_by(scope=self).all()
        for existing_network in existing_networks:
            if existing_network.id == network.id:
                # Same network added again during edit via admin interface
                continue
            if network.network_ip.overlaps(existing_network.network_ip):
                raise ValidationError(
                    f"{network.network_ip} overlaps {existing_network} ({existing_network.network_ip})"
                )
        return network

    @validates("first_vlan")
    def validate_first_vlan(self, key, value):
        """Ensure the first vlan is lower than any network vlan id"""
        if value is None:
            return value
        for network in self.networks:
            if int(value) > network.vlan_id:
                raise ValidationError(
                    f"First vlan shall be lower than {network.vlan_name} vlan: {network.vlan_id}"
                )
        return value

    @validates("last_vlan")
    def validate_last_vlan(self, key, value):
        """Ensure the last vlan is greater than any network vlan id"""
        if value is None:
            return value
        for network in self.networks:
            if int(value) < network.vlan_id:
                raise ValidationError(
                    f"Last vlan shall be greater than {network.vlan_name} vlan: {network.vlan_id}"
                )
        return value

    @property
    def supernet_ip(self):
        return ipaddress.ip_network(self.supernet)

    def prefix_range(self):
        """Return the list of subnet prefix that can be used for this network scope"""
        return list(range(self.supernet_ip.prefixlen + 1, 31))

    def vlan_range(self):
        """Return the list of vlan ids that can be assigned for this network scope

        The range is defined by the first and last vlan
        """
        if self.first_vlan is None or self.last_vlan is None:
            return []
        return range(self.first_vlan, self.last_vlan + 1)

    def used_vlans(self):
        """Return the list of vlan ids in use

        The list is sorted
        """
        if self.first_vlan is None or self.last_vlan is None:
            return []
        return sorted(network.vlan_id for network in self.networks)

    def available_vlans(self):
        """Return the list of vlan ids available"""
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        return [vlan for vlan in self.vlan_range() if vlan not in self.used_vlans()]

    def used_subnets(self):
        """Return the list of subnets in use

        The list is sorted
        """
        return sorted(network.network_ip for network in self.networks)

    def available_subnets(self, prefix):
        """Return the list of available subnets with the given prefix

        Overlapping subnets with existing networks are filtered"""
        used = self.used_subnets()
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        return [
            str(subnet)
            for subnet in self.supernet_ip.subnets(new_prefix=prefix)
            if not utils.overlaps(subnet, used)
    def to_dict(self, recursive=False):
        d = super().to_dict()
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        d.update(
            {
                "name": self.name,
                "first_vlan": self.first_vlan,
                "last_vlan": self.last_vlan,
                "supernet": self.supernet,
                "description": self.description,
                "domain": str(self.domain),
                "networks": [str(network) for network in self.networks],
            }
        )
        return d
Benjamin Bertrand's avatar
Benjamin Bertrand committed
# Define RQ JobStatus as a Python enum
# We can't use the one defined in rq/job.py as it's
# not a real enum (it's a custom one) and is not
# compatible with sqlalchemy
class JobStatus(Enum):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    QUEUED = "queued"
    FINISHED = "finished"
    FAILED = "failed"
    STARTED = "started"
    DEFERRED = "deferred"
Benjamin Bertrand's avatar
Benjamin Bertrand committed


class Task(db.Model):
    # Use job id generated by RQ
    id = db.Column(postgresql.UUID, primary_key=True)
    created_at = db.Column(db.DateTime, default=utcnow())
    ended_at = db.Column(db.DateTime)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    name = db.Column(db.Text, nullable=False, index=True)
    command = db.Column(db.Text)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    status = db.Column(db.Enum(JobStatus, name="job_status"))
    awx_resource = db.Column(db.Text)
    awx_job_id = db.Column(db.Integer)
    exception = db.Column(db.Text)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    user_id = db.Column(
        db.Integer,
        db.ForeignKey("user_account.id"),
        nullable=False,
        default=utils.fetch_current_user_id,
    )
    depends_on_id = db.Column(postgresql.UUID, db.ForeignKey("task.id"))

    reverse_dependencies = db.relationship(
        "Task", backref=db.backref("depends_on", remote_side=[id])
    )
    @property
    def awx_job_url(self):
        if self.awx_job_id is None:
            return None
        if self.awx_resource == "job":
            route = "jobs/playbook"
        elif self.awx_resource == "workflow_job":
            route = "workflows"
        elif self.awx_resource == "inventory_source":
            route = "jobs/inventory"
        else:
            current_app.logger.warning(f"Unknown AWX resource: {self.awx_resource}")
            return None
        return urllib.parse.urljoin(
            current_app.config["AWX_URL"], f"/#/{route}/{self.awx_job_id}"
    def update_reverse_dependencies(self):
        """Recursively set all reverse dependencies to FAILED

        When a RQ job is set to FAILED, its reverse dependencies will stay to DEFERRED.
        This method allows to easily update the corresponding tasks status.

        The tasks are modified but the session is not committed.
        """

        def set_reverse_dependencies_to_failed(task):
            for dependency in task.reverse_dependencies:
                current_app.logger.info(
                    f"Setting {dependency.id} ({dependency.name}) to FAILED due to failed dependency"
                )
                dependency.status = JobStatus.FAILED
                set_reverse_dependencies_to_failed(dependency)

        set_reverse_dependencies_to_failed(self)

Benjamin Bertrand's avatar
Benjamin Bertrand committed
    def __str__(self):
        return str(self.id)

    def to_dict(self, recursive=False):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
        return {
Benjamin Bertrand's avatar
Benjamin Bertrand committed
            "id": self.id,
            "name": self.name,