Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
utils.py 17.35 KiB
# -*- coding: utf-8 -*-
"""
app.utils
~~~~~~~~~

This module implements utility functions.

:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.

"""
import base64
import datetime
import ipaddress
import io
import random
import sqlalchemy as sa
import dateutil.parser
import yaml
from pathlib import Path
from flask import current_app, jsonify, url_for
from flask.globals import _app_ctx_stack, _request_ctx_stack
from flask_login import current_user
from wtforms import ValidationError
from .extensions import db


def fetch_current_user_id():
    """Retrieve the current user id"""
    # Return None if we are outside of request context.
    if _app_ctx_stack.top is None or _request_ctx_stack.top is None:
        return None
    try:
        return current_user.id
    except AttributeError:
        return None


class CSEntryError(Exception):
    """CSEntryError class

    Exception used to pass useful information to the client side (API or AJAX)
    """

    status_code = 400

    def __init__(self, message, status_code=None, payload=None):
        super().__init__(self)
        self.message = message
        if status_code is not None:
            self.status_code = status_code
        self.payload = payload

    def to_dict(self):
        rv = dict(self.payload or ())
        rv["message"] = self.message
        return rv

    def __str__(self):
        return str(self.to_dict())


def image_to_base64(img, format="PNG"):
    """Convert a Pillow image to a base64 string

    :param img: Pillow image
    :param format: format of the image to use
    :returns str: image as base64 string
    """
    buf = io.BytesIO()
    img.save(buf, format=format)
    return base64.b64encode(buf.getvalue()).decode("ascii")


def format_field(field):
    """Format the given field to a string or None"""
    if field is None:
        return None
    if isinstance(field, datetime.datetime):
        return field.strftime("%Y-%m-%d %H:%M")
    return str(field)


def convert_to_model(item, model, filter_by="name"):
    """Convert item to an instance of model

    Allow to convert a string to an instance of model
    Raise an exception if the given name is not found

    :returns: an instance of model
    """
    if item is None or isinstance(item, model):
        return item
    kwarg = {filter_by: item}
    instance = model.query.filter_by(**kwarg).first()
    if instance is None:
        raise CSEntryError(f"{item} is not a valid {model.__name__.lower()}")
    return instance


def convert_to_models(d, fields):
    """Convert the values of the dictionary to the given type

    :param d: dictionary with the values to update
    :param fields: list of tuple (key, type_, filter_by)
    :returns: new updated dictionary
    """
    new = d.copy()
    for key, type_, filter_by in fields:
        if filter_by is None or key not in d:
            # This is not an instance of db.Model but a standard type
            continue
        if isinstance(type_, list):
            values = new[key]
            if not isinstance(values, list):
                values = [values]
            new[key] = [
                convert_to_model(value, type_[0], filter_by) for value in values
            ]
        else:
            new[key] = convert_to_model(new[key], type_, filter_by)
    return new


def attribute_to_string(value):
    """Return the attribute as a string

    If the attribute is defined in the schema as multi valued
    then the attribute value is returned as a list
    See http://ldap3.readthedocs.io/tutorial_searches.html#entries-retrieval

    This function returns the first item of the list if it's a list

    :param value: string or list
    :returns: string
    """
    if isinstance(value, list):
        try:
            return value[0]
        except IndexError:
            return ""
    else:
        return value


def get_choices(iterable, allow_blank=False, allow_null=False):
    """Return a list of (value, label)"""
    choices = []
    if allow_blank:
        choices = [("", "")]
    if allow_null:
        choices.append(("null", "not set"))
    choices.extend([(val, val) for val in iterable])
    return choices


def get_model_choices(model, allow_none=False, attr="name", query=None, order_by=None):
    """Return a list of (value, label)"""
    choices = []
    if allow_none:
        choices = [(None, "")]
    if query is None:
        query = model.query
    query = query.order_by(getattr(model, order_by or attr))
    choices.extend(
        [(str(instance.id), getattr(instance, attr)) for instance in query.all()]
    )
    return choices


