Newer
Older
# -*- coding: utf-8 -*-
"""
app.models
~~~~~~~~~~
This module implements the models used in the app.
:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.
"""
import sqlalchemy as sa
from operator import attrgetter
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import validates
from sqlalchemy_continuum import make_versioned, version_class
from citext import CIText
from flask_login import UserMixin, current_user
from wtforms import ValidationError
from .extensions import db, login_manager, ldap_manager, cache
from .plugins import FlaskUserPlugin
from .validators import (
ICS_ID_RE,
HOST_NAME_RE,
VLAN_NAME_RE,
MAC_ADDRESS_RE,
DEVICE_TYPE_RE,
)
make_versioned(plugins=[FlaskUserPlugin()])
# See http://docs.sqlalchemy.org/en/latest/core/compiler.html#utc-timestamp-function
class utcnow(sa.sql.expression.FunctionElement):
type = sa.types.DateTime()
def pg_utcnow(element, compiler, **kw):
def temporary_ics_ids():
"""Generator that returns the full list of temporary ICS ids"""
return (
f'{current_app.config["TEMPORARY_ICS_ID"]}{letter}{number:0=3d}'
for letter in string.ascii_uppercase
for number in range(0, 1000)
)
def used_temporary_ics_ids():
"""Return a set with the temporary ICS ids used"""
temporary_items = Item.query.filter(
Item.ics_id.startswith(current_app.config["TEMPORARY_ICS_ID"])
).all()
return {item.ics_id for item in temporary_items}
def get_temporary_ics_id():
"""Return a temporary ICS id that is available"""
used_temp_ics_ids = used_temporary_ics_ids()
for ics_id in temporary_ics_ids():
return ics_id
else:
raise ValueError("No temporary ICS id available")
@login_manager.user_loader
def load_user(user_id):
"""User loader callback for flask-login
:param str user_id: unicode ID of a user
:returns: corresponding user object or None
"""
return User.query.get(int(user_id))
@ldap_manager.save_user
def save_user(dn, username, data, memberships):
"""User saver for flask-ldap3-login
This method is called whenever a LDAPLoginForm()
successfully validates.
"""
user = User.query.filter_by(username=username).first()
if user is None:
display_name=utils.attribute_to_string(data["cn"]) or username,
email=utils.attribute_to_string(data["mail"]),
)
# Always update the user groups to keep them up-to-date
user.groups = sorted(
[utils.attribute_to_string(group["cn"]) for group in memberships]
)
db.session.add(user)
db.session.commit()
# Tables required for Many-to-Many relationships between users and favorites attributes
favorite_manufacturers_table = db.Table(
"favorite_manufacturers",
db.Column(
"user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True
),
db.Column(
"manufacturer_id",
db.Integer,
db.ForeignKey("manufacturer.id"),
primary_key=True,
),
)
favorite_models_table = db.Table(
"favorite_models",
db.Column(
"user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True
),
db.Column("model_id", db.Integer, db.ForeignKey("model.id"), primary_key=True),
)
favorite_locations_table = db.Table(
"favorite_locations",
db.Column(
"user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True
),
db.Column(
"location_id", db.Integer, db.ForeignKey("location.id"), primary_key=True
),
)
favorite_statuses_table = db.Table(
"favorite_statuses",
db.Column(
"user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True
),
db.Column("status_id", db.Integer, db.ForeignKey("status.id"), primary_key=True),
)
favorite_actions_table = db.Table(
"favorite_actions",
db.Column(
"user_id", db.Integer, db.ForeignKey("user_account.id"), primary_key=True
),
db.Column("action_id", db.Integer, db.ForeignKey("action.id"), primary_key=True),
class User(db.Model, UserMixin):
# "user" is a reserved word in postgresql
# so let's use another name
username = db.Column(db.Text, nullable=False, unique=True)
display_name = db.Column(db.Text, nullable=False)
groups = db.Column(postgresql.ARRAY(db.Text), default=[])
tokens = db.relationship("Token", backref="user")
tasks = db.relationship("Task", backref="user")
# The favorites won't be accessed very often so we load them
# only when necessary (lazy=True)
favorite_manufacturers = db.relationship(
secondary=favorite_manufacturers_table,
lazy=True,
backref=db.backref("favorite_users", lazy=True),
)
secondary=favorite_models_table,
lazy=True,
backref=db.backref("favorite_users", lazy=True),
)
secondary=favorite_locations_table,
lazy=True,
backref=db.backref("favorite_users", lazy=True),
)
secondary=favorite_statuses_table,
lazy=True,
backref=db.backref("favorite_users", lazy=True),
)
secondary=favorite_actions_table,
lazy=True,
backref=db.backref("favorite_users", lazy=True),
)
def get_id(self):
"""Return the user id as unicode
Required by flask-login
"""
return str(self.id)
@property
def csentry_groups(self):
"""Return the list of CSEntry groups the user belong to
Groups are assigned based on the CSENTRY_LDAP_GROUPS mapping with LDAP groups
"""
if not hasattr(self, "_csentry_groups"):
self._csentry_groups = []
for csentry_group, ldap_groups in current_app.config[
"CSENTRY_LDAP_GROUPS"
].items():
if set(self.groups) & set(ldap_groups):
self._csentry_groups.append(csentry_group)
# Add the network group based on CSENTRY_NETWORK_SCOPES_LDAP_GROUPS
network_ldap_groups = set(
itertools.chain(
*current_app.config["CSENTRY_NETWORK_SCOPES_LDAP_GROUPS"].values()
)
)
if set(self.groups) & network_ldap_groups:
self._csentry_groups.append("network")
return self._csentry_groups
@property
def csentry_network_scopes(self):
"""Return the list of CSEntry network scopes the user has access to
Network scopes are assigned based on the CSENTRY_NETWORK_SCOPES_LDAP_GROUPS mapping with LDAP groups
if not hasattr(self, "_csentry_network_scopes"):
self._csentry_network_scopes = []
for network_scope, ldap_groups in current_app.config[
"CSENTRY_NETWORK_SCOPES_LDAP_GROUPS"
].items():
if set(self.groups) & set(ldap_groups):
self._csentry_network_scopes.append(network_scope)
return self._csentry_network_scopes
return "admin" in self.csentry_groups
@property
def is_auditor(self):
return "auditor" in self.csentry_groups
def is_member_of_one_group(self, groups):
"""Return True if the user is at least member of one of the given CSEntry groups"""
return bool(set(groups) & set(self.csentry_groups))
def has_access_to_network(self, network):
"""Return True if the user has access to the network
- admin users have access to all networks
- normal users must have access to the network scope
- normal users don't have access to admin_only networks (whatever the network scope)
- LOGIN_DISABLED can be set to True to turn off authentication check when testing.
In this case, this function always returns True.
"""
if current_app.config.get("LOGIN_DISABLED") or self.is_admin or network is None:
# True is already returned for admin users
return False
return str(network.scope) in self.csentry_network_scopes
def can_view_network(self, network):
"""Return True if the user can view the network
- non sensitive networks can be viewed by anyone
- normal users must have access to the network scope to view sensitive networks
- LOGIN_DISABLED can be set to True to turn off authentication check when testing.
In this case, this function always returns True.
"""
if (
current_app.config.get("LOGIN_DISABLED")
or self.is_admin
or not network.sensitive
):
return True
return str(network.scope) in self.csentry_network_scopes
@property
def sensitive_filter(self):
"""Return the elasticsearch query to use to filter sensitive hosts"""
filter = "sensitive:false"
if self.csentry_network_scopes:
scopes_filter = " OR ".join(
[f"scope:{scope}" for scope in self.csentry_network_scopes]
)
filter = f"{filter} OR (sensitive:true AND ({scopes_filter}))"
return filter
def can_view_host(self, host):
"""Return True if the user can view the host
- non sensitive hosts can be viewed by anyone
- normal users must have access to the network scope to view sensitive hosts
- LOGIN_DISABLED can be set to True to turn off authentication check when testing.
In this case, this function always returns True.
"""
if (
current_app.config.get("LOGIN_DISABLED")
or self.is_admin
or not host.sensitive
):
return True
return str(host.scope) in self.csentry_network_scopes
def can_create_vm(self, host):
"""Return True if the user can create the VM
- host.device_type shall be VirtualMachine
- admin users can create anything
- normal users must have access to the network to create VIOC
- normal users can only create a VM if the host is in one of the allowed network scopes
- LOGIN_DISABLED can be set to True to turn off authentication check when testing.
In this case, this function always returns True.
"""
if str(host.device_type) != "VirtualMachine":
return False
if current_app.config.get("LOGIN_DISABLED") or self.is_admin:
return True
if not self.has_access_to_network(host.main_network):
# True is already returned for admin users
return False
if host.is_ioc:
# VIOC can be created by anyone having access to the network
return True
# VM can only be created if the network scope is allowed
str(host.scope) in current_app.config["ALLOWED_VM_CREATION_NETWORK_SCOPES"]
def can_set_boot_profile(self, host):
"""Return True if the user can set the network boot profile
- host.device_type shall be in ALLOWED_SET_BOOT_PROFILE_DEVICE_TYPES
- admin users can always set the profile
- normal users must have access to the network
- normal users can only set the boot profile if the host is in one of the allowed network scopes
- LOGIN_DISABLED can be set to True to turn off authentication check when testing.
In this case, this function always returns True.
"""
if (
str(host.device_type)
not in current_app.config["ALLOWED_SET_BOOT_PROFILE_DEVICE_TYPES"]
):
return False
if current_app.config.get("LOGIN_DISABLED") or self.is_admin:
return True
if not self.has_access_to_network(host.main_network):
# True is already returned for admin users
return False
# Boot profile can only be set if the network scope is allowed
str(host.scope)
in current_app.config["ALLOWED_SET_BOOT_PROFILE_NETWORK_SCOPES"]
def can_delete_host(self, host):
"""Return True if the user can delete the host
- admin users can delete any host
- normal users must be creator of the host
- LOGIN_DISABLED can be set to True to turn off authentication check when testing.
In this case, this function always returns True.
"""
if current_app.config.get("LOGIN_DISABLED") or self.is_admin:
return True
return self.id == host.user.id
def favorite_attributes(self):
"""Return all user's favorite attributes"""
favorites_list = [
self.favorite_manufacturers,
self.favorite_models,
self.favorite_locations,
self.favorite_statuses,
self.favorite_actions,
]
return [favorite for favorites in favorites_list for favorite in favorites]
def launch_task(self, name, func, queue_name="normal", **kwargs):
"""Launch a task in the background using RQ
The task is added to the session but not committed.
"""
q = Queue(queue_name, default_timeout=current_app.config["RQ_DEFAULT_TIMEOUT"])
job = q.enqueue(f"app.tasks.{func}", **kwargs)
awx_resource=kwargs.get("resource", None),
status=JobStatus(job.get_status()),
depends_on_id=kwargs.get("depends_on", None),
db.session.add(task)
return task
def get_tasks(self, all=False):
"""Return all tasks created by the current user
If the user is admin or auditor and all is set to True, will return all tasks
return Task.query.order_by(Task.created_at).all()
return Task.query.filter_by(user=self).order_by(Task.created_at).all()
def get_tasks_in_progress(self, name):
"""Return all the <name> tasks not finished or failed"""
return (
Task.query.filter_by(name=name)
.filter(~Task.status.in_([JobStatus.FINISHED, JobStatus.FAILED]))
.all()
)
def get_task_started(self, name):
"""Return the <name> task currently running or None"""
return Task.query.filter_by(name=name, status=JobStatus.STARTED).first()
def is_task_waiting(self, name):
"""Return True if a <name> task is waiting
Waiting means:
- queued
A deferred task will be set to failed if the task it depends on fails.
.filter(Task.status.in_([JobStatus.DEFERRED, JobStatus.QUEUED]))
def get_task_waiting(self, name):
"""Return the latest <name> task currently waiting or None
Waiting means:
- queued
A deferred task will be set to failed if the task it depends on fails.
"""
return (
Task.query.filter_by(name=name)
.filter(Task.status.in_([JobStatus.DEFERRED, JobStatus.QUEUED]))
.order_by(Task.created_at.desc())
.first()
)
return self.username
"id": self.id,
"username": self.username,
"display_name": self.display_name,
"email": self.email,
"groups": self.csentry_groups,
class SearchableMixin(object):
"""Add search capability to a class"""
@classmethod
def search(cls, query, page=1, per_page=20, sort=None, filter_sensitive=False):
if filter_sensitive and not (current_user.is_admin or current_user.is_auditor):
query = current_user.sensitive_filter
query = f"({query}) AND ({current_user.sensitive_filter})"
try:
ids, total = search.query_index(
cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"],
query,
page,
per_page,
sort,
)
except elasticsearch.ElasticsearchException as e:
# Invalid query
current_app.logger.warning(e)
return cls.query.filter_by(id=0), 0
if total == 0:
return cls.query.filter_by(id=0), 0
when = [(value, i) for i, value in enumerate(ids)]
return (
cls.query.filter(cls.id.in_(ids)).order_by(db.case(when, value=cls.id)),
total,
)
@classmethod
def before_flush(cls, session, flush_context, instances):
"""Save the new/modified/deleted objects"""
# The session.new / dirty / deleted lists are empty in the after_flush_postexec event.
# We need to record them here
session._changes = {"add_obj": [], "delete": []}
for obj in itertools.chain(session.new, session.dirty):
if isinstance(obj, SearchableMixin):
index = (
obj.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"]
current_app.logger.debug(
f"object to add/update in the {index} index: {obj}"
)
session._changes["add_obj"].append((index, obj))
for obj in session.deleted:
if isinstance(obj, SearchableMixin):
index = (
obj.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"]
)
current_app.logger.debug(
f"object to remove from the {index} index: {obj}"
session._changes["delete"].append((index, obj.id))
@classmethod
def after_flush_postexec(cls, session, flush_context):
"""Retrieve the new and updated objects representation"""
if not hasattr(session, "_changes") or session._changes is None:
# - We can't call obj.to_dict() in the before_flush event because the id
# hasn't been allocated yet (for new objects) and other fields haven't been updated
# (default values like created_at/updated_at and some relationships).
# - We can't call obj.to_dict() in the after_commit event because it would raise:
# sqlalchemy.exc.InvalidRequestError:
# This session is in 'committed' state; no further SQL can be emitted within this transaction.
session._changes["add"] = [
(index, obj.to_dict(recursive=True))
for index, obj in session._changes["add_obj"]
]
@classmethod
def after_commit(cls, session):
"""Update the elasticsearch index"""
if not hasattr(session, "_changes") or session._changes is None:
for index, body in session._changes["add"]:
search.add_to_index(index, body)
for index, id in session._changes["delete"]:
search.remove_from_index(index, id)
session._changes = None
@classmethod
def delete_index(cls, **kwargs):
"""Delete the index of the class"""
current_app.logger.info(f"Delete the {cls.__tablename__} index")
search.delete_index(
cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"],
**kwargs,
)
@classmethod
def create_index(cls, **kwargs):
"""Create the index of the class"""
if hasattr(cls, "__mapping__"):
current_app.logger.info(f"Create the {cls.__tablename__} index")
search.create_index(
cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"],
cls.__mapping__,
**kwargs,
)
else:
current_app.logger.info(
f"No mapping defined for {cls.__tablename__}. No index created."
)
def reindex(cls, delete=True):
"""Force to reindex all instances of the class"""
current_app.logger.info(f"Force to re-index all {cls.__tablename__} instances")
if delete:
# Ignore index_not_found_exception
cls.delete_index(ignore_unavailable=True)
cls.create_index()
for obj in cls.query:
search.add_to_index(
cls.__tablename__ + current_app.config["ELASTICSEARCH_INDEX_SUFFIX"],
obj.to_dict(recursive=True),
class Token(db.Model):
"""Table to store valid tokens"""
id = db.Column(db.Integer, primary_key=True)
jti = db.Column(postgresql.UUID, nullable=False)
token_type = db.Column(db.Text, nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey("user_account.id"), nullable=False)
issued_at = db.Column(db.DateTime, nullable=False)
# expires can be set to None for tokens that never expire
expires = db.Column(db.DateTime)
description = db.Column(db.Text)
__table_args__ = (sa.UniqueConstraint(jti, user_id),)
def __str__(self):
return self.jti
class QRCodeMixin:
id = db.Column(db.Integer, primary_key=True)
name = db.Column(CIText, nullable=False, unique=True)
description = db.Column(db.Text)
"""Return a QRCode image to identify a record
The QRCode includes:
- the table name
- the name of the record
"""
data = ":".join(["CSE", self.__tablename__, self.name])
@cache.memoize(timeout=0)
def base64_image(self):
"""Return the QRCode image as base64 string"""
return utils.image_to_base64(self.image())
def is_user_favorite(self):
"""Return True if the attribute is part of the current user favorites"""
return current_user in self.favorite_users
def __repr__(self):
# The cache.memoize decorator performs a repr() on the passed in arguments
# __repr__ is used as part of the cache key and shall be a uniquely identifying string
# See https://flask-caching.readthedocs.io/en/latest/#memoization
return f"{self.__class__.__name__}(id={self.id}, name={self.name})"
"id": self.id,
"name": self.name,
"description": self.description,
"qrcode": self.base64_image(),
class Manufacturer(QRCodeMixin, db.Model):
items = db.relationship("Item", back_populates="manufacturer")
items = db.relationship("Item", back_populates="model")
items = db.relationship("Item", back_populates="location")
items = db.relationship("Item", back_populates="status")
class CreatedMixin:
id = db.Column(db.Integer, primary_key=True)
created_at = db.Column(db.DateTime, default=utcnow())
updated_at = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow())
# Using ForeignKey and relationship in mixin requires the @declared_attr decorator
# See http://docs.sqlalchemy.org/en/latest/orm/extensions/declarative/mixins.html
@declared_attr
def user_id(cls):
return db.Column(
db.Integer,
db.ForeignKey("user_account.id"),
nullable=False,
default=utils.fetch_current_user_id,
)
@declared_attr
def user(cls):
def __init__(self, **kwargs):
# Automatically convert created_at/updated_at strings
# to datetime object
if key in kwargs and isinstance(kwargs[key], str):
kwargs[key] = utils.parse_to_utc(kwargs[key])
super().__init__(**kwargs)
"id": self.id,
"created_at": utils.format_field(self.created_at),
"updated_at": utils.format_field(self.updated_at),
"user": str(self.user),
class Item(CreatedMixin, SearchableMixin, db.Model):
"exclude": [
"created_at",
"user_id",
"ics_id",
"serial_number",
"manufacturer_id",
"model_id",
]
__mapping__ = {
"created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
"updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm"},
"user": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"ics_id": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"serial_number": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"quantity": {"type": "long"},
"manufacturer": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"model": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"location": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"status": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"parent": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"children": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"macs": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"host": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},
"stack_member": {"type": "byte"},
"history": {"enabled": False},
"comments": {"type": "text"},
}
# WARNING! Inheriting id from CreatedMixin doesn't play well with
# SQLAlchemy-Continuum. It has to be defined here.
ics_id = db.Column(
db.Text, unique=True, nullable=False, index=True, default=get_temporary_ics_id
)
serial_number = db.Column(db.Text, nullable=False)
quantity = db.Column(db.Integer, nullable=False, default=1)
manufacturer_id = db.Column(db.Integer, db.ForeignKey("manufacturer.id"))
model_id = db.Column(db.Integer, db.ForeignKey("model.id"))
location_id = db.Column(db.Integer, db.ForeignKey("location.id"))
status_id = db.Column(db.Integer, db.ForeignKey("status.id"))
parent_id = db.Column(db.Integer, db.ForeignKey("item.id"))
host_id = db.Column(db.Integer, db.ForeignKey("host.id"))
stack_member = db.Column(db.SmallInteger)
manufacturer = db.relationship(
"Manufacturer", back_populates="items", lazy="joined"
)
model = db.relationship("Model", back_populates="items", lazy="joined")
location = db.relationship("Location", back_populates="items", lazy="joined")
status = db.relationship("Status", back_populates="items", lazy="joined")
children = db.relationship("Item", backref=db.backref("parent", remote_side=[id]))
macs = db.relationship("Mac", backref="item", lazy="joined")
comments = db.relationship(
"ItemComment", backref="item", cascade="all, delete-orphan", lazy="joined"
)
sa.CheckConstraint(
"stack_member >= 0 AND stack_member <=9", name="stack_member_range"
),
sa.UniqueConstraint(host_id, stack_member, name="uq_item_host_id_stack_member"),
def __init__(self, **kwargs):
# Automatically convert manufacturer/model/location/status to an
# instance of their class if passed as a string
for key, cls in [
("manufacturer", Manufacturer),
("model", Model),
("location", Location),
("status", Status),
]:
if key in kwargs:
kwargs[key] = utils.convert_to_model(kwargs[key], cls)
super().__init__(**kwargs)
return str(self.ics_id)
def validate_ics_id(self, key, string):
"""Ensure the ICS id field matches the required format"""
if string is not None and ICS_ID_RE.fullmatch(string) is None:
raise ValidationError("ICS id shall match [A-Z]{3}[0-9]{3}")
return string
def to_dict(self, recursive=False):
d.update(
{
"ics_id": self.ics_id,
"serial_number": self.serial_number,
"quantity": self.quantity,
"manufacturer": utils.format_field(self.manufacturer),
"model": utils.format_field(self.model),
"location": utils.format_field(self.location),
"status": utils.format_field(self.status),
"parent": utils.format_field(self.parent),
"children": [str(child) for child in self.children],
"macs": [str(mac) for mac in self.macs],
"host": utils.format_field(self.host),
"stack_member": utils.format_field(self.stack_member),
"history": self.history(),
"comments": [str(comment) for comment in self.comments],
}
)
def to_row_dict(self):
"""Convert to a dict that can easily be exported to an excel row
All values should be a string
"""
d = self.to_dict().copy()
d["children"] = " ".join(d["children"])
d["macs"] = " ".join(d["macs"])
d["comments"] = "\n\n".join(d["comments"])
d["history"] = "\n".join([str(version) for version in d["history"]])
return d
def history(self):
versions = []
for version in self.versions:
# parent is an attribute used by SQLAlchemy-Continuum
# version.parent refers to an ItemVersion instance (and has no link with
# the item parent_id)
# We need to retrieve the parent "manually"
if version.parent_id is None:
parent = None
else:
parent = Item.query.get(version.parent_id)
versions.append(
{
"updated_at": utils.format_field(version.updated_at),
"quantity": version.quantity,
"location": utils.format_field(version.location),
"status": utils.format_field(version.status),
"parent": utils.format_field(parent),
}
)
class ItemComment(CreatedMixin, db.Model):
body = db.Column(db.Text, nullable=False)
item_id = db.Column(db.Integer, db.ForeignKey("item.id"), nullable=False)
d.update({"body": self.body, "item": str(self.item)})
class Network(CreatedMixin, db.Model):
vlan_name = db.Column(CIText, nullable=False, unique=True)
vlan_id = db.Column(db.Integer, nullable=True, unique=True)
address = db.Column(postgresql.CIDR, nullable=False, unique=True)
first_ip = db.Column(postgresql.INET, nullable=False, unique=True)
last_ip = db.Column(postgresql.INET, nullable=False, unique=True)
gateway = db.Column(postgresql.INET, nullable=False, unique=True)
description = db.Column(db.Text)
admin_only = db.Column(db.Boolean, nullable=False, default=False)
sensitive = db.Column(db.Boolean, nullable=False, default=False)
scope_id = db.Column(db.Integer, db.ForeignKey("network_scope.id"), nullable=False)
domain_id = db.Column(db.Integer, db.ForeignKey("domain.id"), nullable=False)
interfaces = db.relationship(
"Interface", backref=db.backref("network", lazy="joined"), lazy=True
)
sa.CheckConstraint("first_ip < last_ip", name="first_ip_less_than_last_ip"),
sa.CheckConstraint("first_ip << address", name="first_ip_in_network"),
sa.CheckConstraint("last_ip << address", name="last_ip_in_network"),
sa.CheckConstraint("gateway << address", name="gateway_in_network"),
def __init__(self, **kwargs):
# Automatically convert scope to an instance of NetworkScope if it was passed
# as a string
if "scope" in kwargs:
kwargs["scope"] = utils.convert_to_model(
kwargs["scope"], NetworkScope, "name"
)
# If domain_id is not passed, we set it to the network scope value
if "domain_id" not in kwargs:
kwargs["domain_id"] = kwargs["scope"].domain_id
# WARNING! Setting self.scope will call validate_networks in the NetworkScope class
# For the validation to work, self.address must be set before!
# Ensure that address and vlan_name are passed before scope
vlan_name = kwargs.pop("vlan_name")
address = kwargs.pop("address")
super().__init__(vlan_name=vlan_name, address=address, **kwargs)
@property
def network_ip(self):
return ipaddress.ip_network(self.address)
@property
def netmask(self):
return self.network_ip.netmask
@property
def broadcast(self):
return self.network_ip.broadcast_address
def first(self):
return ipaddress.ip_address(self.first_ip)
def last(self):
return ipaddress.ip_address(self.last_ip)
def ip_range(self):
"""Return the list of IP addresses that can be assigned for this network
The range is defined by the first and last IP
"""
return [
addr for addr in self.network_ip.hosts() if self.first <= addr <= self.last
]
def used_ips(self):
"""Return the list of IP addresses in use
The list is sorted
"""
return sorted(interface.address for interface in self.interfaces)
def available_ips(self):
"""Return the list of IP addresses available"""
return [addr for addr in self.ip_range() if addr not in self.used_ips()]
def validate_first_ip(self, key, ip):
"""Ensure the first IP is in the network"""
utils.ip_in_network(ip, self.address)
def validate_last_ip(self, key, ip):
"""Ensure the last IP is in the network and greater than first_ip"""
addr, net = utils.ip_in_network(ip, self.address)
raise ValidationError(
f"Last IP address {ip} is less than the first address {self.first}"
)
def validate_interfaces(self, key, interface):
"""Ensure the interface IP is in the network range"""
utils.validate_ip(interface.ip, self)
def validate_vlan_name(self, key, string):
"""Ensure the name matches the required format"""
if string is None:
return None
if VLAN_NAME_RE.fullmatch(string) is None:
raise ValidationError(r"Vlan name shall match [A-Za-z0-9\-]{3,25}")
@validates("vlan_id")
def validate_vlan_id(self, key, value):