# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~

This module implements the models used in the app.

:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.

"""
import ipaddress
import string
import qrcode
import itertools
import urllib.parse
import elasticsearch
import sqlalchemy as sa
from enum import Enum
from operator import attrgetter
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import validates
from sqlalchemy_continuum import make_versioned, version_class
from citext import CIText
from flask import current_app
from flask_login import UserMixin, current_user
from wtforms import ValidationError
from rq import Queue
from .extensions import db, login_manager, ldap_manager, cache
from .plugins import FlaskUserPlugin
from .validators import (
    ICS_ID_RE,
    HOST_NAME_RE,
    INTERFACE_NAME_RE,
    VLAN_NAME_RE,
    MAC_ADDRESS_RE,
    DEVICE_TYPE_RE,
)
from . import utils, search


make_versioned(plugins=[FlaskUserPlugin()])


# See http://docs.sqlalchemy.org/en/latest/core/compiler.html#utc-timestamp-function
class utcnow(sa.sql.expression.FunctionElement):
    type = sa.types.DateTime()


@sa.ext.compiler.compiles(utcnow, "postgresql")
def pg_utcnow(element, compiler, **kw):
    return "TIMEZONE('utc', CURRENT_TIMESTAMP)"


def temporary_ics_ids():
    """Generator that returns the full list of temporary ICS ids"""
    return (
        f'{current_app.config["TEMPORARY_ICS_ID"]}{letter}{number:0=3d}'
        for letter in string.ascii_uppercase
        for number in range(0, 1000)
    )


def used_temporary_ics_ids():
    """Return a set with the temporary ICS ids used"""
    temporary_items = Item.query.filter(
        Item.ics_id.startswith(current_app.config["TEMPORARY_ICS_ID"])
    ).all()
    return {item.ics_id for item in temporary_items}


def get_temporary_ics_id():
    """Return a temporary ICS id that is available"""
    used_temp_ics_ids = used_temporary_ics_ids()
    for ics_id in temporary_ics_ids():
        if ics_id not in used_temp_ics_ids:
            return ics_id
    else:
        raise ValueError("No temporary ICS id available")


@login_manager.user_loader
def load_user(user_id):
    """User loader callback for flask-login

    :param str user_id: unicode ID of a user
    :returns: corresponding user object or None
    """
    return User.query.get(int(user_id))


@ldap_manager.save_user
def save_user(dn, username, data, memberships):
    """User saver for flask-ldap3-login

    This method is called whenever a LDAPLoginForm()
    successfully validates.
    """
    user = User.query.filter_by(username=username).first()
    if user is None:
        user = User(
            username=username,
            display_name=utils.attribute_to_string(data["cn"]) or username,
            email=utils.attribute_to_string(data["mail"]),
        )
    # Always update the user groups to keep them up-to-date
    user.groups = sorted(
        [utils.attribute_to_string(group["cn"]) for group in memberships]
    )
    db.session.add(user)
    db.session.commit()
    return user


# Tables required for Many-to-Many relationships between users and favorites attributes
favorite_manufacturers_table = db.Table(
    "favorite_manufacturers",
    db.Column(
        "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True
    ),
    db.Column(
        "manufacturer_id",
        db.Integer,
        db.ForeignKey("manufacturer.id"),
        primary_key=True,
    ),
)
favorite_models_table = db.Table(
    "favorite_models",
    db.Column(
        "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True
    ),
    db.Column("model_id", db.Integer, db.ForeignKey("model.id"), primary_key=True),
)
favorite_locations_table = db.Table(
    "favorite_locations",
    db.Column(
        "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True
    ),
    db.Column(
        "location_id", db.Integer, db.ForeignKey("location.id"), primary_key=True
    ),
)
favorite_statuses_table = db.Table(
    "favorite_statuses",
    db.Column(
        "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True
    ),
    db.Column("status_id", db.Integer, db.ForeignKey("status.id"), primary_key=True),
)
favorite_actions_table = db.Table(
    "favorite_actions",
    db.Column(
        "user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True
    ),
    db.Column("action_id", db.Integer, db.ForeignKey("action.id"), primary_key=True),
)


class User(db.Model, UserMixin):
    # "user" is a reserved word in postgresql
    # so let's use another name
    __tablename__ = "user_account"

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.Text, nullable=False, unique=True)
    display_name = db.Column(db.Text, nullable=False)
    email = db.Column(db.Text)
    groups = db.Column(postgresql.ARRAY(db.Text), default=[])
    tokens = db.relationship("Token", backref="user")
    tasks = db.relationship("Task", backref="user")
    # The favorites won't be accessed very often so we load them
    # only when necessary (lazy=True)
    favorite_manufacturers = db.relationship(
        "Manufacturer",
        secondary=favorite_manufacturers_table,
        lazy=True,
        backref=db.backref("favorite_users", lazy=True),
    )
    favorite_models = db.relationship(
        "Model",
        secondary=favorite_models_table,
        lazy=True,
        backref=db.backref("favorite_users", lazy=True),
    )
    favorite_locations = db.relationship(
        "Location",
        secondary=favorite_locations_table,
        lazy=True,
        backref=db.backref("favorite_users", lazy=True),
    )
    favorite_statuses = db.relationship(
        "Status",
        secondary=favorite_statuses_table,
        lazy=True,
        backref=db.backref("favorite_users", lazy=True),
    )
    favorite_actions = db.relationship(
        "Action",
        secondary=favorite_actions_table,
        lazy=True,
        backref=db.backref("favorite_users", lazy=True),
    )

    def get_id(self):
        """Return the user id as unicode

        Required by flask-login
        """
        return str(self.id)

    @property
    def csentry_groups(self):
        """Return the list of CSEntry groups the user belong to

        Groups are assigned based on the CSENTRY_LDAP_GROUPS mapping with LDAP groups
        """
        if not hasattr(self, "_csentry_groups"):
            self._csentry_groups = []
            for csentry_group, ldap_groups in current_app.config[
                "CSENTRY_LDAP_GROUPS"
            ].items():
                if set(self.groups) & set(ldap_groups):
                    self._csentry_groups.append(csentry_group)
            # Add the network group based on CSENTRY_NETWORK_SCOPES_LDAP_GROUPS
            network_ldap_groups = set(
                itertools.chain(
                    *current_app.config["CSENTRY_NETWORK_SCOPES_LDAP_GROUPS"].values()
                )
            )
            if set(self.groups) & network_ldap_groups:
                self._csentry_groups.append("network")
        return self._csentry_groups

    @property
    def csentry_network_scopes(self):
        """Return the list of CSEntry network scopes the user has access to

        Network scopes are assigned based on the CSENTRY_NETWORK_SCOPES_LDAP_GROUPS mapping with LDAP groups
        """
        if not hasattr(self, "_csentry_network_scopes"):
            self._csentry_network_scopes = []
            for network_scope, ldap_groups in current_app.config[
                "CSENTRY_NETWORK_SCOPES_LDAP_GROUPS"
            ].items():
                if set(self.groups) & set(ldap_groups):
                    self._csentry_network_scopes.append(network_scope)
        return self._csentry_network_scopes

    @property
    def is_admin(self):
        return "admin" in self.csentry_groups

    @property
    def is_auditor(self):
        return "auditor" in self.csentry_groups

    def is_member_of_one_group(self, groups):
        """Return True if the user is at least member of one of the given CSEntry groups"""
        return bool(set(groups) & set(self.csentry_groups))

    def has_access_to_network(self, network):
        """Return True if the user has access to the network

        - admin users have access to all networks
        - normal users must have access to the network scope
        - normal users don't have access to admin_only networks (whatever the network scope)
        - LOGIN_DISABLED can be set to True to turn off authentication check when testing.
          In this case, this function always returns True.
        """
        if current_app.config.get("LOGIN_DISABLED") or self.is_admin or network is None:
            return True
        if network.admin_only:
            # True is already returned for admin users
            return False
        return str(network.scope) in self.csentry_network_scopes

    def can_view_network(self, network):
        """Return True if the user can view the network

        - admin and auditor users can view all networks
        - non sensitive networks can be viewed by anyone
        - normal users must have access to the network scope to view sensitive networks
        - LOGIN_DISABLED can be set to True to turn off authentication check when testing.
          In this case, this function always returns True.
        """
        if (
            current_app.config.get("LOGIN_DISABLED")
            or self.is_admin
            or self.is_auditor
            or not network.sensitive
        ):
            return True
        return str(network.scope) in self.csentry_network_scopes

    @property
    def sensitive_filter(self):
        """Return the elasticsearch query to use to filter sensitive hosts"""
        filter = "sensitive:false"
        if self.csentry_network_scopes:
            scopes_filter = " OR ".join(
                [f"scope:{scope}" for scope in self.csentry_network_scopes]
            )
            filter = f"{filter} OR (sensitive:true AND ({scopes_filter}))"
        return filter

    def can_view_host(self, host):
        """Return True if the user can view the host

        - admin and auditor users can view all hosts
        - non sensitive hosts can be viewed by anyone
        - normal users must have access to the network scope to view sensitive hosts
        - LOGIN_DISABLED can be set to True to turn off authentication check when testing.
          In this case, this function always returns True.
        """
        if (
            current_app.config.get("LOGIN_DISABLED")
            or self.is_admin
            or self.is_auditor
            or not host.sensitive
        ):
            return True
        return str(host.scope) in self.csentry_network_scopes

    def can_create_vm(self, host):
        """Return True if the user can create the VM

        - host.device_type shall be VirtualMachine
        - admin users can create anything
        - normal users must have access to the network to create VIOC
        - normal users can only create a VM if the host is in one of the allowed network scopes
        - LOGIN_DISABLED can be set to True to turn off authentication check when testing.
          In this case, this function always returns True.
        """
        if str(host.device_type) != "VirtualMachine":
            return False
        if current_app.config.get("LOGIN_DISABLED") or self.is_admin:
            return True
        if not self.has_access_to_network(host.main_network):
            # True is already returned for admin users
            return False
        if host.is_ioc:
            # VIOC can be created by anyone having access to the network
            return True
        # VM can only be created if the network scope is allowed
        return (
            str(host.scope) in current_app.config["ALLOWED_VM_CREATION_NETWORK_SCOPES"]
        )

    def can_set_boot_profile(self, host):
        """Return True if the user can set the network boot profile

        - host.device_type shall be in ALLOWED_SET_BOOT_PROFILE_DEVICE_TYPES
        - admin users can always set the profile
        - normal users must have access to the network
        - normal users can only set the boot profile if the host is in one of the allowed network scopes
        - LOGIN_DISABLED can be set to True to turn off authentication check when testing.
          In this case, this function always returns True.
        """
        if (
            str(host.device_type)
            not in current_app.config["ALLOWED_SET_BOOT_PROFILE_DEVICE_TYPES"]
        ):
            return False
        if current_app.config.get("LOGIN_DISABLED") or self.is_admin:
            return True
        if not self.has_access_to_network(host.main_network):
            # True is already returned for admin users
            return False
        # Boot profile can only be set if the network scope is allowed
        return (
            str(host.scope)
            in current_app.config["ALLOWED_SET_BOOT_PROFILE_NETWORK_SCOPES"]
        )

    def can_delete_host(self, host):
        """Return True if the user can delete the host

        - admin users can delete any host
        - normal users must be creator of the host
        - LOGIN_DISABLED can be set to True to turn off authentication check when testing.
          In this case, this function always returns True.
        """
        if current_app.config.get("LOGIN_DISABLED") or self.is_admin:
            return True
        return self.id == host.user.id

    def favorite_attributes(self):
        """Return all user's favorite attributes"""
        favorites_list = [
            self.favorite_manufacturers,
            self.favorite_models,
            self.favorite_locations,
            self.favorite_statuses,
            self.favorite_actions,
        ]
        return [favorite for favorites in favorites_list for favorite in favorites]

    def launch_task(self, name, func, queue_name="normal", **kwargs):
        """Launch a task in the background using RQ

        The task is added to the session but not committed.
        """
        q = Queue(queue_name, default_timeout=current_app.config["RQ_DEFAULT_TIMEOUT"])
        job = q.enqueue(f"app.tasks.{func}", **kwargs)
        # The status will be set to QUEUED or DEFERRED
        task = Task(
            id=job.id,
            name=name,
            awx_resource=kwargs.get("resource", None),
            command=job.get_call_string(),
            status=JobStatus(job.get_status()),
            user=self,
            depends_on_id=kwargs.get("depends_on", None),
        )
        db.session.add(task)
        return task

    def get_tasks(self, all=False):
        """Return all tasks created by the current user

        If the user is admin or auditor and all is set to True, will return all tasks
        """
        if all and (self.is_admin or self.is_auditor):
            return Task.query.order_by(Task.created_at).all()
        return Task.query.filter_by(user=self).order_by(Task.created_at).all()

    def get_tasks_in_progress(self, name):
        """Return all the <name> tasks not finished or failed"""
        return (
            Task.query.filter_by(name=name)
            .filter(~Task.status.in_([JobStatus.FINISHED, JobStatus.FAILED]))
            .all()
        )

    def get_task_started(self, name):
        """Return the <name> task currently running or None"""
        return Task.query.filter_by(name=name, status=JobStatus.STARTED).first()

    def is_task_waiting(self, name):
        """Return True if a <name> task is waiting

        Waiting means:
            - queued
            - deferred

        A deferred task will be set to failed if the task it depends on fails.
        """
        count = (
            Task.query.filter_by(name=name)
            .filter(Task.status.in_([JobStatus.DEFERRED, JobStatus.QUEUED]))
            .count()
        )
        return count > 0

    def get_task_waiting(self, name):
        """Return the latest <name> task currently waiting or None

        Waiting means:
            - queued
            - deferred

        A deferred task will be set to failed if the task it depends on fails.
        """
        return (
            Task.query.filter_by(name=name)
            .filter(Task.status.in_([JobStatus.DEFERRED, JobStatus.QUEUED]))
            .order_by(Task.created_at.desc())
            .first()
        )

    def __str__(self):
        return self.username

    def to_dict(self, recursive=False):
        return {
            "id": self.id,
            "username": self.username,
            "display_name": self.display_name,
            "email": self.email,
            "groups": self.csentry_groups,
        }


