# -*- coding: utf-8 -*- """ app.utils ~~~~~~~~~ This module implements utility functions. :copyright: (c) 2017 European Spallation Source ERIC :license: BSD 2-Clause, see LICENSE for more details. """ import base64 import datetime import ipaddress import io import random import sqlalchemy as sa import dateutil.parser import yaml from pathlib import Path from flask import current_app, jsonify, url_for from flask.globals import _app_ctx_stack, _request_ctx_stack from flask_login import current_user from wtforms import ValidationError from .extensions import db def fetch_current_user_id(): """Retrieve the current user id""" # Return None if we are outside of request context. if _app_ctx_stack.top is None or _request_ctx_stack.top is None: return None try: return current_user.id except AttributeError: return None class CSEntryError(Exception): """CSEntryError class Exception used to pass useful information to the client side (API or AJAX) """ status_code = 400 def __init__(self, message, status_code=None, payload=None): super().__init__(self) self.message = message if status_code is not None: self.status_code = status_code self.payload = payload def to_dict(self): rv = dict(self.payload or ()) rv["message"] = self.message return rv def __str__(self): return str(self.to_dict()) def image_to_base64(img, format="PNG"): """Convert a Pillow image to a base64 string :param img: Pillow image :param format: format of the image to use :returns str: image as base64 string """ buf = io.BytesIO() img.save(buf, format=format) return base64.b64encode(buf.getvalue()).decode("ascii") def format_field(field): """Format the given field to a string or None""" if field is None: return None if isinstance(field, datetime.datetime): return field.strftime("%Y-%m-%d %H:%M") return str(field) def convert_to_model(item, model, filter_by="name"): """Convert item to an instance of model Allow to convert a string to an instance of model Raise an exception if the given name is not found :returns: an instance of model """ if item is None or isinstance(item, model): return item kwarg = {filter_by: item} instance = model.query.filter_by(**kwarg).first() if instance is None: raise CSEntryError(f"{item} is not a valid {model.__name__.lower()}") return instance def convert_to_models(d, fields): """Convert the values of the dictionary to the given type :param d: dictionary with the values to update :param fields: list of tuple (key, type_, filter_by) :returns: new updated dictionary """ new = d.copy() for key, type_, filter_by in fields: if filter_by is None or key not in d: # This is not an instance of db.Model but a standard type continue if isinstance(type_, list): values = new[key] if not isinstance(values, list): values = [values] new[key] = [ convert_to_model(value, type_[0], filter_by) for value in values ] else: new[key] = convert_to_model(new[key], type_, filter_by) return new def attribute_to_string(value): """Return the attribute as a string If the attribute is defined in the schema as multi valued then the attribute value is returned as a list See http://ldap3.readthedocs.io/tutorial_searches.html#entries-retrieval This function returns the first item of the list if it's a list :param value: string or list :returns: string """ if isinstance(value, list): try: return value[0] except IndexError: return "" else: return value def get_choices(iterable, allow_blank=False, allow_null=False): """Return a list of (value, label)""" choices = [] if allow_blank: choices = [("", "")] if allow_null: choices.append(("null", "not set")) choices.extend([(val, val) for val in iterable]) return choices def get_model_choices(model, allow_none=False, attr="name", query=None, order_by=None): """Return a list of (value, label)""" choices = [] if allow_none: choices = [(None, "")] if query is None: query = model.query query = query.order_by(getattr(model, order_by or attr)) choices.extend( [(str(instance.id), getattr(instance, attr)) for instance in query.all()] ) return choices def get_query(query, model, **kwargs): """Retrieve the query from the arguments :param query: sqlalchemy base query :param model: model class :param kwargs: kwargs from a request :returns: query filtered by the arguments """ if not kwargs: return query try: # With filter_by(**kwargs), the keyword expressions are extracted # from the primary entity of the query, or the last entity that was # the target of a call to Query.join() # This might not be what we want when join() is used. # Always apply filtering on the given model for key, value in kwargs.items(): query = query.filter(getattr(model, key) == value) except (sa.exc.InvalidRequestError, AttributeError) as e: current_app.logger.warning(f"Invalid query arguments: {e}") raise CSEntryError("Invalid query arguments", status_code=422) return query def lowercase_field(value): """Filter to force form value to lowercase""" try: return value.lower() except AttributeError: return value # coerce functions to use with SelectField that can accept a None value # wtforms always coerce to string by default # Values returned from the form are usually strings but if a field is disabled # None is returned # To pass wtforms validation, the value returned must be part of choices def coerce_to_str_or_none(value): """Convert '', None and 'None' to None""" if value in ("", "None") or value is None: return None return str(value) def coerce_to_int_or_none(value): """Return None if the value is not an integer""" try: return int(value) except ValueError: return None def parse_to_utc(string): """Convert a string to a datetime object with no timezone""" d = dateutil.parser.parse(string) if d.tzinfo is None: # Assume this is UTC return d # Convert to UTC and remove timezone d = d.astimezone(datetime.timezone.utc) return d.replace(tzinfo=None) def random_mac(): """Return a random MAC address""" octets = [ random.randint(0x00, 0xFF), random.randint(0x00, 0xFF), random.randint(0x00, 0xFF), ] octets = [f"{nb:02x}" for nb in octets] return ":".join((current_app.config["MAC_OUI"], *octets)) def pluralize(singular): """Return the plural form of the given word Used to pluralize API endpoints (not any given english word) """ if not singular.endswith("s"): return singular + "s" else: return singular + "es" def format_datetime(value, format="%Y-%m-%d %H:%M"): """Format a datetime to string Function used as a jinja2 filter """ return value.strftime(format) def pretty_yaml(value): """Pretty print yaml Function used as a jinja2 filter """ if value: return yaml.safe_dump(value, default_flow_style=False) else: return "" def trigger_job_once(name, queue_name="low", **kwargs): """Trigger a job only once We can have one running job + one in queue (to apply the latest changes). Make sure that we don't have more than one in queue. """ waiting_task = current_user.get_task_waiting(name) if waiting_task is not None: current_app.logger.info( f'Already one "{name}" task waiting. No need to trigger a new one.' ) return waiting_task started = current_user.get_task_started(name) if started: # There is already one running task. Trigger a new one when it's done. kwargs["depends_on"] = started.id current_app.logger.info(f"Launch new {name} job") task = current_user.launch_task(name, queue_name=queue_name, **kwargs) return task def trigger_core_services_update(): """Trigger a job to update the core services (DNS/DHCP/radius) This function should be called every time an interface or host is created/edited The AWX template uses its own inventory that is updated on launch to avoid blocking the main inventory update when running. There is no need to trigger an inventory update. We can have one running job + one in queue to apply the latest changes. Make sure that we don't have more than one in queue. """ job_template = current_app.config["AWX_CORE_SERVICES_UPDATE"] resource = current_app.config.get("AWX_CORE_SERVICES_UPDATE_RESOURCE", "job") return trigger_job_once( "trigger_core_services_update", queue_name="normal", func="launch_awx_job", job_template=job_template, resource=resource, ) def trigger_inventory_update(): """Trigger a job to update the inventory in AWX This function should be called every time something impacting the inventory is updated (AnsibleGroup, Host, Interface, Network...) We can have one running job + one in queue to apply the latest changes. Make sure that we don't have more than one in queue. """ inventory_source = current_app.config["AWX_INVENTORY_SOURCE"] # Put it on the "high" queue so that it has higher priority than all other jobs return trigger_job_once( "trigger_inventory_update", queue_name="high", func="launch_awx_job", resource="inventory_source", inventory_source=inventory_source, ) def trigger_ansible_groups_reindex(): """Trigger a job to reindex the Ansible groups This function should be called every time a Host or Interface is modified to make sure the dynamic Ansible groups are indexed properly in Elasticsearch. We can have one running job + one in queue to apply the latest changes. Make sure that we don't have more than one in queue. """ return trigger_job_once( "reindex_ansible_groups", queue_name="low", func="reindex_ansible_groups", ) def trigger_vm_creation( host, vm_disk_size, vm_cores, vm_memory, vm_osversion, skip_post_install_job ): """Trigger a job to create a virtual machine or virtual IOC""" domain = str(host.main_interface.network.domain) extra_vars = [ f"vmname={host.name}", f"domain={domain}", f"vm_disk_size={vm_disk_size}", f"vm_cores={vm_cores}", f"vm_memory={vm_memory}", f"vm_osversion={vm_osversion}", ] if host.is_ioc: task_name = "trigger_vioc_creation" job_template = current_app.config["AWX_CREATE_VIOC"] post_job_template = current_app.config["AWX_POST_INSTALL"]["VIOC"].get(domain) else: task_name = "trigger_vm_creation" job_template = current_app.config["AWX_CREATE_VM"] post_job_template = current_app.config["AWX_POST_INSTALL"]["VM"].get(domain) current_app.logger.info( f"Launch new job to create the {host} VM: {job_template} with {extra_vars}" ) task = current_user.launch_task( task_name, resource="job", queue_name="low", func="launch_awx_job", job_template=job_template, extra_vars=extra_vars, ) if ( post_job_template and (not skip_post_install_job) and (not vm_osversion.startswith("windows")) ): current_user.launch_task( "trigger_post_install_job", resource="job", queue_name="low", func="launch_awx_job", job_template=post_job_template, limit=f"{host.fqdn}", depends_on=task.id, ) current_app.logger.info( f"Trigger post install job: run {post_job_template} on {host.fqdn}" ) return task def trigger_set_network_boot_profile(host, boot_profile): """Trigger a job to set the boot profile for host""" extra_vars = [ f"autoinstall_boot_profile={boot_profile}", f"autoinstall_pxe_mac_addr={host.main_interface.mac}", ] job_template = current_app.config["AWX_SET_NETWORK_BOOT_PROFILE"] current_app.logger.info( f"Launch new job to set the network boot profile for {host.name}: {job_template} with {extra_vars}" ) task = current_user.launch_task( "trigger_set_network_boot_profile", resource="job", queue_name="low", func="launch_awx_job", job_template=job_template, extra_vars=extra_vars, ) return task def redirect_to_job_status(job_id): """ The answer to a client request, leading it to regularly poll a job status. :param job_id: The id of the job started, which needs to be pulled by the client :type job_id: rq.job_id :return: HTTP response """ return jsonify({}), 202, {"Location": url_for("main.job_status", job_id=job_id)} def unique_filename(filename): """Return an unique filename :param filename: filename that should be unique :returns: unique filename """ p = Path(filename) if not p.exists(): return filename base = p.with_suffix("") nb = 1 while True: unique = Path(f"{base}-{nb}{p.suffix}") if not unique.exists(): break nb += 1 return unique def retrieve_data_for_datatables(values, model, filter_sensitive=False): """Return the filtered data of model to datatables This function is supposed to be called when using datatables with serverSide processing. :param values: a `~werkzeug.datastructures.MultiDict`, typically request.values :param model: class of the model to query :param bool filter_sensitive: filter out sensitive data if set to True :return: json object required by datatables """ # Get the parameters from the post data sent by datatables draw = values.get("draw", 0, type=int) per_page = values.get("length", 20, type=int) # page starts at 1 in elasticsearch page = int(values.get("start", 0, type=int) / per_page) + 1 search = values.get("search[value]", "") if search == "": search = "*" order_column = values.get("order[0][column]") if order_column is None: sort = None else: name = values.get(f"columns[{order_column}][data]") order_dir = values.get("order[0][dir]", "asc") # Sorting can be done directly on all fields of type # keyword/date/long # To sort on fields of type text, we use the extra .keyword field if name in ("created_at", "updated_at", "quantity"): sort = f"{name}:{order_dir}" else: sort = f"{name}.keyword:{order_dir}" instances, nb_filtered = model.search( search, page=page, per_page=per_page, sort=sort, filter_sensitive=filter_sensitive, ) data = [instance.to_dict(recursive=True) for instance in instances] # Total number of items before filtering nb_total = db.session.query(sa.func.count(model.id)).scalar() response = { "draw": draw, "recordsTotal": nb_total, "recordsFiltered": nb_filtered, "data": data, } return jsonify(response) def minutes_ago(minutes): """Return the datetime x minutes ago""" return datetime.datetime.utcnow() - datetime.timedelta(minutes=minutes) def update_ansible_vars(host, vars): """Update the host ansible_vars Return False if no variables were changed, True otherwise """ if host.ansible_vars: local_ansible_vars = host.ansible_vars.copy() local_ansible_vars.update(vars) if local_ansible_vars == host.ansible_vars: # No change return False else: host.ansible_vars.update(vars) # If we don't flag the field as modified, it's not saved to the database # Probably because we update an existing dictionary sa.orm.attributes.flag_modified(host, "ansible_vars") else: host.ansible_vars = vars return True def ip_in_network(ip, address): """Ensure the IP is in the network :returns: a tuple with the IP and network as (IPv4Address, IPv4Network) :raises: ValidationError if the IP is not in the network """ addr = ipaddress.ip_address(ip) net = ipaddress.ip_network(address) if addr not in net: raise ValidationError(f"IP address {ip} is not in network {address}") return (addr, net) def validate_ip(ip, network): """Ensure the IP is in the network range""" addr, net = ip_in_network(ip, network.address) # Admin user can create IP outside the defined range try: # current_user is a local proxy and is not # valid outside of a request context. is_admin = current_user.is_admin except AttributeError: is_admin = False if not is_admin: if addr < network.first or addr > network.last: raise ValidationError( f"IP address {ip} is not in range {network.first} - {network.last}" ) def overlaps(subnet, subnets): """Return True if the subnet overlaps with any of the subnets""" for network in subnets: if subnet.overlaps(network): return True return False