def get_query(query, model, **kwargs):
    """Retrieve the query from the arguments

    :param query: sqlalchemy base query
    :param model: model class
    :param kwargs: kwargs from a request
    :returns: query filtered by the arguments
    """
    if not kwargs:
        return query
    try:
        # With filter_by(**kwargs), the keyword expressions are extracted
        # from the primary entity of the query, or the last entity that was
        # the target of a call to Query.join()
        # This might not be what we want when join() is used.
        # Always apply filtering on the given model
        for key, value in kwargs.items():
            query = query.filter(getattr(model, key) == value)
    except (sa.exc.InvalidRequestError, AttributeError) as e:
        current_app.logger.warning(f"Invalid query arguments: {e}")
        raise CSEntryError("Invalid query arguments", status_code=422)
    return query


def lowercase_field(value):
    """Filter to force form value to lowercase"""
    try:
        return value.lower()
    except AttributeError:
        return value


# coerce functions to use with SelectField that can accept a None value
# wtforms always coerce to string by default
# Values returned from the form are usually strings but if a field is disabled
# None is returned
# To pass wtforms validation, the value returned must be part of choices
def coerce_to_str_or_none(value):
    """Convert '', None and 'None' to None"""
    if value in ("", "None") or value is None:
        return None
    return str(value)


def coerce_to_int_or_none(value):
    """Return None if the value is not an integer"""
    try:
        return int(value)
    except ValueError:
        return None


def parse_to_utc(string):
    """Convert a string to a datetime object with no timezone"""
    d = dateutil.parser.parse(string)
    if d.tzinfo is None:
        # Assume this is UTC
        return d
    # Convert to UTC and remove timezone
    d = d.astimezone(datetime.timezone.utc)
    return d.replace(tzinfo=None)


def random_mac():
    """Return a random MAC address"""
    octets = [
        random.randint(0x00, 0xFF),
        random.randint(0x00, 0xFF),
        random.randint(0x00, 0xFF),
    ]
    octets = [f"{nb:02x}" for nb in octets]
    return ":".join((current_app.config["MAC_OUI"], *octets))


def pluralize(singular):
    """Return the plural form of the given word

    Used to pluralize API endpoints (not any given english word)
    """
    if not singular.endswith("s"):
        return singular + "s"
    else:
        return singular + "es"


def format_datetime(value, format="%Y-%m-%d %H:%M"):
    """Format a datetime to string

    Function used as a jinja2 filter
    """
    return value.strftime(format)


def pretty_yaml(value):
    """Pretty print yaml

    Function used as a jinja2 filter
    """
    if value:
        return yaml.safe_dump(value, default_flow_style=False)
    else:
        return ""


def trigger_job_once(name, queue_name="low", **kwargs):
    """Trigger a job only once

    We can have one running job + one in queue (to apply the latest changes).
    Make sure that we don't have more than one in queue.
    """
    waiting_task = current_user.get_task_waiting(name)
    if waiting_task is not None:
        current_app.logger.info(
            f'Already one "{name}" task waiting. No need to trigger a new one.'
        )
        return waiting_task
    started = current_user.get_task_started(name)
    if started:
        # There is already one running task. Trigger a new one when it's done.
        kwargs["depends_on"] = started.id
    current_app.logger.info(f"Launch new {name} job")
    task = current_user.launch_task(name, queue_name=queue_name, **kwargs)
    return task


def trigger_core_services_update():
    """Trigger a job to update the core services (DNS/DHCP/radius)

    This function should be called every time an interface or host is created/edited

    The AWX template uses its own inventory that is updated on launch to avoid
    blocking the main inventory update when running.
    There is no need to trigger an inventory update.

    We can have one running job + one in queue to apply the latest changes.
    Make sure that we don't have more than one in queue.
    """
    job_template = current_app.config["AWX_CORE_SERVICES_UPDATE"]
    resource = current_app.config.get("AWX_CORE_SERVICES_UPDATE_RESOURCE", "job")
    return trigger_job_once(
        "trigger_core_services_update",
        queue_name="normal",
        func="launch_awx_job",
        job_template=job_template,
        resource=resource,
    )