class SearchableMixin(object):
    """Add search capability to a class"""

    @classmethod
    def search(cls, query, page=1, per_page=20, sort=None, filter_sensitive=False):
        if filter_sensitive and not (current_user.is_admin or current_user.is_auditor):
            if query == "*":
                query = current_user.sensitive_filter
            else:
                query = f"({query}) AND ({current_user.sensitive_filter})"
        try:
            ids, total = search.query_index(
                cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"],
                query,
                page,
                per_page,
                sort,
            )
        except elasticsearch.ElasticsearchException as e:
            # Invalid query
            current_app.logger.warning(e)
            return cls.query.filter_by(id=0), 0
        if total == 0:
            return cls.query.filter_by(id=0), 0
        when = [(value, i) for i, value in enumerate(ids)]
        return (
            cls.query.filter(cls.id.in_(ids)).order_by(db.case(when, value=cls.id)),
            total,
        )

    @classmethod
    def before_flush(cls, session, flush_context, instances):
        """Save the new/modified/deleted objects"""
        # The session.new / dirty / deleted lists are empty in the after_flush_postexec event.
        # We need to record them here
        session._changes = {"add_obj": [], "delete": []}
        for obj in itertools.chain(session.new, session.dirty):
            if isinstance(obj, SearchableMixin):
                index = (
                    obj.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"]
                )
                current_app.logger.debug(
                    f"object to add/update in the {index} index: {obj}"
                )
                session._changes["add_obj"].append((index, obj))
        for obj in session.deleted:
            if isinstance(obj, SearchableMixin):
                index = (
                    obj.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"]
                )
                current_app.logger.debug(
                    f"object to remove from the {index} index: {obj}"
                )
                session._changes["delete"].append((index, obj.id))

    @classmethod
    def after_flush_postexec(cls, session, flush_context):
        """Retrieve the new and updated objects representation"""
        if not hasattr(session, "_changes") or session._changes is None:
            return
        # - We can't call obj.to_dict() in the before_flush event because the id
        #   hasn't been allocated yet (for new objects) and other fields haven't been updated
        #   (default values like created_at/updated_at and some relationships).
        # - We can't call obj.to_dict() in the after_commit event because it would raise:
        #   sqlalchemy.exc.InvalidRequestError:
        #   This session is in 'committed' state; no further SQL can be emitted within this transaction.
        session._changes["add"] = [
            (index, obj.to_dict(recursive=True))
            for index, obj in session._changes["add_obj"]
        ]

    @classmethod
    def after_commit(cls, session):
        """Update the elasticsearch index"""
        if not hasattr(session, "_changes") or session._changes is None:
            return
        for index, body in session._changes["add"]:
            search.add_to_index(index, body)
        for index, id in session._changes["delete"]:
            search.remove_from_index(index, id)
        session._changes = None

    @classmethod
    def delete_index(cls, **kwargs):
        """Delete the index of the class"""
        current_app.logger.info(f"Delete the {cls.__tablename__} index")
        search.delete_index(
            cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"],
            **kwargs,
        )

    @classmethod
    def create_index(cls, **kwargs):
        """Create the index of the class"""
        if hasattr(cls, "__mapping__"):
            current_app.logger.info(f"Create the {cls.__tablename__} index")
            search.create_index(
                cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"],
                cls.__mapping__,
                **kwargs,
            )
        else:
            current_app.logger.info(
                f"No mapping defined for {cls.__tablename__}. No index created."
            )

    @classmethod
    def reindex(cls, delete=True):
        """Force to reindex all instances of the class"""
        current_app.logger.info(f"Force to re-index all {cls.__tablename__} instances")
        if delete:
            # Ignore index_not_found_exception
            cls.delete_index(ignore_unavailable=True)
            cls.create_index()
        for obj in cls.query:
            search.add_to_index(
                cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"],
                obj.to_dict(recursive=True),
            )


