Newer
Older
# -*- coding: utf-8 -*-
"""
app.utils
~~~~~~~~~
This module implements utility functions.
:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.
"""
import sqlalchemy as sa
import dateutil.parser
from pathlib import Path
from flask import current_app, jsonify, url_for
from flask.globals import _app_ctx_stack, _request_ctx_stack
from flask_login import current_user
from wtforms import ValidationError
def fetch_current_user_id():
"""Retrieve the current user id"""
# Return None if we are outside of request context.
if _app_ctx_stack.top is None or _request_ctx_stack.top is None:
return None
try:
except AttributeError:
return None
class CSEntryError(Exception):
"""CSEntryError class
Exception used to pass useful information to the client side (API or AJAX)
"""
status_code = 400
def __init__(self, message, status_code=None, payload=None):
super().__init__(self)
self.message = message
if status_code is not None:
self.status_code = status_code
self.payload = payload
def to_dict(self):
rv = dict(self.payload or ())
return rv
def __str__(self):
return str(self.to_dict())
"""Convert a Pillow image to a base64 string
:param img: Pillow image
:param format: format of the image to use
:returns str: image as base64 string
"""
buf = io.BytesIO()
img.save(buf, format=format)
return base64.b64encode(buf.getvalue()).decode("ascii")
def format_field(field):
"""Format the given field to a string or None"""
if field is None:
return None
if isinstance(field, datetime.datetime):
def convert_to_model(item, model, filter_by="name"):
"""Convert item to an instance of model
Allow to convert a string to an instance of model
Raise an exception if the given name is not found
:returns: an instance of model
"""
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
if item is None or isinstance(item, model):
return item
kwarg = {filter_by: item}
instance = model.query.filter_by(**kwarg).first()
if instance is None:
raise CSEntryError(f"{item} is not a valid {model.__name__.lower()}")
return instance
def convert_to_models(d, fields):
"""Convert the values of the dictionary to the given type
:param d: dictionary with the values to update
:param fields: list of tuple (key, type_, filter_by)
:returns: new updated dictionary
"""
new = d.copy()
for key, type_, filter_by in fields:
if filter_by is None or key not in d:
# This is not an instance of db.Model but a standard type
continue
if isinstance(type_, list):
values = new[key]
if not isinstance(values, list):
values = [values]
new[key] = [
convert_to_model(value, type_[0], filter_by) for value in values
]
else:
new[key] = convert_to_model(new[key], type_, filter_by)
return new
def attribute_to_string(value):
"""Return the attribute as a string
If the attribute is defined in the schema as multi valued
then the attribute value is returned as a list
See http://ldap3.readthedocs.io/tutorial_searches.html#entries-retrieval
This function returns the first item of the list if it's a list
:param value: string or list
:returns: string
"""
if isinstance(value, list):
try:
return value[0]
except IndexError:
return ""
def get_choices(iterable, allow_blank=False, allow_null=False):
"""Return a list of (value, label)"""
choices = []
if allow_blank:
choices.extend([(val, val) for val in iterable])
return choices
def get_model_choices(model, allow_none=False, attr="name", query=None, order_by=None):
"""Return a list of (value, label)"""
choices = []
query = model.query
query = query.order_by(getattr(model, order_by or attr))
choices.extend(
[(str(instance.id), getattr(instance, attr)) for instance in query.all()]
)
def get_query(query, model, **kwargs):
"""Retrieve the query from the arguments
:param query: sqlalchemy base query
:param model: model class
:returns: query filtered by the arguments
"""
if not kwargs:
return query
try:
# With filter_by(**kwargs), the keyword expressions are extracted
# from the primary entity of the query, or the last entity that was
# the target of a call to Query.join()
# This might not be what we want when join() is used.
# Always apply filtering on the given model
for key, value in kwargs.items():
query = query.filter(getattr(model, key) == value)
except (sa.exc.InvalidRequestError, AttributeError) as e:
current_app.logger.warning(f"Invalid query arguments: {e}")
raise CSEntryError("Invalid query arguments", status_code=422)
def lowercase_field(value):
"""Filter to force form value to lowercase"""
try:
return value.lower()
except AttributeError:
return value
# coerce functions to use with SelectField that can accept a None value
# wtforms always coerce to string by default
# Values returned from the form are usually strings but if a field is disabled
# None is returned
# To pass wtforms validation, the value returned must be part of choices
def coerce_to_str_or_none(value):
"""Convert '', None and 'None' to None"""
return None
return str(value)
def coerce_to_int_or_none(value):
"""Return None if the value is not an integer"""
try:
return int(value)
except ValueError:
return None
def parse_to_utc(string):
"""Convert a string to a datetime object with no timezone"""
d = dateutil.parser.parse(string)
if d.tzinfo is None:
# Assume this is UTC
return d
# Convert to UTC and remove timezone
d = d.astimezone(datetime.timezone.utc)
return d.replace(tzinfo=None)
def random_mac():
"""Return a random MAC address"""
octets = [
random.randint(0x00, 0xFF),
random.randint(0x00, 0xFF),
random.randint(0x00, 0xFF),
]
octets = [f"{nb:02x}" for nb in octets]
return ":".join((current_app.config["MAC_OUI"], *octets))
def pluralize(singular):
"""Return the plural form of the given word
Used to pluralize API endpoints (not any given english word)
"""
if not singular.endswith("s"):
return singular + "s"
def format_datetime(value, format="%Y-%m-%d %H:%M"):
"""Format a datetime to string
Function used as a jinja2 filter
"""
return value.strftime(format)
def pretty_yaml(value):
"""Pretty print yaml
Function used as a jinja2 filter
"""
if value:
return yaml.safe_dump(value, default_flow_style=False)
else:
return ""
def trigger_job_once(name, queue_name="low", **kwargs):
"""Trigger a job only once
We can have one running job + one in queue (to apply the latest changes).
Make sure that we don't have more than one in queue.
"""
waiting_task = current_user.get_task_waiting(name)
if waiting_task is not None:
current_app.logger.info(
f'Already one "{name}" task waiting. No need to trigger a new one.'
)
return waiting_task
started = current_user.get_task_started(name)
if started:
# There is already one running task. Trigger a new one when it's done.
kwargs["depends_on"] = started.id
current_app.logger.info(f"Launch new {name} job")
task = current_user.launch_task(name, queue_name=queue_name, **kwargs)
return task
def trigger_core_services_update():
"""Trigger a job to update the core services (DNS/DHCP/radius)
This function should be called every time an interface or host is created/edited
The AWX template uses its own inventory that is updated on launch to avoid
blocking the main inventory update when running.
There is no need to trigger an inventory update.
We can have one running job + one in queue to apply the latest changes.
Make sure that we don't have more than one in queue.
"""
job_template = current_app.config["AWX_CORE_SERVICES_UPDATE"]
resource = current_app.config.get("AWX_CORE_SERVICES_UPDATE_RESOURCE", "job")
return trigger_job_once(
"trigger_core_services_update",
queue_name="normal",
func="launch_awx_job",
job_template=job_template,
resource=resource,
)
def trigger_inventory_update():
"""Trigger a job to update the inventory in AWX
This function should be called every time something impacting the inventory is updated
(AnsibleGroup, Host, Interface, Network...)
We can have one running job + one in queue to apply the latest changes.
Make sure that we don't have more than one in queue.
"""
inventory_source = current_app.config["AWX_INVENTORY_SOURCE"]
# Put it on the "high" queue so that it has higher priority than all other jobs
return trigger_job_once(
"trigger_inventory_update",
queue_name="high",
func="launch_awx_job",
resource="inventory_source",
)
def trigger_ansible_groups_reindex():
"""Trigger a job to reindex the Ansible groups
This function should be called every time a Host or Interface is modified
to make sure the dynamic Ansible groups are indexed properly in Elasticsearch.
We can have one running job + one in queue to apply the latest changes.
Make sure that we don't have more than one in queue.
"""
return trigger_job_once(
"reindex_ansible_groups", queue_name="low", func="reindex_ansible_groups",
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
def trigger_vm_creation(
host, vm_disk_size, vm_cores, vm_memory, vm_osversion, skip_post_install_job
):
"""Trigger a job to create a virtual machine or virtual IOC"""
domain = str(host.main_interface.network.domain)
extra_vars = [
f"vmname={host.name}",
f"domain={domain}",
f"vm_disk_size={vm_disk_size}",
f"vm_cores={vm_cores}",
f"vm_memory={vm_memory}",
f"vm_osversion={vm_osversion}",
]
if host.is_ioc:
task_name = "trigger_vioc_creation"
job_template = current_app.config["AWX_CREATE_VIOC"]
post_job_template = current_app.config["AWX_POST_INSTALL"]["VIOC"].get(domain)
else:
task_name = "trigger_vm_creation"
job_template = current_app.config["AWX_CREATE_VM"]
post_job_template = current_app.config["AWX_POST_INSTALL"]["VM"].get(domain)
current_app.logger.info(
f"Launch new job to create the {host} VM: {job_template} with {extra_vars}"
)
task = current_user.launch_task(
task_name,
resource="job",
queue_name="low",
func="launch_awx_job",
job_template=job_template,
extra_vars=extra_vars,
)
if (
post_job_template
and (not skip_post_install_job)
and (not vm_osversion.startswith("windows"))
):
current_user.launch_task(
"trigger_post_install_job",
resource="job",
queue_name="low",
func="launch_awx_job",
job_template=post_job_template,
limit=f"{host.fqdn}",
depends_on=task.id,
)
current_app.logger.info(
f"Trigger post install job: run {post_job_template} on {host.fqdn}"
)
return task
def trigger_set_network_boot_profile(host, boot_profile):
"""Trigger a job to set the boot profile for host"""
extra_vars = [
f"autoinstall_boot_profile={boot_profile}",
f"autoinstall_pxe_mac_addr={host.main_interface.mac}",
]
job_template = current_app.config["AWX_SET_NETWORK_BOOT_PROFILE"]
current_app.logger.info(
f"Launch new job to set the network boot profile for {host.name}: {job_template} with {extra_vars}"
)
task = current_user.launch_task(
"trigger_set_network_boot_profile",
resource="job",
queue_name="low",
func="launch_awx_job",
job_template=job_template,
extra_vars=extra_vars,
)
return task
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def redirect_to_job_status(job_id):
"""
The answer to a client request, leading it to regularly poll a job status.
:param job_id: The id of the job started, which needs to be pulled by the client
:type job_id: rq.job_id
:return: HTTP response
"""
return jsonify({}), 202, {"Location": url_for("main.job_status", job_id=job_id)}
def unique_filename(filename):
"""Return an unique filename
:param filename: filename that should be unique
:returns: unique filename
"""
p = Path(filename)
if not p.exists():
return filename
base = p.with_suffix("")
nb = 1
while True:
unique = Path(f"{base}-{nb}{p.suffix}")
if not unique.exists():
break
nb += 1
return unique
def retrieve_data_for_datatables(values, model, filter_sensitive=False):
"""Return the filtered data of model to datatables
This function is supposed to be called when using datatables
with serverSide processing.
:param values: a `~werkzeug.datastructures.MultiDict`, typically request.values
:param model: class of the model to query
:param bool filter_sensitive: filter out sensitive data if set to True
:return: json object required by datatables
"""
# Get the parameters from the post data sent by datatables
draw = values.get("draw", 0, type=int)
per_page = values.get("length", 20, type=int)
# page starts at 1 in elasticsearch
page = int(values.get("start", 0, type=int) / per_page) + 1
search = values.get("search[value]", "")
if search == "":
search = "*"
order_column = values.get("order[0][column]")
if order_column is None:
sort = None
else:
name = values.get(f"columns[{order_column}][data]")
order_dir = values.get("order[0][dir]", "asc")
# Sorting can be done directly on all fields of type
# keyword/date/long
# To sort on fields of type text, we use the extra .keyword field
if name in ("created_at", "updated_at", "quantity"):
sort = f"{name}:{order_dir}"
else:
sort = f"{name}.keyword:{order_dir}"
instances, nb_filtered = model.search(
search,
page=page,
per_page=per_page,
sort=sort,
filter_sensitive=filter_sensitive,
data = [instance.to_dict(recursive=True) for instance in instances]
# Total number of items before filtering
nb_total = db.session.query(sa.func.count(model.id)).scalar()
response = {
"draw": draw,
"recordsTotal": nb_total,
"recordsFiltered": nb_filtered,
"data": data,
}
return jsonify(response)
def minutes_ago(minutes):
"""Return the datetime x minutes ago"""
return datetime.datetime.utcnow() - datetime.timedelta(minutes=minutes)
def update_ansible_vars(host, vars):
"""Update the host ansible_vars
Return False if no variables were changed, True otherwise
"""
if host.ansible_vars:
local_ansible_vars = host.ansible_vars.copy()
local_ansible_vars.update(vars)
if local_ansible_vars == host.ansible_vars:
# No change
return False
else:
host.ansible_vars.update(vars)
# If we don't flag the field as modified, it's not saved to the database
# Probably because we update an existing dictionary
sa.orm.attributes.flag_modified(host, "ansible_vars")
else:
host.ansible_vars = vars
return True
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
def ip_in_network(ip, address):
"""Ensure the IP is in the network
:returns: a tuple with the IP and network as (IPv4Address, IPv4Network)
:raises: ValidationError if the IP is not in the network
"""
addr = ipaddress.ip_address(ip)
net = ipaddress.ip_network(address)
if addr not in net:
raise ValidationError(f"IP address {ip} is not in network {address}")
return (addr, net)
def validate_ip(ip, network):
"""Ensure the IP is in the network range"""
addr, net = ip_in_network(ip, network.address)
# Admin user can create IP outside the defined range
try:
# current_user is a local proxy and is not
# valid outside of a request context.
is_admin = current_user.is_admin
except AttributeError:
is_admin = False
if not is_admin:
if addr < network.first or addr > network.last:
raise ValidationError(
f"IP address {ip} is not in range {network.first} - {network.last}"
)
def overlaps(subnet, subnets):
"""Return True if the subnet overlaps with any of the subnets"""
for network in subnets:
if subnet.overlaps(network):
return True
return False