# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~

This module implements the models used in the app.

:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.

"""
import ipaddress
import re
import qrcode
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import validates
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy_continuum import make_versioned, version_class
from citext import CIText
from flask import current_app
from flask_login import UserMixin
from wtforms import ValidationError
from .extensions import db, login_manager, ldap_manager, jwt
from .plugins import FlaskUserPlugin
from . import utils


ICS_ID_RE = re.compile('[A-Z]{3}[0-9]{3}')
HOST_NAME_RE = re.compile('^[a-z0-9\-]{2,20}$')
make_versioned(plugins=[FlaskUserPlugin()])


@login_manager.user_loader
def load_user(user_id):
    """User loader callback for flask-login

    :param str user_id: unicode ID of a user
    :returns: corresponding user object or None
    """
    return User.query.get(int(user_id))


@jwt.user_loader_callback_loader
def user_loader_callback(identity):
    """User loader callback for flask-jwt-extended

    :param str identity: identity from the token (user_id)
    :returns: corresponding user object or None
    """
    return User.query.get(int(identity))


@ldap_manager.save_user
def save_user(dn, username, data, memberships):
    """User saver for flask-ldap3-login

    This method is called whenever a LDAPLoginForm()
    successfully validates.
    """
    user = User.query.filter_by(username=username).first()
    if user is None:
        user = User(username=username,
                    name=utils.attribute_to_string(data['cn']),
                    email=utils.attribute_to_string(data['mail']))
    # Always update the user groups to keep them up-to-date
    user.groups = [utils.attribute_to_string(group['cn']) for group in memberships]
    db.session.add(user)
    db.session.commit()
    return user


# Table required for Many-to-Many relationships between users and groups
usergroups_table = db.Table(
    'usergroups',
    db.Column('user_id', db.Integer, db.ForeignKey('user_account.id')),
    db.Column('group_id', db.Integer, db.ForeignKey('group.id'))
)


class Group(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.Text, nullable=False, unique=True)

    def __str__(self):
        return self.name


def find_or_create_group(name):
    """Return the existing group or a newly created one"""
    group = Group.query.filter_by(name=name).first()
    return group or Group(name=name)


class User(db.Model, UserMixin):
    # "user" is a reserved word in postgresql
    # so let's use another name
    __tablename__ = 'user_account'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.Text, unique=True)
    name = db.Column(db.Text)
    email = db.Column(db.Text)
    grp = db.relationship('Group', secondary=usergroups_table,
                          backref=db.backref('members', lazy='dynamic'))
    # Proxy the 'name' attribute from the 'grp' relationship
    # See http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html
    groups = association_proxy('grp', 'name',
                               creator=find_or_create_group)

    def get_id(self):
        """Return the user id as unicode

        Required by flask-login
        """
        return str(self.id)

    @property
    def is_admin(self):
        return current_app.config['CSENTRY_LDAP_GROUPS']['admin'] in self.groups

    def is_member_of_one_group(self, groups):
        """Return True if the user is at least member of one of the given groups"""
        names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
        return bool(set(self.groups) & set(names))

    def is_member_of_all_groups(self, groups):
        """Return True if the user is member of all the given groups"""
        names = [current_app.config['CSENTRY_LDAP_GROUPS'].get(group) for group in groups]
        return set(names).issubset(self.groups)

    def __str__(self):
        return self.name


class QRCodeMixin:
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(CIText, nullable=False, unique=True)

    def image(self):
        """Return a QRCode image to identify a record

        The QRCode includes:
             - CSE prefix
             - the table name
             - the name of the record
        """
        data = ':'.join(['CSE', self.__tablename__, self.name])
        return qrcode.make(data, version=1, box_size=5)

    def __str__(self):
        return self.name

    def to_dict(self, qrcode=False):
        d = {'id': self.id, 'name': self.name}
        if qrcode:
            d['qrcode'] = utils.image_to_base64(self.image())
        return d


class Action(QRCodeMixin, db.Model):
    pass


class Manufacturer(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='manufacturer')


class Model(QRCodeMixin, db.Model):
    description = db.Column(db.Text)
    items = db.relationship('Item', back_populates='model')

    def to_dict(self, qrcode=False):
        d = super().to_dict(qrcode)
        d['description'] = self.description
        return d


class Location(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='location')


class Status(QRCodeMixin, db.Model):
    items = db.relationship('Item', back_populates='status')


class Item(db.Model):
    __versioned__ = {
        'exclude': ['_created', 'ics_id', 'serial_number',
                    'manufacturer_id', 'model_id']
    }

    id = db.Column(db.Integer, primary_key=True)
    _created = db.Column(db.DateTime, default=db.func.now())
    _updated = db.Column(db.DateTime, default=db.func.now(), onupdate=db.func.now())
    ics_id = db.Column(db.Text, unique=True, index=True)
    serial_number = db.Column(db.Text, nullable=False)
    manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id'))
    model_id = db.Column(db.Integer, db.ForeignKey('model.id'))
    location_id = db.Column(db.Integer, db.ForeignKey('location.id'))
    status_id = db.Column(db.Integer, db.ForeignKey('status.id'))
    parent_id = db.Column(db.Integer, db.ForeignKey('item.id'))

    manufacturer = db.relationship('Manufacturer', back_populates='items')
    model = db.relationship('Model', back_populates='items')
    location = db.relationship('Location', back_populates='items')
    status = db.relationship('Status', back_populates='items')
    children = db.relationship('Item', backref=db.backref('parent', remote_side=[id]))
    macs = db.relationship('Mac', backref='item')

    def __init__(self, **kwargs):
        # Automatically convert manufacturer/model/location/status to an
        # instance of their class if passed as a string
        for key, cls in [('manufacturer', Manufacturer),
                         ('model', Model),
                         ('location', Location),
                         ('status', Status)]:
            if key in kwargs:
                kwargs[key] = utils.convert_to_model(kwargs[key], cls)
        super().__init__(**kwargs)

    def __str__(self):
        return str(self.ics_id)

    @validates('ics_id')
    def validate_ics_id(self, key, string):
        """Ensure the ICS id field matches the required format"""
        if string is not None:
            if ICS_ID_RE.fullmatch(string) is None:
                raise ValidationError('ICS id shall match [A-Z]{3}[0-9]{3}')
        return string

    def to_dict(self, long=False):
        d = {
            'id': self.id,
            'ics_id': self.ics_id,
            'serial_number': self.serial_number,
            'manufacturer': utils.format_field(self.manufacturer),
            'model': utils.format_field(self.model),
            'location': utils.format_field(self.location),
            'status': utils.format_field(self.status),
            'updated': utils.format_field(self._updated),
            'created': utils.format_field(self._created),
            'parent': utils.format_field(self.parent),
        }
        if long:
            d['children'] = [utils.format_field(child) for child in self.children]
            d['macs'] = [mac.to_dict(long=True) for mac in self.macs]
            d['history'] = self.history()
        return d

    def history(self):
        versions = []
        for version in self.versions:
            # parent is an attribute used by SQLAlchemy-Continuum
            # version.parent refers to an ItemVersion instance (and has no link with
            # the item parent_id)
            # We need to retrieve the parent "manually"
            if version.parent_id is None:
                parent = None
            else:
                parent = Item.query.get(version.parent_id)
            versions.append({
                'updated': utils.format_field(version._updated),
                'location': utils.format_field(version.location),
                'status': utils.format_field(version.status),
                'parent': utils.format_field(parent),
            })
        return versions


class Network(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    vlan_name = db.Column(CIText, nullable=False, unique=True)
    vlan_id = db.Column(db.Integer, nullable=False, unique=True)
    address = db.Column(postgresql.CIDR, nullable=False, unique=True)
    first_ip = db.Column(postgresql.INET, nullable=False, unique=True)
    last_ip = db.Column(postgresql.INET, nullable=False, unique=True)
    gateway = db.Column(postgresql.INET)
    description = db.Column(db.Text)
    admin_only = db.Column(db.Boolean, nullable=False, default=False)
    scope_id = db.Column(db.Integer, db.ForeignKey('network_scope.id'), nullable=False)

    hosts = db.relationship('Host', backref='network')

    __table_args__ = (
        sa.CheckConstraint('first_ip < last_ip', name='first_ip_less_than_last_ip'),
        sa.CheckConstraint('first_ip << address', name='first_ip_in_network'),
        sa.CheckConstraint('last_ip << address', name='last_ip_in_network'),
    )

    def __init__(self, **kwargs):
        # Automatically convert scope to an instance of NetworkScope if it was passed
        # as a string
        if 'scope' in kwargs:
            kwargs['scope'] = utils.convert_to_model(kwargs['scope'], NetworkScope, 'name')
        super().__init__(**kwargs)

    def __str__(self):
        return str(self.vlan_name)

    @property
    def network_ip(self):
        return ipaddress.ip_network(self.address)

    @property
    def first(self):
        return ipaddress.ip_address(self.first_ip)

    @property
    def last(self):
        return ipaddress.ip_address(self.last_ip)

    def ip_range(self):
        """Return the list of IP addresses that can be assigned for this network

        The range is defined by the first and last IP
        """
        return [addr for addr in self.network_ip.hosts()
                if self.first <= addr <= self.last]

    def used_ips(self):
        """Return the list of IP addresses in use

        The list is sorted
        """
        return sorted(host.address for host in self.hosts)

    def available_ips(self):
        """Return the list of IP addresses available"""
        return [addr for addr in self.ip_range()
                if addr not in self.used_ips()]

    @staticmethod
    def ip_in_network(ip, address):
        """Ensure the IP is in the network

        :param str user_id: unicode ID of a user
        :returns: a tuple with the IP and network as (IPv4Address, IPv4Network)
        :raises: ValidationError if the IP is not in the network
        """
        addr = ipaddress.ip_address(ip)
        net = ipaddress.ip_network(address)
        if addr not in net:
            raise ValidationError(f'IP address {ip} is not in network {address}')
        return (addr, net)

    @validates('first_ip')
    def validate_first_ip(self, key, ip):
        """Ensure the first IP is in the network"""
        self.ip_in_network(ip, self.address)
        return ip

    @validates('last_ip')
    def validate_last_ip(self, key, ip):
        """Ensure the last IP is in the network and greater than first_ip"""
        addr, net = self.ip_in_network(ip, self.address)
        if addr < self.first:
            raise ValidationError(f'Last IP address {ip} is less than the first address {self.first}')
        return ip

    @validates('hosts')
    def validate_hosts(self, key, host):
        """Ensure the host IP is in the network range"""
        addr, net = self.ip_in_network(host.ip, self.address)
        if addr < self.first or addr > self.last:
            raise ValidationError(f'IP address {host.ip} is not in range {self.first} - {self.last}')
        return host

    def to_dict(self):
        return {
            'id': self.id,
            'vlan_name': self.vlan_name,
            'address': self.address,
            'first_ip': self.first_ip,
            'last_ip': self.last_ip,
            'gateway': self.gateway,
            'vlan_id': self.vlan_id,
            'description': self.description,
            'admin_only': self.admin_only,
            'scope': utils.format_field(self.scope),
        }


# Table required for Many-to-Many relationships between hosts and tags
hosttags_table = db.Table(
    'hosttags',
    db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True),
    db.Column('host_id', db.Integer, db.ForeignKey('host.id'), primary_key=True)
)


class Tag(QRCodeMixin, db.Model):
    pass


class Host(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False)
    ip = db.Column(postgresql.INET, nullable=False, unique=True)
    name = db.Column(db.Text, nullable=False, unique=True)
    description = db.Column(db.Text)
    mac_id = db.Column(db.Integer, db.ForeignKey('mac.id'))

    cnames = db.relationship('Cname', backref='host')
    tags = db.relationship('Tag', secondary=hosttags_table, lazy='subquery',
                           backref=db.backref('hosts', lazy=True))

    def __init__(self, **kwargs):
        # Automatically convert network to an instance of Network if it was passed
        # as an address string
        if 'network' in kwargs:
            kwargs['network'] = utils.convert_to_model(kwargs['network'], Network, 'address')
        # WARNING! Setting self.network will call validates_hosts in the Network class
        # For the validation to work, self.ip must be set before!
        # Ensure that ip is passed before network
        try:
            ip = kwargs.pop('ip')
        except KeyError:
            super().__init__(**kwargs)
        else:
            super().__init__(ip=ip, **kwargs)

    @validates('name')
    def validate_name(self, key, string):
        """Ensure the hostname matches the required format"""
        if string is None:
            return None
        # Force the string to lowercase
        lower_string = string.lower()
        if HOST_NAME_RE.fullmatch(lower_string) is None:
            raise ValidationError('Host name shall match [a-z0-9\-]{2,20}')
        return lower_string

    @property
    def address(self):
        return ipaddress.ip_address(self.ip)

    def __str__(self):
        return str(self.name)

    def __repr__(self):
        return f'Host(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, description={self.description}, mac={self.mac})'

    def to_dict(self, long=False):
        d = {
            'id': self.id,
            'ip': self.ip,
            'name': self.name,
            'description': self.description,
            'network_id': self.network_id,
        }
        if long:
            d['mac'] = getattr(self.mac, 'address', None)
        return d


class Mac(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    address = db.Column(postgresql.MACADDR, nullable=False, unique=True)
    item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False)

    hosts = db.relationship('Host', backref='mac')

    def __str__(self):
        return str(self.address)

    def to_dict(self, long=False):
        d = {
            'id': self.id,
            'address': self.address,
            'item_id': self.item_id,
        }
        if long:
            d['item_ics_id'] = self.item.ics_id
            d['hosts'] = [host.to_dict() for host in self.hosts]
        return d


class Cname(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.Text, nullable=False, unique=True)
    host_id = db.Column(db.Integer, db.ForeignKey('host.id'), nullable=False, unique=True)

    def __str__(self):
        return str(self.name)

    def to_dict(self):
        return {
            'id': self.id,
            'name': self.name,
            'host_id': self.host_id,
        }


class NetworkScope(db.Model):
    __tablename__ = 'network_scope'
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(CIText, nullable=False, unique=True)
    first_vlan = db.Column(db.Integer, nullable=False, unique=True)
    last_vlan = db.Column(db.Integer, nullable=False, unique=True)
    subnet = db.Column(postgresql.CIDR, nullable=False, unique=True)

    networks = db.relationship('Network', backref='scope')

    def __str__(self):
        return str(self.name)

    def to_dict(self):
        return {
            'id': self.id,
            'name': self.name,
            'first_vlan': self.first_vlan,
            'last_vlan': self.last_vlan,
            'subnet': self.subnet,
        }


# call configure_mappers after defining all the models
# required by sqlalchemy_continuum
sa.orm.configure_mappers()
ItemVersion = version_class(Item)