class Token(db.Model):
    """Table to store valid tokens"""

    id = db.Column(db.Integer, primary_key=True)
    jti = db.Column(postgresql.UUID, nullable=False)
    token_type = db.Column(db.Text, nullable=False)
    user_id = db.Column(db.Integer, db.ForeignKey("user_account.id"), nullable=False)
    issued_at = db.Column(db.DateTime, nullable=False)
    # expires can be set to None for tokens that never expire
    expires = db.Column(db.DateTime)
    description = db.Column(db.Text)

    __table_args__ = (sa.UniqueConstraint(jti, user_id),)

    def __str__(self):
        return self.jti


class QRCodeMixin:
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(CIText, nullable=False, unique=True)
    description = db.Column(db.Text)

    def image(self):
        """Return a QRCode image to identify a record

        The QRCode includes:
             - CSE prefix
             - the table name
             - the name of the record
        """
        data = ":".join(["CSE", self.__tablename__, self.name])
        return qrcode.make(data, version=1, box_size=5)

    @cache.memoize(timeout=0)
    def base64_image(self):
        """Return the QRCode image as base64 string"""
        return utils.image_to_base64(self.image())

    def is_user_favorite(self):
        """Return True if the attribute is part of the current user favorites"""
        return current_user in self.favorite_users

    def __str__(self):
        return self.name

    def __repr__(self):
        # The cache.memoize decorator performs a repr() on the passed in arguments
        # __repr__ is used as part of the cache key and shall be a uniquely identifying string
        # See https://flask-caching.readthedocs.io/en/latest/#memoization
        return f"{self.__class__.__name__}(id={self.id}, name={self.name})"

    def to_dict(self, recursive=False):
        return {
            "id": self.id,
            "name": self.name,
            "description": self.description,
            "qrcode": self.base64_image(),
        }


class Action(QRCodeMixin, db.Model):
    pass


class Manufacturer(QRCodeMixin, db.Model):
    items = db.relationship("Item", back_populates="manufacturer")


class Model(QRCodeMixin, db.Model):
    items = db.relationship("Item", back_populates="model")


class Location(QRCodeMixin, db.Model):
    items = db.relationship("Item", back_populates="location")


class Status(QRCodeMixin, db.Model):
    items = db.relationship("Item", back_populates="status")


class CreatedMixin:
    id = db.Column(db.Integer, primary_key=True)
    created_at = db.Column(db.DateTime, default=utcnow())
    updated_at = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow())

    # Using ForeignKey and relationship in mixin requires the @declared_attr decorator
    # See http://docs.sqlalchemy.org/en/latest/orm/extensions/declarative/mixins.html
    @declared_attr
    def user_id(cls):
        return db.Column(
            db.Integer,
            db.ForeignKey("user_account.id"),
            nullable=False,
            default=utils.fetch_current_user_id,
        )

    @declared_attr
    def user(cls):
        return db.relationship("User")

    def __init__(self, **kwargs):
        # Automatically convert created_at/updated_at strings
        # to datetime object
        for key in ("created_at", "updated_at"):
            if key in kwargs and isinstance(kwargs[key], str):
                kwargs[key] = utils.parse_to_utc(kwargs[key])
        super().__init__(**kwargs)

    def to_dict(self):
        return {
            "id": self.id,
            "created_at": utils.format_field(self.created_at),
            "updated_at": utils.format_field(self.updated_at),
            "user": str(self.user),
        }


class Item(CreatedMixin, SearchableMixin, db.Model):
    __versioned__ = {
        "exclude": [
            "created_at",
            "user_id",
            "ics_id",
            "serial_number",
            "manufacturer_id",
            "model_id",
        ]
    }
    __mapping__ = {
        "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
        "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
        "user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "ics_id": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "serial_number": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "quantity": {"type": "long"},
        "manufacturer": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "model": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "location": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "status": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "parent": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "children": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "macs": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "host": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "stack_member": {"type": "byte"},
        "history": {"enabled": False},
        "comments": {"type": "text"},
    }

    # WARNING! Inheriting id from CreatedMixin doesn't play well with
    # SQLAlchemy-Continuum. It has to be defined here.
    id = db.Column(db.Integer, primary_key=True)
    ics_id = db.Column(
        db.Text, unique=True, nullable=False, index=True, default=get_temporary_ics_id
    )
    serial_number = db.Column(db.Text, nullable=False)
    quantity = db.Column(db.Integer, nullable=False, default=1)
    manufacturer_id = db.Column(db.Integer, db.ForeignKey("manufacturer.id"))
    model_id = db.Column(db.Integer, db.ForeignKey("model.id"))
    location_id = db.Column(db.Integer, db.ForeignKey("location.id"))
    status_id = db.Column(db.Integer, db.ForeignKey("status.id"))
    parent_id = db.Column(db.Integer, db.ForeignKey("item.id"))
    host_id = db.Column(db.Integer, db.ForeignKey("host.id"))
    stack_member = db.Column(db.SmallInteger)

    manufacturer = db.relationship(
        "Manufacturer", back_populates="items", lazy="joined"
    )
    model = db.relationship("Model", back_populates="items", lazy="joined")
    location = db.relationship("Location", back_populates="items", lazy="joined")
    status = db.relationship("Status", back_populates="items", lazy="joined")
    children = db.relationship("Item", backref=db.backref("parent", remote_side=[id]))
    macs = db.relationship("Mac", backref="item", lazy="joined")
    comments = db.relationship(
        "ItemComment", backref="item", cascade="all, delete-orphan", lazy="joined"
    )

    __table_args__ = (
        sa.CheckConstraint(
            "stack_member >= 0 AND stack_member <=9", name="stack_member_range"
        ),
        sa.UniqueConstraint(host_id, stack_member, name="uq_item_host_id_stack_member"),
    )

    def __init__(self, **kwargs):
        # Automatically convert manufacturer/model/location/status to an
        # instance of their class if passed as a string
        for key, cls in [
            ("manufacturer", Manufacturer),
            ("model", Model),
            ("location", Location),
            ("status", Status),
        ]:
            if key in kwargs:
                kwargs[key] = utils.convert_to_model(kwargs[key], cls)
        super().__init__(**kwargs)

    def __str__(self):
        return str(self.ics_id)

    @validates("ics_id")
    def validate_ics_id(self, key, string):
        """Ensure the ICS id field matches the required format"""
        if string is not None and ICS_ID_RE.fullmatch(string) is None:
            raise ValidationError("ICS id shall match [A-Z]{3}[0-9]{3}")
        return string

    def to_dict(self, recursive=False):
        d = super().to_dict()
        d.update(
            {
                "ics_id": self.ics_id,
                "serial_number": self.serial_number,
                "quantity": self.quantity,
                "manufacturer": utils.format_field(self.manufacturer),
                "model": utils.format_field(self.model),
                "location": utils.format_field(self.location),
                "status": utils.format_field(self.status),
                "parent": utils.format_field(self.parent),
                "children": [str(child) for child in self.children],
                "macs": [str(mac) for mac in self.macs],
                "host": utils.format_field(self.host),
                "stack_member": utils.format_field(self.stack_member),
                "history": self.history(),
                "comments": [str(comment) for comment in self.comments],
            }
        )
        return d

    def to_row_dict(self):
        """Convert to a dict that can easily be exported to an excel row

        All values should be a string
        """
        d = self.to_dict().copy()
        d["children"] = " ".join(d["children"])
        d["macs"] = " ".join(d["macs"])
        d["comments"] = "\n\n".join(d["comments"])
        d["history"] = "\n".join([str(version) for version in d["history"]])
        return d

    def history(self):
        versions = []
        for version in self.versions:
            # parent is an attribute used by SQLAlchemy-Continuum
            # version.parent refers to an ItemVersion instance (and has no link with
            # the item parent_id)
            # We need to retrieve the parent "manually"
            if version.parent_id is None:
                parent = None
            else:
                parent = Item.query.get(version.parent_id)
            versions.append(
                {
                    "updated_at": utils.format_field(version.updated_at),
                    "quantity": version.quantity,
                    "location": utils.format_field(version.location),
                    "status": utils.format_field(version.status),
                    "parent": utils.format_field(parent),
                }
            )
        return versions