def trigger_inventory_update():
    """Trigger a job to update the inventory in AWX

    This function should be called every time something impacting the inventory is updated
    (AnsibleGroup, Host, Interface, Network...)

    We can have one running job + one in queue to apply the latest changes.
    Make sure that we don't have more than one in queue.
    """
    inventory_source = current_app.config["AWX_INVENTORY_SOURCE"]
    # Put it on the "high" queue so that it has higher priority than all other jobs
    return trigger_job_once(
        "trigger_inventory_update",
        queue_name="high",
        func="launch_awx_job",
        resource="inventory_source",
        inventory_source=inventory_source,
    )


def trigger_ansible_groups_reindex():
    """Trigger a job to reindex the Ansible groups

    This function should be called every time a Host or Interface is modified
    to make sure the dynamic Ansible groups are indexed properly in Elasticsearch.

    We can have one running job + one in queue to apply the latest changes.
    Make sure that we don't have more than one in queue.
    """
    return trigger_job_once(
        "reindex_ansible_groups", queue_name="low", func="reindex_ansible_groups",
    )


def trigger_vm_creation(
    host, vm_disk_size, vm_cores, vm_memory, vm_osversion, skip_post_install_job
):
    """Trigger a job to create a virtual machine or virtual IOC"""
    domain = str(host.main_interface.network.domain)
    extra_vars = [
        f"vmname={host.name}",
        f"domain={domain}",
        f"vm_disk_size={vm_disk_size}",
        f"vm_cores={vm_cores}",
        f"vm_memory={vm_memory}",
        f"vm_osversion={vm_osversion}",
    ]
    if host.is_ioc:
        task_name = "trigger_vioc_creation"
        job_template = current_app.config["AWX_CREATE_VIOC"]
        post_job_template = current_app.config["AWX_POST_INSTALL"]["VIOC"].get(domain)
    else:
        task_name = "trigger_vm_creation"
        job_template = current_app.config["AWX_CREATE_VM"]
        post_job_template = current_app.config["AWX_POST_INSTALL"]["VM"].get(domain)
    current_app.logger.info(
        f"Launch new job to create the {host} VM: {job_template} with {extra_vars}"
    )
    task = current_user.launch_task(
        task_name,
        resource="job",
        queue_name="low",
        func="launch_awx_job",
        job_template=job_template,
        extra_vars=extra_vars,
    )
    if (
        post_job_template
        and (not skip_post_install_job)
        and (not vm_osversion.startswith("windows"))
    ):
        current_user.launch_task(
            "trigger_post_install_job",
            resource="job",
            queue_name="low",
            func="launch_awx_job",
            job_template=post_job_template,
            limit=f"{host.fqdn}",
            depends_on=task.id,
        )
        current_app.logger.info(
            f"Trigger post install job: run {post_job_template} on {host.fqdn}"
        )
    return task


def trigger_set_network_boot_profile(host, boot_profile):
    """Trigger a job to set the boot profile for host"""
    extra_vars = [
        f"autoinstall_boot_profile={boot_profile}",
        f"autoinstall_pxe_mac_addr={host.main_interface.mac}",
    ]
    job_template = current_app.config["AWX_SET_NETWORK_BOOT_PROFILE"]
    current_app.logger.info(
        f"Launch new job to set the network boot profile for {host.name}: {job_template} with {extra_vars}"
    )
    task = current_user.launch_task(
        "trigger_set_network_boot_profile",
        resource="job",
        queue_name="low",
        func="launch_awx_job",
        job_template=job_template,
        extra_vars=extra_vars,
    )
    return task


def redirect_to_job_status(job_id):
    """
    The answer to a client request, leading it to regularly poll a job status.

    :param job_id: The id of the job started, which needs to be pulled by the client
    :type job_id: rq.job_id
    :return: HTTP response
    """
    return jsonify({}), 202, {"Location": url_for("main.job_status", job_id=job_id)}


