Forked from
ICS Control System Infrastructure / csentry
60 commits behind the upstream repository.
-
Benjamin Bertrand authored
The device_type has an impact on the DHCP configuration. JIRA INFRA-1846 #action In Progress
Benjamin Bertrand authoredThe device_type has an impact on the DHCP configuration. JIRA INFRA-1846 #action In Progress
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
utils.py 17.35 KiB
# -*- coding: utf-8 -*-
"""
app.utils
~~~~~~~~~
This module implements utility functions.
:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.
"""
import base64
import datetime
import ipaddress
import io
import random
import sqlalchemy as sa
import dateutil.parser
import yaml
from pathlib import Path
from flask import current_app, jsonify, url_for
from flask.globals import _app_ctx_stack, _request_ctx_stack
from flask_login import current_user
from wtforms import ValidationError
from .extensions import db
def fetch_current_user_id():
"""Retrieve the current user id"""
# Return None if we are outside of request context.
if _app_ctx_stack.top is None or _request_ctx_stack.top is None:
return None
try:
return current_user.id
except AttributeError:
return None
class CSEntryError(Exception):
"""CSEntryError class
Exception used to pass useful information to the client side (API or AJAX)
"""
status_code = 400
def __init__(self, message, status_code=None, payload=None):
super().__init__(self)
self.message = message
if status_code is not None:
self.status_code = status_code
self.payload = payload
def to_dict(self):
rv = dict(self.payload or ())
rv["message"] = self.message
return rv
def __str__(self):
return str(self.to_dict())
def image_to_base64(img, format="PNG"):
"""Convert a Pillow image to a base64 string
:param img: Pillow image
:param format: format of the image to use
:returns str: image as base64 string
"""
buf = io.BytesIO()
img.save(buf, format=format)
return base64.b64encode(buf.getvalue()).decode("ascii")
def format_field(field):
"""Format the given field to a string or None"""
if field is None:
return None
if isinstance(field, datetime.datetime):
return field.strftime("%Y-%m-%d %H:%M")
return str(field)
def convert_to_model(item, model, filter_by="name"):
"""Convert item to an instance of model
Allow to convert a string to an instance of model
Raise an exception if the given name is not found
:returns: an instance of model
"""
if item is None or isinstance(item, model):
return item
kwarg = {filter_by: item}
instance = model.query.filter_by(**kwarg).first()
if instance is None:
raise CSEntryError(f"{item} is not a valid {model.__name__.lower()}")
return instance
def convert_to_models(d, fields):
"""Convert the values of the dictionary to the given type
:param d: dictionary with the values to update
:param fields: list of tuple (key, type_, filter_by)
:returns: new updated dictionary
"""
new = d.copy()
for key, type_, filter_by in fields:
if filter_by is None or key not in d:
# This is not an instance of db.Model but a standard type
continue
if isinstance(type_, list):
values = new[key]
if not isinstance(values, list):
values = [values]
new[key] = [
convert_to_model(value, type_[0], filter_by) for value in values
]
else:
new[key] = convert_to_model(new[key], type_, filter_by)
return new
def attribute_to_string(value):
"""Return the attribute as a string
If the attribute is defined in the schema as multi valued
then the attribute value is returned as a list
See http://ldap3.readthedocs.io/tutorial_searches.html#entries-retrieval
This function returns the first item of the list if it's a list
:param value: string or list
:returns: string
"""
if isinstance(value, list):
try:
return value[0]
except IndexError:
return ""
else:
return value
def get_choices(iterable, allow_blank=False, allow_null=False):
"""Return a list of (value, label)"""
choices = []
if allow_blank:
choices = [("", "")]
if allow_null:
choices.append(("null", "not set"))
choices.extend([(val, val) for val in iterable])
return choices
def get_model_choices(model, allow_none=False, attr="name", query=None, order_by=None):
"""Return a list of (value, label)"""
choices = []
if allow_none:
choices = [(None, "")]
if query is None:
query = model.query
query = query.order_by(getattr(model, order_by or attr))
choices.extend(
[(str(instance.id), getattr(instance, attr)) for instance in query.all()]
)
return choices
def get_query(query, model, **kwargs):
"""Retrieve the query from the arguments
:param query: sqlalchemy base query
:param model: model class
:param kwargs: kwargs from a request
:returns: query filtered by the arguments
"""
if not kwargs:
return query
try:
# With filter_by(**kwargs), the keyword expressions are extracted
# from the primary entity of the query, or the last entity that was
# the target of a call to Query.join()
# This might not be what we want when join() is used.
# Always apply filtering on the given model
for key, value in kwargs.items():
query = query.filter(getattr(model, key) == value)
except (sa.exc.InvalidRequestError, AttributeError) as e:
current_app.logger.warning(f"Invalid query arguments: {e}")
raise CSEntryError("Invalid query arguments", status_code=422)
return query
def lowercase_field(value):
"""Filter to force form value to lowercase"""
try:
return value.lower()
except AttributeError:
return value
# coerce functions to use with SelectField that can accept a None value
# wtforms always coerce to string by default
# Values returned from the form are usually strings but if a field is disabled
# None is returned
# To pass wtforms validation, the value returned must be part of choices
def coerce_to_str_or_none(value):
"""Convert '', None and 'None' to None"""
if value in ("", "None") or value is None:
return None
return str(value)
def coerce_to_int_or_none(value):
"""Return None if the value is not an integer"""
try:
return int(value)
except ValueError:
return None
def parse_to_utc(string):
"""Convert a string to a datetime object with no timezone"""
d = dateutil.parser.parse(string)
if d.tzinfo is None:
# Assume this is UTC
return d
# Convert to UTC and remove timezone
d = d.astimezone(datetime.timezone.utc)
return d.replace(tzinfo=None)
def random_mac():
"""Return a random MAC address"""
octets = [
random.randint(0x00, 0xFF),
random.randint(0x00, 0xFF),
random.randint(0x00, 0xFF),
]
octets = [f"{nb:02x}" for nb in octets]
return ":".join((current_app.config["MAC_OUI"], *octets))
def pluralize(singular):
"""Return the plural form of the given word
Used to pluralize API endpoints (not any given english word)
"""
if not singular.endswith("s"):
return singular + "s"
else:
return singular + "es"
def format_datetime(value, format="%Y-%m-%d %H:%M"):
"""Format a datetime to string
Function used as a jinja2 filter
"""
return value.strftime(format)
def pretty_yaml(value):
"""Pretty print yaml
Function used as a jinja2 filter
"""
if value:
return yaml.safe_dump(value, default_flow_style=False)
else:
return ""
def trigger_job_once(name, queue_name="low", **kwargs):
"""Trigger a job only once
We can have one running job + one in queue (to apply the latest changes).
Make sure that we don't have more than one in queue.
"""
waiting_task = current_user.get_task_waiting(name)
if waiting_task is not None:
current_app.logger.info(
f'Already one "{name}" task waiting. No need to trigger a new one.'
)
return waiting_task
started = current_user.get_task_started(name)
if started:
# There is already one running task. Trigger a new one when it's done.
kwargs["depends_on"] = started.id
current_app.logger.info(f"Launch new {name} job")
task = current_user.launch_task(name, queue_name=queue_name, **kwargs)
return task
def trigger_core_services_update():
"""Trigger a job to update the core services (DNS/DHCP/radius)
This function should be called every time an interface or host is created/edited
The AWX template uses its own inventory that is updated on launch to avoid
blocking the main inventory update when running.
There is no need to trigger an inventory update.
We can have one running job + one in queue to apply the latest changes.
Make sure that we don't have more than one in queue.
"""
job_template = current_app.config["AWX_CORE_SERVICES_UPDATE"]
resource = current_app.config.get("AWX_CORE_SERVICES_UPDATE_RESOURCE", "job")
return trigger_job_once(
"trigger_core_services_update",
queue_name="normal",
func="launch_awx_job",
job_template=job_template,
resource=resource,
)
def trigger_inventory_update():
"""Trigger a job to update the inventory in AWX
This function should be called every time something impacting the inventory is updated
(AnsibleGroup, Host, Interface, Network...)
We can have one running job + one in queue to apply the latest changes.
Make sure that we don't have more than one in queue.
"""
inventory_source = current_app.config["AWX_INVENTORY_SOURCE"]
# Put it on the "high" queue so that it has higher priority than all other jobs
return trigger_job_once(
"trigger_inventory_update",
queue_name="high",
func="launch_awx_job",
resource="inventory_source",
inventory_source=inventory_source,
)
def trigger_ansible_groups_reindex():
"""Trigger a job to reindex the Ansible groups
This function should be called every time a Host or Interface is modified
to make sure the dynamic Ansible groups are indexed properly in Elasticsearch.
We can have one running job + one in queue to apply the latest changes.
Make sure that we don't have more than one in queue.
"""
return trigger_job_once(
"reindex_ansible_groups", queue_name="low", func="reindex_ansible_groups",
)
def trigger_vm_creation(
host, vm_disk_size, vm_cores, vm_memory, vm_osversion, skip_post_install_job
):
"""Trigger a job to create a virtual machine or virtual IOC"""
domain = str(host.main_interface.network.domain)
extra_vars = [
f"vmname={host.name}",
f"domain={domain}",
f"vm_disk_size={vm_disk_size}",
f"vm_cores={vm_cores}",
f"vm_memory={vm_memory}",
f"vm_osversion={vm_osversion}",
]
if host.is_ioc:
task_name = "trigger_vioc_creation"
job_template = current_app.config["AWX_CREATE_VIOC"]
post_job_template = current_app.config["AWX_POST_INSTALL"]["VIOC"].get(domain)
else:
task_name = "trigger_vm_creation"
job_template = current_app.config["AWX_CREATE_VM"]
post_job_template = current_app.config["AWX_POST_INSTALL"]["VM"].get(domain)
current_app.logger.info(
f"Launch new job to create the {host} VM: {job_template} with {extra_vars}"
)
task = current_user.launch_task(
task_name,
resource="job",
queue_name="low",
func="launch_awx_job",
job_template=job_template,
extra_vars=extra_vars,
)
if (
post_job_template
and (not skip_post_install_job)
and (not vm_osversion.startswith("windows"))
):
current_user.launch_task(
"trigger_post_install_job",
resource="job",
queue_name="low",
func="launch_awx_job",
job_template=post_job_template,
limit=f"{host.fqdn}",
depends_on=task.id,
)
current_app.logger.info(
f"Trigger post install job: run {post_job_template} on {host.fqdn}"
)
return task
def trigger_set_network_boot_profile(host, boot_profile):
"""Trigger a job to set the boot profile for host"""
extra_vars = [
f"autoinstall_boot_profile={boot_profile}",
f"autoinstall_pxe_mac_addr={host.main_interface.mac}",
]
job_template = current_app.config["AWX_SET_NETWORK_BOOT_PROFILE"]
current_app.logger.info(
f"Launch new job to set the network boot profile for {host.name}: {job_template} with {extra_vars}"
)
task = current_user.launch_task(
"trigger_set_network_boot_profile",
resource="job",
queue_name="low",
func="launch_awx_job",
job_template=job_template,
extra_vars=extra_vars,
)
return task
def redirect_to_job_status(job_id):
"""
The answer to a client request, leading it to regularly poll a job status.
:param job_id: The id of the job started, which needs to be pulled by the client
:type job_id: rq.job_id
:return: HTTP response
"""
return jsonify({}), 202, {"Location": url_for("main.job_status", job_id=job_id)}
def unique_filename(filename):
"""Return an unique filename
:param filename: filename that should be unique
:returns: unique filename
"""
p = Path(filename)
if not p.exists():
return filename
base = p.with_suffix("")
nb = 1
while True:
unique = Path(f"{base}-{nb}{p.suffix}")
if not unique.exists():
break
nb += 1
return unique
def retrieve_data_for_datatables(values, model, filter_sensitive=False):
"""Return the filtered data of model to datatables
This function is supposed to be called when using datatables
with serverSide processing.
:param values: a `~werkzeug.datastructures.MultiDict`, typically request.values
:param model: class of the model to query
:param bool filter_sensitive: filter out sensitive data if set to True
:return: json object required by datatables
"""
# Get the parameters from the post data sent by datatables
draw = values.get("draw", 0, type=int)
per_page = values.get("length", 20, type=int)
# page starts at 1 in elasticsearch
page = int(values.get("start", 0, type=int) / per_page) + 1
search = values.get("search[value]", "")
if search == "":
search = "*"
order_column = values.get("order[0][column]")
if order_column is None:
sort = None
else:
name = values.get(f"columns[{order_column}][data]")
order_dir = values.get("order[0][dir]", "asc")
# Sorting can be done directly on all fields of type
# keyword/date/long
# To sort on fields of type text, we use the extra .keyword field
if name in ("created_at", "updated_at", "quantity"):
sort = f"{name}:{order_dir}"
else:
sort = f"{name}.keyword:{order_dir}"
instances, nb_filtered = model.search(
search,
page=page,
per_page=per_page,
sort=sort,
filter_sensitive=filter_sensitive,
)
data = [instance.to_dict(recursive=True) for instance in instances]
# Total number of items before filtering
nb_total = db.session.query(sa.func.count(model.id)).scalar()
response = {
"draw": draw,
"recordsTotal": nb_total,
"recordsFiltered": nb_filtered,
"data": data,
}
return jsonify(response)
def minutes_ago(minutes):
"""Return the datetime x minutes ago"""
return datetime.datetime.utcnow() - datetime.timedelta(minutes=minutes)
def update_ansible_vars(host, vars):
"""Update the host ansible_vars
Return False if no variables were changed, True otherwise
"""
if host.ansible_vars:
local_ansible_vars = host.ansible_vars.copy()
local_ansible_vars.update(vars)
if local_ansible_vars == host.ansible_vars:
# No change
return False
else:
host.ansible_vars.update(vars)
# If we don't flag the field as modified, it's not saved to the database
# Probably because we update an existing dictionary
sa.orm.attributes.flag_modified(host, "ansible_vars")
else:
host.ansible_vars = vars
return True
def ip_in_network(ip, address):
"""Ensure the IP is in the network
:returns: a tuple with the IP and network as (IPv4Address, IPv4Network)
:raises: ValidationError if the IP is not in the network
"""
addr = ipaddress.ip_address(ip)
net = ipaddress.ip_network(address)
if addr not in net:
raise ValidationError(f"IP address {ip} is not in network {address}")
return (addr, net)
def validate_ip(ip, network):
"""Ensure the IP is in the network range"""
addr, net = ip_in_network(ip, network.address)
# Admin user can create IP outside the defined range
try:
# current_user is a local proxy and is not
# valid outside of a request context.
is_admin = current_user.is_admin
except AttributeError:
is_admin = False
if not is_admin:
if addr < network.first or addr > network.last:
raise ValidationError(
f"IP address {ip} is not in range {network.first} - {network.last}"
)
def overlaps(subnet, subnets):
"""Return True if the subnet overlaps with any of the subnets"""
for network in subnets:
if subnet.overlaps(network):
return True
return False