class ItemComment(CreatedMixin, db.Model):
    body = db.Column(db.Text, nullable=False)
    item_id = db.Column(db.Integer, db.ForeignKey("item.id"), nullable=False)

    def __str__(self):
        return self.body

    def to_dict(self, recursive=False):
        d = super().to_dict()
        d.update({"body": self.body, "item": str(self.item)})
        return d


class Network(CreatedMixin, db.Model):
    vlan_name = db.Column(CIText, nullable=False, unique=True)
    vlan_id = db.Column(db.Integer, nullable=True, unique=True)
    address = db.Column(postgresql.CIDR, nullable=False, unique=True)
    first_ip = db.Column(postgresql.INET, nullable=False, unique=True)
    last_ip = db.Column(postgresql.INET, nullable=False, unique=True)
    gateway = db.Column(postgresql.INET, nullable=False, unique=True)
    description = db.Column(db.Text)
    admin_only = db.Column(db.Boolean, nullable=False, default=False)
    sensitive = db.Column(db.Boolean, nullable=False, default=False)
    scope_id = db.Column(db.Integer, db.ForeignKey("network_scope.id"), nullable=False)
    domain_id = db.Column(db.Integer, db.ForeignKey("domain.id"), nullable=False)

    interfaces = db.relationship(
        "Interface", backref=db.backref("network", lazy="joined"), lazy=True
    )

    __table_args__ = (
        sa.CheckConstraint("first_ip < last_ip", name="first_ip_less_than_last_ip"),
        sa.CheckConstraint("first_ip << address", name="first_ip_in_network"),
        sa.CheckConstraint("last_ip << address", name="last_ip_in_network"),
        sa.CheckConstraint("gateway << address", name="gateway_in_network"),
    )

    def __init__(self, **kwargs):
        # Automatically convert scope to an instance of NetworkScope if it was passed
        # as a string
        if "scope" in kwargs:
            kwargs["scope"] = utils.convert_to_model(
                kwargs["scope"], NetworkScope, "name"
            )
            # If domain_id is not passed, we set it to the network scope value
            if "domain_id" not in kwargs:
                kwargs["domain_id"] = kwargs["scope"].domain_id
        # WARNING! Setting self.scope will call validate_networks in the NetworkScope class
        # For the validation to work, self.address must be set before!
        # Ensure that address and vlan_name are passed before scope
        vlan_name = kwargs.pop("vlan_name")
        address = kwargs.pop("address")
        super().__init__(vlan_name=vlan_name, address=address, **kwargs)

    def __str__(self):
        return str(self.vlan_name)

    @property
    def network_ip(self):
        return ipaddress.ip_network(self.address)

    @property
    def netmask(self):
        return self.network_ip.netmask

    @property
    def broadcast(self):
        return self.network_ip.broadcast_address

    @property
    def first(self):
        return ipaddress.ip_address(self.first_ip)

    @property
    def last(self):
        return ipaddress.ip_address(self.last_ip)

    def ip_range(self):
        """Return the list of IP addresses that can be assigned for this network

        The range is defined by the first and last IP
        """
        return [
            addr for addr in self.network_ip.hosts() if self.first <= addr <= self.last
        ]

    def used_ips(self):
        """Return the list of IP addresses in use

        The list is sorted
        """
        return sorted(interface.address for interface in self.interfaces)

    def available_ips(self):
        """Return the list of IP addresses available"""
        return [addr for addr in self.ip_range() if addr not in self.used_ips()]

    @validates("first_ip")
    def validate_first_ip(self, key, ip):
        """Ensure the first IP is in the network"""
        utils.ip_in_network(ip, self.address)
        return ip

    @validates("last_ip")
    def validate_last_ip(self, key, ip):
        """Ensure the last IP is in the network and greater than first_ip"""
        addr, net = utils.ip_in_network(ip, self.address)
        if addr < self.first:
            raise ValidationError(
                f"Last IP address {ip} is less than the first address {self.first}"
            )
        return ip

    @validates("interfaces")
    def validate_interfaces(self, key, interface):
        """Ensure the interface IP is in the network range"""
        utils.validate_ip(interface.ip, self)
        return interface

    @validates("vlan_name")
    def validate_vlan_name(self, key, string):
        """Ensure the name matches the required format"""
        if string is None:
            return None
        if VLAN_NAME_RE.fullmatch(string) is None:
            raise ValidationError(r"Vlan name shall match [A-Za-z0-9\-]{3,25}")
        return string

    @validates("vlan_id")
    def validate_vlan_id(self, key, value):
        """Ensure the vlan_id is in the scope range"""
        if value is None or self.scope is None:
            # If scope is None, we can't do any validation
            # This will occur when vlan_id is passed before scope
            # We could ensure it's not the case but main use case
            # is when editing network. This won't happen then.
            return value
        if int(value) not in self.scope.vlan_range():
            raise ValidationError(
                f"Vlan id shall be in the range [{self.scope.first_vlan} - {self.scope.last_vlan}]"
            )
        return value

    def to_dict(self, recursive=False):
        d = super().to_dict()
        d.update(
            {
                "vlan_name": self.vlan_name,
                "vlan_id": self.vlan_id,
                "address": self.address,
                "netmask": str(self.netmask),
                "broadcast": str(self.broadcast),
                "first_ip": self.first_ip,
                "last_ip": self.last_ip,
                "gateway": self.gateway,
                "description": self.description,
                "admin_only": self.admin_only,
                "sensitive": self.sensitive,
                "scope": utils.format_field(self.scope),
                "domain": str(self.domain),
                "interfaces": [str(interface) for interface in self.interfaces],
            }
        )
        return d


class DeviceType(db.Model):
    __tablename__ = "device_type"
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(CIText, nullable=False, unique=True)

    hosts = db.relationship(
        "Host", backref=db.backref("device_type", lazy="joined"), lazy=True
    )

    @validates("name")
    def validate_name(self, key, string):
        """Ensure the name field matches the required format"""
        if string is not None and DEVICE_TYPE_RE.fullmatch(string) is None:
            raise ValidationError(f"'{string}' is an invalid device type name")
        return string

    def __str__(self):
        return self.name

    def to_dict(self, recursive=False):
        return {
            "id": self.id,
            "name": self.name,
            "hosts": [str(host) for host in self.hosts],
        }