def unique_filename(filename):
    """Return an unique filename

    :param filename: filename that should be unique
    :returns: unique filename
    """
    p = Path(filename)
    if not p.exists():
        return filename
    base = p.with_suffix("")
    nb = 1
    while True:
        unique = Path(f"{base}-{nb}{p.suffix}")
        if not unique.exists():
            break
        nb += 1
    return unique


def retrieve_data_for_datatables(values, model, filter_sensitive=False):
    """Return the filtered data of model to datatables

    This function is supposed to be called when using datatables
    with serverSide processing.

    :param values: a `~werkzeug.datastructures.MultiDict`, typically request.values
    :param model: class of the model to query
    :param bool filter_sensitive: filter out sensitive data if set to True
    :return: json object required by datatables
    """
    # Get the parameters from the post data sent by datatables
    draw = values.get("draw", 0, type=int)
    per_page = values.get("length", 20, type=int)
    # page starts at 1 in elasticsearch
    page = int(values.get("start", 0, type=int) / per_page) + 1
    search = values.get("search[value]", "")
    if search == "":
        search = "*"
    order_column = values.get("order[0][column]")
    if order_column is None:
        sort = None
    else:
        name = values.get(f"columns[{order_column}][data]")
        order_dir = values.get("order[0][dir]", "asc")
        # Sorting can be done directly on all fields of type
        # keyword/date/long
        # To sort on fields of type text, we use the extra .keyword field
        if name in ("created_at", "updated_at", "quantity"):
            sort = f"{name}:{order_dir}"
        else:
            sort = f"{name}.keyword:{order_dir}"
    instances, nb_filtered = model.search(
        search,
        page=page,
        per_page=per_page,
        sort=sort,
        filter_sensitive=filter_sensitive,
    )
    data = [instance.to_dict(recursive=True) for instance in instances]
    # Total number of items before filtering
    nb_total = db.session.query(sa.func.count(model.id)).scalar()
    response = {
        "draw": draw,
        "recordsTotal": nb_total,
        "recordsFiltered": nb_filtered,
        "data": data,
    }
    return jsonify(response)


def minutes_ago(minutes):
    """Return the datetime x minutes ago"""
    return datetime.datetime.utcnow() - datetime.timedelta(minutes=minutes)


def update_ansible_vars(host, vars):
    """Update the host ansible_vars

    Return False if no variables were changed, True otherwise
    """
    if host.ansible_vars:
        local_ansible_vars = host.ansible_vars.copy()
        local_ansible_vars.update(vars)
        if local_ansible_vars == host.ansible_vars:
            # No change
            return False
        else:
            host.ansible_vars.update(vars)
            # If we don't flag the field as modified, it's not saved to the database
            # Probably because we update an existing dictionary
            sa.orm.attributes.flag_modified(host, "ansible_vars")
    else:
        host.ansible_vars = vars
    return True


def ip_in_network(ip, address):
    """Ensure the IP is in the network

    :returns: a tuple with the IP and network as (IPv4Address, IPv4Network)
    :raises: ValidationError if the IP is not in the network
    """
    addr = ipaddress.ip_address(ip)
    net = ipaddress.ip_network(address)
    if addr not in net:
        raise ValidationError(f"IP address {ip} is not in network {address}")
    return (addr, net)


def validate_ip(ip, network):
    """Ensure the IP is in the network range"""
    addr, net = ip_in_network(ip, network.address)
    # Admin user can create IP outside the defined range
    try:
        # current_user is a local proxy and is not
        # valid outside of a request context.
        is_admin = current_user.is_admin
    except AttributeError:
        is_admin = False
    if not is_admin:
        if addr < network.first or addr > network.last:
            raise ValidationError(
                f"IP address {ip} is not in range {network.first} - {network.last}"
            )


def overlaps(subnet, subnets):
    """Return True if the subnet overlaps with any of the subnets"""
    for network in subnets:
        if subnet.overlaps(network):
            return True
    return False