# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~

This module implements the models used in the app.

:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.

"""
import ipaddress
import string
import qrcode
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import validates
from sqlalchemy_continuum import make_versioned, version_class
from citext import CIText
from flask import current_app
from flask_login import UserMixin
from wtforms import ValidationError
from .extensions import db, login_manager, ldap_manager, cache
from .plugins import FlaskUserPlugin
from .validators import ICS_ID_RE, HOST_NAME_RE, VLAN_NAME_RE, MAC_ADDRESS_RE
from . import utils


make_versioned(plugins=[FlaskUserPlugin()])


# See http://docs.sqlalchemy.org/en/latest/core/compiler.html#utc-timestamp-function
class utcnow(sa.sql.expression.FunctionElement):
    type = sa.types.DateTime()


@sa.ext.compiler.compiles(utcnow, 'postgresql')
def pg_utcnow(element, compiler, **kw):
        return "TIMEZONE('utc', CURRENT_TIMESTAMP)"


def temporary_ics_ids():
    """Generator that returns the full list of temporary ICS ids"""
    return (f'{current_app.config["TEMPORARY_ICS_ID"]}{letter}{number:0=3d}'
            for letter in string.ascii_uppercase
            for number in range(0, 1000))


def used_temporary_ics_ids():
    """Return a set with the temporary ICS ids used"""
    temporary_items = Item.query.filter(
        Item.ics_id.startswith(
            current_app.config['TEMPORARY_ICS_ID'])).all()
    return {item.ics_id for item in temporary_items}


def get_temporary_ics_id():
    """Return a temporary ICS id that is available"""
    used_temp_ics_ids = used_temporary_ics_ids()
    for ics_id in temporary_ics_ids():
        if ics_id not in used_temp_ics_ids:
            return ics_id
    else:
        raise ValueError('No temporary ICS id available')


@login_manager.user_loader
@cache.memoize(timeout=1800)
def load_user(user_id):
    """User loader callback for flask-login

    :param str user_id: unicode ID of a user
    :returns: corresponding user object or None
    """
    return User.query.get(int(user_id))


@ldap_manager.save_user
def save_user(dn, username, data, memberships):
    """User saver for flask-ldap3-login

    This method is called whenever a LDAPLoginForm()
    successfully validates.
    """
    user = User.query.filter_by(username=username).first()
    if user is None:
        user = User(username=username,
                    display_name=utils.attribute_to_string(data['cn']),
                    email=utils.attribute_to_string(data['mail']))
    # Always update the user groups to keep them up-to-date
    user.groups = [utils.attribute_to_string(group['cn']) for group in memberships]
    db.session.add(user)
    db.session.commit()
    return user


class User(db.Model, UserMixin):
    # "user" is a reserved word in postgresql
    # so let's use another name
    __tablename__ = 'user_account'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.Text, nullable=False, unique=True)
    display_name = db.Column(db.Text, nullable=False)
    email = db.Column(db.Text)
    groups = db.Column(postgresql.ARRAY(db.Text))
    tokens = db.relationship("Token", backref="user")

    def get_id(self):
        """Return the user id as unicode

        Required by flask-login
        """
        return str(self.id)

    @property
    def csentry_groups(self):
        groups = []
        for key, value in current_app.config['CSENTRY_LDAP_GROUPS'].items():
            if value in self.groups:
                groups.append(key)
        return groups

    @property
    def is_admin(self):
        return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups

    def is_member_of_one_group(self, groups):
        """Return True if the user is at least member of one of the given groups"""
        names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
        return bool(set(self.groups) & set(names))

    def is_member_of_all_groups(self, groups):
        """Return True if the user is member of all the given groups"""
        names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
        return set(names).issubset(self.groups)

    def __str__(self):
        return self.display_name

    def to_dict(self):
        return {
            'id': self.id,
            'username': self.username,
            'display_name': self.display_name,
            'email': self.email,
            'groups': self.csentry_groups,
        }


class Token(db.Model):
    """Table to store valid tokens"""
    id = db.Column(db.Integer, primary_key=True)
    jti = db.Column(postgresql.UUID, nullable=False)
    token_type = db.Column(db.Text, nullable=False)
    user_id = db.Column(db.Integer, db.ForeignKey('user_account.id'), nullable=False)
    issued_at = db.Column(db.DateTime, nullable=False)
    # expires can be set to None for tokens that never expire
    expires = db.Column(db.DateTime)
    description = db.Column(db.Text)

    __table_args__ = (
        sa.UniqueConstraint(jti, user_id),
    )

    def __str__(self):
        return self.jti


class QRCodeMixin:
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(CIText, nullable=False, unique=True)
    description = db.Column(db.Text)

    def image(self):
        """Return a QRCode image to identify a record

        The QRCode includes:
             - CSE prefix
             - the table name
             - the name of the record
        """
        data = ':'.join(['CSE', self.__tablename__, self.name])
        return qrcode.make(data, version=1, box_size=5)

    @cache.memoize(timeout=0)
    def base64_image(self):
        """Return the QRCode image as base64 string"""
        return utils.image_to_base64(self.image())

    def __str__(self):
        return self.name

    def __repr__(self):
        # The cache.memoize decorator performs a repr() on the passed in arguments
        # __repr__ is used as part of the cache key and shall be a uniquely identifying string
        # See https://flask-caching.readthedocs.io/en/latest/#memoization
        return f'{self.__class__.__name__}(id={self.id}, name={self.name})'

    def to_dict(self):
        return {
            'id': self.id,
            'name': self.name,
            'description': self.description,
            'qrcode': self.base64_image(),
        }


class Action(QRCodeMixin, db.Model):
    pass


class Manufacturer(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='manufacturer')


class Model(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='model')


class Location(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='location')


class Status(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='status')


class CreatedMixin:
    id = db.Column(db.Integer, primary_key=True)
    created_at = db.Column(db.DateTime, default=utcnow())
    updated_at = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow())

    # Using ForeignKey and relationship in mixin requires the @declared_attr decorator
    # See http://docs.sqlalchemy.org/en/latest/orm/extensions/declarative/mixins.html
    @declared_attr
    def user_id(cls):
        return db.Column(db.Integer, db.ForeignKey('user_account.id'),
                         nullable=False, default=utils.fetch_current_user_id)

    @declared_attr
    def user(cls):
        return db.relationship('User')

    def __init__(self, **kwargs):
        # Automatically convert created_at/updated_at strings
        # to datetime object
        for key in ('created_at', 'updated_at'):
            if key in kwargs:
                if isinstance(kwargs[key], str):
                    kwargs[key] = utils.parse_to_utc(kwargs[key])
        super().__init__(**kwargs)

    def to_dict(self):
        return {
            'id': self.id,
            'created_at': utils.format_field(self.created_at),
            'updated_at': utils.format_field(self.updated_at),
            'user': str(self.user),
        }


class Item(CreatedMixin, db.Model):
    __versioned__ = {
        'exclude': ['created_at', 'user_id', 'ics_id', 'serial_number',
                    'manufacturer_id', 'model_id']
    }

    # WARNING! Inheriting id from CreatedMixin doesn't play well with
    # SQLAlchemy-Continuum. It has to be defined here.
    id = db.Column(db.Integer, primary_key=True)
    ics_id = db.Column(db.Text, unique=True, nullable=False,
                       index=True, default=get_temporary_ics_id)
    serial_number = db.Column(db.Text, nullable=False)
    manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id'))
    model_id = db.Column(db.Integer, db.ForeignKey('model.id'))
    location_id = db.Column(db.Integer, db.ForeignKey('location.id'))
    status_id = db.Column(db.Integer, db.ForeignKey('status.id'))
    parent_id = db.Column(db.Integer, db.ForeignKey('item.id'))

    manufacturer = db.relationship('Manufacturer', back_populates='items')
    model = db.relationship('Model', back_populates='items')
    location = db.relationship('Location', back_populates='items')
    status = db.relationship('Status', back_populates='items')
    children = db.relationship('Item', backref=db.backref('parent', remote_side=[id]))
    macs = db.relationship('Mac', backref='item')
    host = db.relationship('Host', uselist=False, backref='item')
    comments = db.relationship('ItemComment', backref='item')

    def __init__(self, **kwargs):
        # Automatically convert manufacturer/model/location/status to an
        # instance of their class if passed as a string
        for key, cls in [('manufacturer', Manufacturer),
                         ('model', Model),
                         ('location', Location),
                         ('status', Status)]:
            if key in kwargs:
                kwargs[key] = utils.convert_to_model(kwargs[key], cls)
        super().__init__(**kwargs)

    def __str__(self):
        return str(self.ics_id)

    @validates('ics_id')
    def validate_ics_id(self, key, string):
        """Ensure the ICS id field matches the required format"""
        if string is not None:
            if ICS_ID_RE.fullmatch(string) is None:
                raise ValidationError('ICS id shall match [A-Z]{3}[0-9]{3}')
        return string

    def to_dict(self):
        d = super().to_dict()
        d.update({
            'ics_id': self.ics_id,
            'serial_number': self.serial_number,
            'manufacturer': utils.format_field(self.manufacturer),
            'model': utils.format_field(self.model),
            'location': utils.format_field(self.location),
            'status': utils.format_field(self.status),
            'parent': utils.format_field(self.parent),
            'children': [str(child) for child in self.children],
            'macs': [str(mac) for mac in self.macs],
            'history': self.history(),
            'comments': [str(comment) for comment in self.comments],
        })
        return d

    def history(self):
        versions = []
        for version in self.versions:
            # parent is an attribute used by SQLAlchemy-Continuum
            # version.parent refers to an ItemVersion instance (and has no link with
            # the item parent_id)
            # We need to retrieve the parent "manually"
            if version.parent_id is None:
                parent = None
            else:
                parent = Item.query.get(version.parent_id)
            versions.append({
                'updated_at': utils.format_field(version.updated_at),
                'location': utils.format_field(version.location),
                'status': utils.format_field(version.status),
                'parent': utils.format_field(parent),
            })
        return versions


class ItemComment(CreatedMixin, db.Model):
    body = db.Column(db.Text, nullable=False)
    item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False)

    def __str__(self):
        return self.body

    def to_dict(self):
        d = super().to_dict()
        d.update({
            'body': self.body,
            'item': str(self.item),
        })
        return d


class Network(CreatedMixin, db.Model):
    vlan_name = db.Column(CIText, nullable=False, unique=True)
    vlan_id = db.Column(db.Integer, nullable=False, unique=True)
    address = db.Column(postgresql.CIDR, nullable=False, unique=True)
    first_ip = db.Column(postgresql.INET, nullable=False, unique=True)
    last_ip = db.Column(postgresql.INET, nullable=False, unique=True)
    description = db.Column(db.Text)
    admin_only = db.Column(db.Boolean, nullable=False, default=False)
    scope_id = db.Column(db.Integer, db.ForeignKey('network_scope.id'), nullable=False)

    interfaces = db.relationship('Interface', backref='network')

    __table_args__ = (
        sa.CheckConstraint('first_ip < last_ip', name='first_ip_less_than_last_ip'),
        sa.CheckConstraint('first_ip << address', name='first_ip_in_network'),
        sa.CheckConstraint('last_ip << address', name='last_ip_in_network'),
    )

    def __init__(self, **kwargs):
        # Automatically convert scope to an instance of NetworkScope if it was passed
        # as a string
        if 'scope' in kwargs:
            kwargs['scope'] = utils.convert_to_model(kwargs['scope'], NetworkScope, 'name')
        super().__init__(**kwargs)

    def __str__(self):
        return str(self.vlan_name)

    @property
    def network_ip(self):
        return ipaddress.ip_network(self.address)

    @property
    def first(self):
        return ipaddress.ip_address(self.first_ip)

    @property
    def last(self):
        return ipaddress.ip_address(self.last_ip)

    def ip_range(self):
        """Return the list of IP addresses that can be assigned for this network

        The range is defined by the first and last IP
        """
        return [addr for addr in self.network_ip.hosts()
                if self.first <= addr <= self.last]

    def used_ips(self):
        """Return the list of IP addresses in use

        The list is sorted
        """
        return sorted(interface.address for interface in self.interfaces)

    def available_ips(self):
        """Return the list of IP addresses available"""
        return [addr for addr in self.ip_range()
                if addr not in self.used_ips()]

    def gateway(self):
        """Return the network gateway"""
        for interface in self.interfaces:
            if 'gateway' in [tag.name for tag in interface.tags]:
                return interface
        return None

    @staticmethod
    def ip_in_network(ip, address):
        """Ensure the IP is in the network

        :param str user_id: unicode ID of a user
        :returns: a tuple with the IP and network as (IPv4Address, IPv4Network)
        :raises: ValidationError if the IP is not in the network
        """
        addr = ipaddress.ip_address(ip)
        net = ipaddress.ip_network(address)
        if addr not in net:
            raise ValidationError(f'IP address {ip} is not in network {address}')
        return (addr, net)

    @validates('first_ip')
    def validate_first_ip(self, key, ip):
        """Ensure the first IP is in the network"""
        self.ip_in_network(ip, self.address)
        return ip

    @validates('last_ip')
    def validate_last_ip(self, key, ip):
        """Ensure the last IP is in the network and greater than first_ip"""
        addr, net = self.ip_in_network(ip, self.address)
        if addr < self.first:
            raise ValidationError(f'Last IP address {ip} is less than the first address {self.first}')
        return ip

    @validates('interfaces')
    def validate_interfaces(self, key, interface):
        """Ensure the interface IP is in the network range"""
        addr, net = self.ip_in_network(interface.ip, self.address)
        if addr < self.first or addr > self.last:
            raise ValidationError(f'IP address {interface.ip} is not in range {self.first} - {self.last}')
        return interface

    @validates('vlan_name')
    def validate_vlan_name(self, key, string):
        """Ensure the name matches the required format"""
        if string is None:
            return None
        if VLAN_NAME_RE.fullmatch(string) is None:
            raise ValidationError('Vlan name shall match [A-Za-z0-9\-]{3,25}')
        return string

    def to_dict(self):
        d = super().to_dict()
        d.update({
            'vlan_name': self.vlan_name,
            'vlan_id': self.vlan_id,
            'address': self.address,
            'first_ip': self.first_ip,
            'last_ip': self.last_ip,
            'description': self.description,
            'admin_only': self.admin_only,
            'scope': utils.format_field(self.scope),
            'interfaces': [str(interface) for interface in self.interfaces],
        })
        return d


# Table required for Many-to-Many relationships between interfaces and tags
interfacetags_table = db.Table(
    'interfacetags',
    db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True),
    db.Column('interface_id', db.Integer, db.ForeignKey('interface.id'), primary_key=True)
)


class Tag(QRCodeMixin, db.Model):
    admin_only = db.Column(db.Boolean, nullable=False, default=False)


class Host(CreatedMixin, db.Model):
    name = db.Column(db.Text, nullable=False, unique=True)
    type = db.Column(db.Text)
    description = db.Column(db.Text)
    item_id = db.Column(db.Integer, db.ForeignKey('item.id'))

    interfaces = db.relationship('Interface', backref='host')

    def __str__(self):
        return str(self.name)

    @validates('name')
    def validate_name(self, key, string):
        """Ensure the name matches the required format"""
        if string is None:
            return None
        # Force the string to lowercase
        lower_string = string.lower()
        if HOST_NAME_RE.fullmatch(lower_string) is None:
            raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}')
        return lower_string

    def to_dict(self):
        d = super().to_dict()
        d.update({
            'name': self.name,
            'type': self.type,
            'description': self.description,
            'item': utils.format_field(self.item),
            'interfaces': [str(interface) for interface in self.interfaces],
        })
        return d


class Interface(CreatedMixin, db.Model):
    network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False)
    ip = db.Column(postgresql.INET, nullable=False, unique=True)
    name = db.Column(db.Text, nullable=False, unique=True)
    mac_id = db.Column(db.Integer, db.ForeignKey('mac.id'))
    host_id = db.Column(db.Integer, db.ForeignKey('host.id'))

    # Add delete and delete-orphan options to automatically delete cnames when:
    # - deleting an interface
    # - de-associating a cname (removing it from the interface.cnames list)
    cnames = db.relationship('Cname', backref='interface',
                             cascade='all, delete, delete-orphan')
    tags = db.relationship('Tag', secondary=interfacetags_table, lazy='subquery',
                           backref=db.backref('interfaces', lazy=True))

    def __init__(self, **kwargs):
        # Automatically convert network to an instance of Network if it was passed
        # as a string
        if 'network' in kwargs:
            kwargs['network'] = utils.convert_to_model(kwargs['network'], Network, 'vlan_name')
        # WARNING! Setting self.network will call validates_interfaces in the Network class
        # For the validation to work, self.ip must be set before!
        # Ensure that ip is passed before network
        try:
            ip = kwargs.pop('ip')
        except KeyError:
            super().__init__(**kwargs)
        else:
            super().__init__(ip=ip, **kwargs)

    @validates('name')
    def validate_name(self, key, string):
        """Ensure the name matches the required format"""
        if string is None:
            return None
        # Force the string to lowercase
        lower_string = string.lower()
        if HOST_NAME_RE.fullmatch(lower_string) is None:
            raise ValidationError('Interface name shall match [a-z0-9\-]{2,20}')
        return lower_string

    @property
    def address(self):
        return ipaddress.ip_address(self.ip)

    def __str__(self):
        return str(self.name)

    def __repr__(self):
        return f'Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})'

    def to_dict(self):
        d = super().to_dict()
        d.update({
            'network': str(self.network),
            'ip': self.ip,
            'name': self.name,
            'mac': utils.format_field(self.mac),
            'host': utils.format_field(self.host),
            'cnames': [str(cname) for cname in self.cnames],
            'tags': [str(tag) for tag in self.tags],
        })
        return d


class Mac(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    address = db.Column(postgresql.MACADDR, nullable=False, unique=True)
    item_id = db.Column(db.Integer, db.ForeignKey('item.id'))

    interfaces = db.relationship('Interface', backref='mac')

    def __str__(self):
        return str(self.address)

    @validates('address')
    def validate_address(self, key, string):
        """Ensure the address is a valid MAC address"""
        if string is None:
            return None
        if MAC_ADDRESS_RE.fullmatch(string) is None:
            raise ValidationError(f"'{string}' does not appear to be a MAC address")
        return string

    def to_dict(self):
        return {
            'id': self.id,
            'address': self.address,
            'item': utils.format_field(self.item),
            'interfaces': [str(interface) for interface in self.interfaces],
        }


class Cname(CreatedMixin, db.Model):
    name = db.Column(db.Text, nullable=False, unique=True)
    interface_id = db.Column(db.Integer, db.ForeignKey('interface.id'), nullable=False)

    def __str__(self):
        return str(self.name)

    def to_dict(self):
        d = super().to_dict()
        d.update({
            'name': self.name,
            'interface': str(self.interface),
        })
        return d


class NetworkScope(CreatedMixin, db.Model):
    __tablename__ = 'network_scope'
    name = db.Column(CIText, nullable=False, unique=True)
    first_vlan = db.Column(db.Integer, nullable=False, unique=True)
    last_vlan = db.Column(db.Integer, nullable=False, unique=True)
    supernet = db.Column(postgresql.CIDR, nullable=False, unique=True)
    description = db.Column(db.Text)

    networks = db.relationship('Network', backref='scope')

    __table_args__ = (
        sa.CheckConstraint('first_vlan < last_vlan', name='first_vlan_less_than_last_vlan'),
    )

    def __str__(self):
        return str(self.name)

    @property
    def supernet_ip(self):
        return ipaddress.ip_network(self.supernet)

    def prefix_range(self):
        """Return the list of subnet prefix that can be used for this network scope"""
        return list(range(self.supernet_ip.prefixlen + 1, 31))

    def vlan_range(self):
        """Return the list of vlan ids that can be assigned for this network scope

        The range is defined by the first and last vlan
        """
        return range(self.first_vlan, self.last_vlan + 1)

    def used_vlans(self):
        """Return the list of vlan ids in use

        The list is sorted
        """
        return sorted(network.vlan_id for network in self.networks)

    def available_vlans(self):
        """Return the list of vlan ids available"""
        return [vlan for vlan in self.vlan_range()
                if vlan not in self.used_vlans()]

    def used_subnets(self):
        """Return the list of subnets in use

        The list is sorted
        """
        return sorted(network.network_ip for network in self.networks)

    def available_subnets(self, prefix):
        """Return the list of available subnets with the given prefix"""
        return [str(subnet) for subnet in self.supernet_ip.subnets(new_prefix=prefix)
                if subnet not in self.used_subnets()]

    def to_dict(self):
        d = super().to_dict()
        d.update({
            'name': self.name,
            'first_vlan': self.first_vlan,
            'last_vlan': self.last_vlan,
            'supernet': self.supernet,
            'description': self.description,
            'networks': [str(network) for network in self.networks],
        })
        return d


# call configure_mappers after defining all the models
# required by sqlalchemy_continuum
sa.orm.configure_mappers()
ItemVersion = version_class(Item)