# Table required for Many-to-Many relationships between Ansible parent and child groups
ansible_groups_parent_child_table = db.Table(
    "ansible_groups_parent_child",
    db.Column(
        "parent_group_id",
        db.Integer,
        db.ForeignKey("ansible_group.id"),
        primary_key=True,
    ),
    db.Column(
        "child_group_id",
        db.Integer,
        db.ForeignKey("ansible_group.id"),
        primary_key=True,
    ),
)


# Table required for Many-to-Many relationships between Ansible groups and hosts
ansible_groups_hosts_table = db.Table(
    "ansible_groups_hosts",
    db.Column(
        "ansible_group_id",
        db.Integer,
        db.ForeignKey("ansible_group.id"),
        primary_key=True,
    ),
    db.Column("host_id", db.Integer, db.ForeignKey("host.id"), primary_key=True),
)


class AnsibleGroupType(Enum):
    STATIC = "STATIC"
    NETWORK_SCOPE = "NETWORK_SCOPE"
    NETWORK = "NETWORK"
    DEVICE_TYPE = "DEVICE_TYPE"
    IOC = "IOC"
    HOSTNAME = "HOSTNAME"

    def __str__(self):
        return self.name

    @classmethod
    def choices(cls):
        return [(item, item.name) for item in AnsibleGroupType]

    @classmethod
    def coerce(cls, value):
        return value if type(value) == AnsibleGroupType else AnsibleGroupType[value]


class AnsibleGroup(CreatedMixin, SearchableMixin, db.Model):
    __versioned__ = {}
    __tablename__ = "ansible_group"
    __mapping__ = {
        "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
        "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
        "user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "vars": {"type": "flattened"},
        "type": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "hosts": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "children": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
    }
    # Define id here so that it can be used in the primary and secondary join
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(CIText, nullable=False, unique=True)
    vars = db.Column(postgresql.JSONB)
    type = db.Column(
        db.Enum(AnsibleGroupType, name="ansible_group_type"),
        default=AnsibleGroupType.STATIC,
        nullable=False,
    )

    _children = db.relationship(
        "AnsibleGroup",
        secondary=ansible_groups_parent_child_table,
        primaryjoin=id == ansible_groups_parent_child_table.c.parent_group_id,
        secondaryjoin=id == ansible_groups_parent_child_table.c.child_group_id,
        backref=db.backref("_parents"),
    )

    def __str__(self):
        return str(self.name)

    @validates("_children")
    def validate_children(self, key, child):
        """Ensure the child is not in the group parents to avoid circular references"""
        if child == self:
            raise ValidationError(f"Group '{self.name}' can't be a child of itself.")
        # "all" is special for Ansible. Any group is automatically a child of "all".
        if child.name == "all":
            raise ValidationError(
                f"Adding group 'all' as child to '{self.name}' creates a recursive dependency loop."
            )

        def check_parents(group):
            """Recursively check all parents"""
            if child in group.parents:
                raise ValidationError(
                    f"Adding group '{child}' as child to '{self.name}' creates a recursive dependency loop."
                )
            for parent in group.parents:
                check_parents(parent)

        check_parents(self)
        return child

    @property
    def is_dynamic(self):
        return self.type != AnsibleGroupType.STATIC

    @property
    def hosts(self):
        if self.type == AnsibleGroupType.STATIC:
            return self._hosts
        if self.type == AnsibleGroupType.NETWORK_SCOPE:
            return (
                Host.query.join(Host.interfaces)
                .join(Interface.network)
                .join(Network.scope)
                .filter(NetworkScope.name == self.name, Interface.name == Host.name)
                .order_by(Host.name)
                .all()
            )
        if self.type == AnsibleGroupType.NETWORK:
            return (
                Host.query.join(Host.interfaces)
                .join(Interface.network)
                .filter(Network.vlan_name == self.name, Interface.name == Host.name)
                .order_by(Host.name)
                .all()
            )
        if self.type == AnsibleGroupType.DEVICE_TYPE:
            return (
                Host.query.join(Host.device_type)
                .filter(DeviceType.name == self.name)
                .order_by(Host.name)
                .all()
            )
        if self.type == AnsibleGroupType.IOC:
            return Host.query.filter(Host.is_ioc.is_(True)).order_by(Host.name).all()
        if self.type == AnsibleGroupType.HOSTNAME:
            return (
                Host.query.filter(Host.name.startswith(self.name))
                .order_by(Host.name)
                .all()
            )

    @hosts.setter
    def hosts(self, value):
        # For dynamic group type, _hosts can only be set to []
        if self.is_dynamic and value:
            raise AttributeError("can't set dynamic hosts")
        self._hosts = value

    @property
    def children(self):
        if self.type == AnsibleGroupType.NETWORK_SCOPE:
            # Return all existing network groups part of the scope
            network_children = (
                AnsibleGroup.query.filter(AnsibleGroup.type == AnsibleGroupType.NETWORK)
                .join(Network, AnsibleGroup.name == Network.vlan_name)
                .join(NetworkScope)
                .filter(NetworkScope.name == self.name)
                .all()
            )
            return sorted(self._children + network_children, key=attrgetter("name"))
        return sorted(self._children, key=attrgetter("name"))

    @children.setter
    def children(self, value):
        if self.type == AnsibleGroupType.NETWORK_SCOPE:
            # Forbid setting a NETWORK group as child
            # Groups linked to networks part of the scope are added automatically
            # Also forbid NETWORK_SCOPE group as child
            for group in value:
                if group.type in (
                    AnsibleGroupType.NETWORK,
                    AnsibleGroupType.NETWORK_SCOPE,
                ):
                    raise ValidationError(
                        f"can't set {str(group.type).lower()} group '{group}' as a network scope child"
                    )
        self._children = value

    @property
    def parents(self):
        if self.type == AnsibleGroupType.NETWORK:
            # Add the group corresponding to the network scope if it exists
            network = Network.query.filter_by(vlan_name=self.name).first()
            if network is not None:
                scope_group = AnsibleGroup.query.filter_by(
                    name=network.scope.name
                ).first()
                if scope_group is not None:
                    return sorted(self._parents + [scope_group], key=attrgetter("name"))
        return sorted(self._parents, key=attrgetter("name"))

    @parents.setter
    def parents(self, value):
        if self.type == AnsibleGroupType.NETWORK:
            # Forbid setting a NETWORK_SCOPE group as parent
            # The group linked to the scope of the network is added automatically
            # Also forbid setting a NETWORK group as it doesn't make sense
            for group in value:
                if group.type in (
                    AnsibleGroupType.NETWORK,
                    AnsibleGroupType.NETWORK_SCOPE,
                ):
                    raise ValidationError(
                        f"can't set {str(group.type).lower()} group '{group}' as a network parent"
                    )
        self._parents = value

    def to_dict(self, recursive=False):
        d = super().to_dict()
        d.update(
            {
                "name": self.name,
                "vars": self.vars,
                "type": self.type.name,
                "hosts": [host.fqdn for host in self.hosts],
                "children": [str(child) for child in self.children],
            }
        )
        return d


class Host(CreatedMixin, SearchableMixin, db.Model):
    __versioned__ = {}
    __mapping__ = {
        "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
        "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
        "user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "fqdn": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "is_ioc": {"type": "boolean"},
        "device_type": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "model": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "description": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "items": {
            "properties": {
                "ics_id": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "serial_number": {
                    "type": "text",
                    "fields": {"keyword": {"type": "keyword"}},
                },
                "stack_member": {"type": "byte"},
            }
        },
        "interfaces": {
            "properties": {
                "id": {"enabled": False},
                "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
                "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
                "user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "is_main": {"type": "boolean"},
                "network": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "ip": {"type": "ip"},
                "netmask": {"enabled": False},
                "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "description": {
                    "type": "text",
                    "fields": {"keyword": {"type": "keyword"}},
                },
                "mac": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "host": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "cnames": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "domain": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
                "device_type": {
                    "type": "text",
                    "fields": {"keyword": {"type": "keyword"}},
                },
                "model": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
            }
        },
        "ansible_vars": {"type": "flattened"},
        "ansible_groups": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "scope": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
        "sensitive": {"type": "boolean"},
    }

    # id shall be defined here to be used by SQLAlchemy-Continuum
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.Text, nullable=False, unique=True)
    description = db.Column(db.Text)
    device_type_id = db.Column(
        db.Integer, db.ForeignKey("device_type.id"), nullable=False
    )
    is_ioc = db.Column(db.Boolean, nullable=False, default=False)
    ansible_vars = db.Column(postgresql.JSONB)

    # 1. Set cascade to all (to add delete) and delete-orphan to delete all interfaces
    # when deleting a host
    # 2. Return interfaces sorted by name so that the main one (the one starting with
    # the same name as the host) is always the first one.
    # As an interface name always has to start with the name of the host, the one
    # matching the host name will always come first.
    interfaces = db.relationship(
        "Interface",
        backref=db.backref("host", lazy="joined"),
        cascade="all, delete-orphan",
        lazy="joined",
        order_by="Interface.name",
    )
    items = db.relationship(
        "Item", backref=db.backref("host", lazy="joined"), lazy="joined"
    )
    ansible_groups = db.relationship(
        "AnsibleGroup",
        secondary=ansible_groups_hosts_table,
        lazy="joined",
        backref=db.backref("_hosts"),
    )

    def __init__(self, **kwargs):
        # Automatically convert device_type as an instance of its class if passed as a string
        if "device_type" in kwargs:
            kwargs["device_type"] = utils.convert_to_model(
                kwargs["device_type"], DeviceType
            )
        # Automatically convert items to a list of instances if passed as a list of ics_id
        if "items" in kwargs:
            kwargs["items"] = [
                utils.convert_to_model(item, Item, filter_by="ics_id")
                for item in kwargs["items"]
            ]
        # Automatically convert ansible groups to a list of instances if passed as a list of strings
        if "ansible_groups" in kwargs:
            kwargs["ansible_groups"] = [
                utils.convert_to_model(group, AnsibleGroup)
                for group in kwargs["ansible_groups"]
            ]
        super().__init__(**kwargs)

    @property
    def model(self):
        """Return the model of the first linked item"""
        try:
            return utils.format_field(self.items[0].model)
        except IndexError:
            return None

    @property
    def main_interface(self):
        """Return the host main interface

        The main interface is the one that has the same name as the host
        or the first one found
        """
        # As interfaces are sorted, the first one is always the main one
        try:
            return self.interfaces[0]
        except IndexError:
            return None

    @property
    def main_network(self):
        """Return the host main interface network"""
        try:
            return self.main_interface.network
        except AttributeError:
            return None

    @property
    def scope(self):
        """Return the host main interface network scope"""
        try:
            return self.main_network.scope
        except AttributeError:
            return None

    @property
    def sensitive(self):
        """Return True if the host is on a sensitive network"""
        try:
            return self.main_network.sensitive
        except AttributeError:
            return False

    @property
    def fqdn(self):
        """Return the host fully qualified domain name

        The domain is based on the main interface
        """
        if self.main_interface:
            return f"{self.name}.{self.main_interface.network.domain}"
        else:
            return self.name

    def __str__(self):
        return str(self.name)

    @validates("name")
    def validate_name(self, key, string):
        """Ensure the name matches the required format"""
        if string is None:
            return None
        # Force the string to lowercase
        lower_string = string.lower()
        if HOST_NAME_RE.fullmatch(lower_string) is None:
            raise ValidationError(f"Host name shall match {HOST_NAME_RE.pattern}")
        existing_cname = Cname.query.filter_by(name=lower_string).first()
        if existing_cname:
            raise ValidationError("Host name matches an existing cname")
        existing_interface = Interface.query.filter(
            Interface.name == lower_string, Interface.host_id != self.id
        ).first()
        if existing_interface:
            raise ValidationError("Host name matches an existing interface")
        return lower_string

    def stack_members(self):
        """Return all items part of the stack sorted by stack member number"""
        members = [item for item in self.items if item.stack_member is not None]
        return sorted(members, key=lambda x: x.stack_member)

    def stack_members_numbers(self):
        """Return the list of stack member numbers"""
        return [item.stack_member for item in self.stack_members()]

    def free_stack_members(self):
        """Return the list of free stack member numbers"""
        return [nb for nb in range(0, 10) if nb not in self.stack_members_numbers()]

    def to_dict(self, recursive=False):
        # None can't be compared to not None values
        # This function replaces None by Inf so it is set at the end of the list
        # items are sorted by stack_member and then ics_id
        def none_to_inf(nb):
            return float("inf") if nb is None else int(nb)

        d = super().to_dict()
        d.update(
            {
                "name": self.name,
                "fqdn": self.fqdn,
                "is_ioc": self.is_ioc,
                "device_type": str(self.device_type),
                "model": self.model,
                "description": self.description,
                "items": [
                    str(item)
                    for item in sorted(
                        self.items,
                        key=lambda x: (none_to_inf(x.stack_member), x.ics_id),
                    )
                ],
                "interfaces": [str(interface) for interface in self.interfaces],
                "ansible_vars": self.ansible_vars,
                "ansible_groups": [str(group) for group in self.ansible_groups],
                "scope": utils.format_field(self.scope),
                "sensitive": self.sensitive,
            }
        )
        if recursive:
            # Replace the list of interface names by the full representation
            # so that we can index everything in elasticsearch
            d["interfaces"] = [interface.to_dict() for interface in self.interfaces]
            # Add extra info in items
            d["items"] = sorted(
                [
                    {
                        "ics_id": item.ics_id,
                        "serial_number": item.serial_number,
                        "stack_member": item.stack_member,
                    }
                    for item in self.items
                ],
                key=lambda x: (none_to_inf(x["stack_member"]), x["ics_id"]),
            )
        return d


class Interface(CreatedMixin, db.Model):
    network_id = db.Column(db.Integer, db.ForeignKey("network.id"), nullable=False)
    ip = db.Column(postgresql.INET, nullable=False, unique=True)
    name = db.Column(db.Text, nullable=False, unique=True)
    description = db.Column(db.Text)
    mac = db.Column(postgresql.MACADDR, nullable=True, unique=True)
    host_id = db.Column(db.Integer, db.ForeignKey("host.id"), nullable=False)

    # Add delete and delete-orphan options to automatically delete cnames when:
    # - deleting an interface
    # - de-associating a cname (removing it from the interface.cnames list)
    cnames = db.relationship(
        "Cname",
        backref=db.backref("interface", lazy="joined"),
        cascade="all, delete, delete-orphan",
        lazy="joined",
    )

    def __init__(self, **kwargs):
        # Always set self.host and not self.host_id to call validate_name
        host_id = kwargs.pop("host_id", None)
        if host_id is not None:
            host = Host.query.get(host_id)
        elif "host" in kwargs:
            # Automatically convert host to an instance of Host if it was passed
            # as a string
            host = utils.convert_to_model(kwargs.pop("host"), Host, "name")
        else:
            host = None
        # Always set self.network and not self.network_id to call validate_interfaces
        network_id = kwargs.pop("network_id", None)
        if network_id is not None:
            kwargs["network"] = Network.query.get(network_id)
        elif "network" in kwargs:
            # Automatically convert network to an instance of Network if it was passed
            # as a string
            kwargs["network"] = utils.convert_to_model(
                kwargs["network"], Network, "vlan_name"
            )
        # WARNING! Setting self.network will call validate_interfaces in the Network class
        # For the validation to work, self.ip must be set before!
        # Ensure that ip is passed before network
        try:
            ip = kwargs.pop("ip")
        except KeyError:
            # Assign first available IP
            ip = str(kwargs["network"].available_ips()[0])
        super().__init__(host=host, ip=ip, **kwargs)

    @validates("name")
    def validate_name(self, key, string):
        """Ensure the name matches the required format"""
        if string is None:
            return None
        # Force the string to lowercase
        lower_string = string.lower()
        if INTERFACE_NAME_RE.fullmatch(lower_string) is None:
            raise ValidationError(
                f"Interface name shall match {INTERFACE_NAME_RE.pattern}"
            )
        if self.host and not lower_string.startswith(self.host.name):
            raise ValidationError(
                f"Interface name shall start with the host name '{self.host}'"
            )
        existing_cname = Cname.query.filter_by(name=lower_string).first()
        if existing_cname:
            raise ValidationError("Interface name matches an existing cname")
        existing_host = Host.query.filter(
            Host.name == lower_string, Host.id != self.host.id
        ).first()
        if existing_host:
            raise ValidationError("Interface name matches an existing host")
        return lower_string

    @validates("mac")
    def validate_mac(self, key, string):
        """Ensure the mac is a valid MAC address"""
        if not string:
            return None
        if MAC_ADDRESS_RE.fullmatch(string) is None:
            raise ValidationError(f"'{string}' does not appear to be a MAC address")
        return string

    @validates("cnames")
    def validate_cnames(self, key, cname):
        """Ensure the cname is unique by domain"""
        existing_cnames = Cname.query.filter_by(name=cname.name).all()
        for existing_cname in existing_cnames:
            if existing_cname.domain == str(self.network.domain):
                raise ValidationError(
                    f"Duplicate cname on the {self.network.domain} domain"
                )
        return cname

    @property
    def address(self):
        return ipaddress.ip_address(self.ip)

    @property
    def is_ioc(self):
        return self.is_main and self.host.is_ioc

    @property
    def is_main(self):
        return self.name == self.host.main_interface.name

    def __str__(self):
        return str(self.name)

    def __repr__(self):
        return f"Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})"

    def to_dict(self, recursive=False):
        d = super().to_dict()
        d.update(
            {
                "is_main": self.is_main,
                "network": str(self.network),
                "ip": self.ip,
                "netmask": str(self.network.netmask),
                "name": self.name,
                "description": self.description,
                "mac": utils.format_field(self.mac),
                "host": utils.format_field(self.host),
                "cnames": [str(cname) for cname in self.cnames],
                "domain": str(self.network.domain),
            }
        )
        if self.host:
            d["device_type"] = str(self.host.device_type)
            d["model"] = utils.format_field(self.host.model)
        else:
            d["device_type"] = None
            d["model"] = None
        return d


class Mac(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    address = db.Column(postgresql.MACADDR, nullable=False, unique=True)
    item_id = db.Column(db.Integer, db.ForeignKey("item.id"))

    def __str__(self):
        return str(self.address)

    @validates("address")
    def validate_address(self, key, string):
        """Ensure the address is a valid MAC address"""
        if string is None:
            return None
        if MAC_ADDRESS_RE.fullmatch(string) is None:
            raise ValidationError(f"'{string}' does not appear to be a MAC address")
        return string

    def to_dict(self, recursive=False):
        return {
            "id": self.id,
            "address": self.address,
            "item": utils.format_field(self.item),
        }


class Cname(CreatedMixin, db.Model):
    name = db.Column(db.Text, nullable=False)
    interface_id = db.Column(db.Integer, db.ForeignKey("interface.id"), nullable=False)

    def __init__(self, **kwargs):
        # Always set self.interface and not self.interface_id to call validate_cnames
        interface_id = kwargs.pop("interface_id", None)
        if interface_id is not None:
            kwargs["interface"] = Interface.query.get(interface_id)
        super().__init__(**kwargs)

    def __str__(self):
        return str(self.name)

    @property
    def domain(self):
        """Return the cname domain name"""
        return str(self.interface.network.domain)

    @property
    def fqdn(self):
        """Return the cname fully qualified domain name"""
        return f"{self.name}.{self.domain}"

    @validates("name")
    def validate_name(self, key, string):
        """Ensure the name matches the required format"""
        if string is None:
            return None
        # Force the string to lowercase
        lower_string = string.lower()
        if HOST_NAME_RE.fullmatch(lower_string) is None:
            raise ValidationError(f"cname shall match {HOST_NAME_RE.pattern}")
        existing_interface = Interface.query.filter_by(name=lower_string).first()
        if existing_interface:
            raise ValidationError("cname matches an existing interface")
        existing_host = Host.query.filter_by(name=lower_string).first()
        if existing_host:
            raise ValidationError("cname matches an existing host")
        return lower_string

    def to_dict(self, recursive=False):
        d = super().to_dict()
        d.update({"name": self.name, "interface": str(self.interface)})
        return d


class Domain(CreatedMixin, db.Model):
    name = db.Column(db.Text, nullable=False, unique=True)

    scopes = db.relationship(
        "NetworkScope", backref=db.backref("domain", lazy="joined"), lazy=True
    )
    networks = db.relationship(
        "Network", backref=db.backref("domain", lazy="joined"), lazy=True
    )

    def __str__(self):
        return str(self.name)

    def to_dict(self, recursive=False):
        d = super().to_dict()
        d.update(
            {
                "name": self.name,
                "scopes": [str(scope) for scope in self.scopes],
                "networks": [str(network) for network in self.networks],
            }
        )
        return d


class NetworkScope(CreatedMixin, db.Model):
    __tablename__ = "network_scope"
    name = db.Column(CIText, nullable=False, unique=True)
    first_vlan = db.Column(db.Integer, nullable=True, unique=True)
    last_vlan = db.Column(db.Integer, nullable=True, unique=True)
    supernet = db.Column(postgresql.CIDR, nullable=False, unique=True)
    domain_id = db.Column(db.Integer, db.ForeignKey("domain.id"), nullable=False)
    description = db.Column(db.Text)

    networks = db.relationship(
        "Network", backref=db.backref("scope", lazy="joined"), lazy=True
    )

    __table_args__ = (
        sa.CheckConstraint(
            "first_vlan < last_vlan", name="first_vlan_less_than_last_vlan"
        ),
    )

    def __str__(self):
        return str(self.name)

    @validates("supernet")
    def validate_supernet(self, key, supernet):
        """Ensure the supernet doesn't overlap existing supernets

        Also ensure it's a supernet of all existing networks (when editing)
        """
        supernet_address = ipaddress.ip_network(supernet)
        existing_scopes = NetworkScope.query.filter(NetworkScope.id != self.id).all()
        for existing_scope in existing_scopes:
            if supernet_address.overlaps(existing_scope.supernet_ip):
                raise ValidationError(
                    f"{supernet} overlaps {existing_scope} ({existing_scope.supernet_ip})"
                )
        for network in self.networks:
            if not network.network_ip.subnet_of(supernet_address):
                raise ValidationError(
                    f"{network.network_ip} is not a subnet of {supernet}"
                )
        return supernet

    @validates("networks")
    def validate_networks(self, key, network):
        """Ensure the network is included in the supernet and doesn't overlap
        existing networks"""
        if not network.network_ip.subnet_of(self.supernet_ip):
            raise ValidationError(
                f"{network.network_ip} is not a subnet of {self.supernet_ip}"
            )
        existing_networks = Network.query.filter_by(scope=self).all()
        for existing_network in existing_networks:
            if existing_network.id == network.id:
                # Same network added again during edit via admin interface
                continue
            if network.network_ip.overlaps(existing_network.network_ip):
                raise ValidationError(
                    f"{network.network_ip} overlaps {existing_network} ({existing_network.network_ip})"
                )
        return network

    @validates("first_vlan")
    def validate_first_vlan(self, key, value):
        """Ensure the first vlan is lower than any network vlan id"""
        if value is None:
            return value
        for network in self.networks:
            if int(value) > network.vlan_id:
                raise ValidationError(
                    f"First vlan shall be lower than {network.vlan_name} vlan: {network.vlan_id}"
                )
        return value

    @validates("last_vlan")
    def validate_last_vlan(self, key, value):
        """Ensure the last vlan is greater than any network vlan id"""
        if value is None:
            return value
        for network in self.networks:
            if int(value) < network.vlan_id:
                raise ValidationError(
                    f"Last vlan shall be greater than {network.vlan_name} vlan: {network.vlan_id}"
                )
        return value

    @property
    def supernet_ip(self):
        return ipaddress.ip_network(self.supernet)

    def prefix_range(self):
        """Return the list of subnet prefix that can be used for this network scope"""
        return list(range(self.supernet_ip.prefixlen + 1, 31))

    def vlan_range(self):
        """Return the list of vlan ids that can be assigned for this network scope

        The range is defined by the first and last vlan
        """
        if self.first_vlan is None or self.last_vlan is None:
            return []
        return range(self.first_vlan, self.last_vlan + 1)

    def used_vlans(self):
        """Return the list of vlan ids in use

        The list is sorted
        """
        if self.first_vlan is None or self.last_vlan is None:
            return []
        return sorted(network.vlan_id for network in self.networks)

    def available_vlans(self):
        """Return the list of vlan ids available"""
        return [vlan for vlan in self.vlan_range() if vlan not in self.used_vlans()]

    def used_subnets(self):
        """Return the list of subnets in use

        The list is sorted
        """
        return sorted(network.network_ip for network in self.networks)

    def available_subnets(self, prefix):
        """Return the list of available subnets with the given prefix

        Overlapping subnets with existing networks are filtered"""
        used = self.used_subnets()
        return [
            str(subnet)
            for subnet in self.supernet_ip.subnets(new_prefix=prefix)
            if not utils.overlaps(subnet, used)
        ]

    def to_dict(self, recursive=False):
        d = super().to_dict()
        d.update(
            {
                "name": self.name,
                "first_vlan": self.first_vlan,
                "last_vlan": self.last_vlan,
                "supernet": self.supernet,
                "description": self.description,
                "domain": str(self.domain),
                "networks": [str(network) for network in self.networks],
            }
        )
        return d


# Define RQ JobStatus as a Python enum
# We can't use the one defined in rq/job.py as it's
# not a real enum (it's a custom one) and is not
# compatible with sqlalchemy
class JobStatus(Enum):
    QUEUED = "queued"
    FINISHED = "finished"
    FAILED = "failed"
    STARTED = "started"
    DEFERRED = "deferred"


class Task(db.Model):
    # Use job id generated by RQ
    id = db.Column(postgresql.UUID, primary_key=True)
    created_at = db.Column(db.DateTime, default=utcnow())
    ended_at = db.Column(db.DateTime)
    name = db.Column(db.Text, nullable=False, index=True)
    command = db.Column(db.Text)
    status = db.Column(db.Enum(JobStatus, name="job_status"))
    awx_resource = db.Column(db.Text)
    awx_job_id = db.Column(db.Integer)
    exception = db.Column(db.Text)
    user_id = db.Column(
        db.Integer,
        db.ForeignKey("user_account.id"),
        nullable=False,
        default=utils.fetch_current_user_id,
    )
    depends_on_id = db.Column(postgresql.UUID, db.ForeignKey("task.id"))

    reverse_dependencies = db.relationship(
        "Task", backref=db.backref("depends_on", remote_side=[id])
    )

    @property
    def awx_job_url(self):
        if self.awx_job_id is None:
            return None
        if self.awx_resource == "job":
            route = "jobs/playbook"
        elif self.awx_resource == "workflow_job":
            route = "workflows"
        elif self.awx_resource == "inventory_source":
            route = "jobs/inventory"
        else:
            current_app.logger.warning(f"Unknown AWX resource: {self.awx_resource}")
            return None
        return urllib.parse.urljoin(
            current_app.config["AWX_URL"], f"/#/{route}/{self.awx_job_id}"
        )

    def update_reverse_dependencies(self):
        """Recursively set all reverse dependencies to FAILED

        When a RQ job is set to FAILED, its reverse dependencies will stay to DEFERRED.
        This method allows to easily update the corresponding tasks status.

        The tasks are modified but the session is not committed.
        """

        def set_reverse_dependencies_to_failed(task):
            for dependency in task.reverse_dependencies:
                current_app.logger.info(
                    f"Setting {dependency.id} ({dependency.name}) to FAILED due to failed dependency"
                )
                dependency.status = JobStatus.FAILED
                set_reverse_dependencies_to_failed(dependency)

        set_reverse_dependencies_to_failed(self)

    def __str__(self):
        return str(self.id)

    def to_dict(self, recursive=False):
        return {
            "id": self.id,
            "name": self.name,
            "created_at": utils.format_field(self.created_at),
            "ended_at": utils.format_field(self.ended_at),
            "status": self.status.name,
            "awx_resource": self.awx_resource,
            "awx_job_id": self.awx_job_id,
            "awx_job_url": self.awx_job_url,
            "depends_on": self.depends_on_id,
            "command": self.command,
            "exception": self.exception,
            "user": str(self.user),
        }


def trigger_core_services_update(session):
    """Trigger core services update on any Interface or Host modification.

    Called by before flush hook
    """
    # In session.dirty, we need to check session.is_modified(instance) because the instance
    # could have been added to the session without being modified
    # In session.deleted, session.is_modified(instance) is usually False (we shouldn't check it).
    # In session.new, it will always be True and we don't need to check it.
    for kind in ("new", "dirty", "deleted"):
        for instance in getattr(session, kind):
            if isinstance(instance, (Host, Interface)) and (
                (kind == "dirty" and session.is_modified(instance))
                or (kind in ("new", "deleted"))
            ):
                utils.trigger_core_services_update()
                return True
    return False


def trigger_inventory_update(session):
    """Trigger an inventory update in AWX

    Update on any AnsibleGroup/Cname/Domain/Host/Interface/Network/NetworkScope
    modification.

    Called by before flush hook
    """
    # In session.dirty, we need to check session.is_modified(instance) because the instance
    # could have been added to the session without being modified
    # In session.deleted, session.is_modified(instance) is usually False (we shouldn't check it).
    # In session.new, it will always be True and we don't need to check it.
    for kind in ("new", "dirty", "deleted"):
        for instance in getattr(session, kind):
            if isinstance(
                instance,
                (AnsibleGroup, Cname, Domain, Host, Interface, Network, NetworkScope),
            ) and (
                (kind == "dirty" and session.is_modified(instance))
                or (kind in ("new", "deleted"))
            ):
                utils.trigger_inventory_update()
                return True
    return False


def trigger_ansible_groups_reindex(session):
    """Trigger a reindex of Ansible groups

    Update on any Host or Interface modification.
    This is required for all dynamic groups.

    Called by before flush hook
    """
    # In session.dirty, we need to check session.is_modified(instance) because the instance
    # could have been added to the session without being modified
    # In session.deleted, session.is_modified(instance) is usually False (we shouldn't check it).
    # In session.new, it will always be True and we don't need to check it.
    for kind in ("new", "dirty", "deleted"):
        for instance in getattr(session, kind):
            if isinstance(instance, (Host, Interface),) and (
                (kind == "dirty" and session.is_modified(instance))
                or (kind in ("new", "deleted"))
            ):
                utils.trigger_ansible_groups_reindex()
                return True
    return False


@sa.event.listens_for(db.session, "before_flush")
def before_flush(session, flush_context, instances):
    """Before flush hook

    Used to trigger core services and inventory update, as well
    as the Ansible groups reindex.

    See http://docs.sqlalchemy.org/en/latest/orm/session_events.html#before-flush
    """
    trigger_inventory_update(session)
    trigger_core_services_update(session)
    trigger_ansible_groups_reindex(session)


@sa.event.listens_for(Network.sensitive, "set")
def update_host_sensitive_field(target, value, oldvalue, initiator):
    """Update the host sensitive field in elasticsearch based on the Network value

    Updating the network won't trigger any update of the hosts as sensitive is just
    a property (based on host.main_interface.network).
    We have to force the update in elasticsearch index.
    """
    if value != oldvalue:
        current_app.logger.debug(f"Network {target} sensitive value changed to {value}")
        index = "host" + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"]
        for interface in target.interfaces:
            current_app.logger.debug(
                f"Update sensitive to {value} for {interface.host}"
            )
            # We can't use interface.host.to_dict() because the property host.sensitive
            # doesn't have the new value yet at this time
            search.update_document(index, interface.host.id, {"sensitive": value})


# call configure_mappers after defining all the models
# required by sqlalchemy_continuum
sa.orm.configure_mappers()
ItemVersion = version_class(Item)
# Set SQLAlchemy event listeners
db.event.listen(db.session, "before_flush", SearchableMixin.before_flush)
db.event.listen(
    db.session, "after_flush_postexec", SearchableMixin.after_flush_postexec
)
db.event.listen(db.session, "after_commit", SearchableMixin.after_commit)