Skip to content
Snippets Groups Projects
Commit cd81f112 authored by Benjamin Bertrand's avatar Benjamin Bertrand
Browse files

Prevent recursive dependency loop in Ansible groups

Check that a group child is not:
- the group itself
- the "all" group
- in one of the parent groups (recursively)

JIRA INFRA-1622 #action In Progress
parent d600145d
No related branches found
No related tags found
No related merge requests found
...@@ -1056,6 +1056,29 @@ class AnsibleGroup(CreatedMixin, db.Model): ...@@ -1056,6 +1056,29 @@ class AnsibleGroup(CreatedMixin, db.Model):
def __str__(self): def __str__(self):
return str(self.name) return str(self.name)
@validates("children")
def validate_children(self, key, child):
"""Ensure the child is not in the group parents to avoid circular references"""
if child == self:
raise ValidationError(f"Group '{self.name}' can't be a child of itself.")
# "all" is special for Ansible. Any group is automatically a child of "all".
if child.name == "all":
raise ValidationError(
f"Adding group 'all' as child to '{self.name}' creates a recursive dependency loop."
)
def check_parents(group):
"""Recursively check all parents"""
if child in group.parents:
raise ValidationError(
f"Adding group '{child}' as child to '{self.name}' creates a recursive dependency loop."
)
for parent in group.parents:
check_parents(parent)
check_parents(self)
return child
@property @property
def is_dynamic(self): def is_dynamic(self):
return self.type != AnsibleGroupType.STATIC return self.type != AnsibleGroupType.STATIC
......
...@@ -265,9 +265,13 @@ class AnsibleGroupForm(CSEntryForm): ...@@ -265,9 +265,13 @@ class AnsibleGroupForm(CSEntryForm):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.children.choices = utils.get_model_choices( self.children.choices = [
models.AnsibleGroup, attr="name" (str(group.id), group.name)
) for group in models.AnsibleGroup.query.order_by(
models.AnsibleGroup.name
).all()
if group.name != "all"
]
self.hosts.choices = utils.get_model_choices( self.hosts.choices = utils.get_model_choices(
models.Host, attr="fqdn", order_by="name" models.Host, attr="fqdn", order_by="name"
) )
...@@ -460,6 +460,18 @@ def delete_ansible_group(): ...@@ -460,6 +460,18 @@ def delete_ansible_group():
def edit_ansible_group(name): def edit_ansible_group(name):
group = models.AnsibleGroup.query.filter_by(name=name).first_or_404() group = models.AnsibleGroup.query.filter_by(name=name).first_or_404()
form = AnsibleGroupForm(request.form, obj=group) form = AnsibleGroupForm(request.form, obj=group)
# Restrict the children that can be added
# We don't check parents of parents, but that will be catched by the validate_children
# and raise a ValidationError
form.children.choices = [
(str(ansible_group.id), ansible_group.name)
for ansible_group in models.AnsibleGroup.query.order_by(
models.AnsibleGroup.name
).all()
if (ansible_group not in group.parents)
and (ansible_group.name != "all")
and (ansible_group != group)
]
# Passing hosts as kwarg to the AnsibleGroupForm doesn't work because # Passing hosts as kwarg to the AnsibleGroupForm doesn't work because
# obj takes precedence (but group.hosts contain Host instances and not id) # obj takes precedence (but group.hosts contain Host instances and not id)
# We need to update the default values. Calling process is required. # We need to update the default values. Calling process is required.
...@@ -470,13 +482,19 @@ def edit_ansible_group(name): ...@@ -470,13 +482,19 @@ def edit_ansible_group(name):
form.children.default = [child.id for child in group.children] form.children.default = [child.id for child in group.children]
form.children.process(request.form) form.children.process(request.form)
if form.validate_on_submit(): if form.validate_on_submit():
group.name = form.name.data try:
group.vars = form.vars.data or None group.name = form.name.data
group.type = form.type.data group.vars = form.vars.data or None
group.hosts = [models.Host.query.get(id_) for id_ in form.hosts.data] group.type = form.type.data
group.children = [ group.hosts = [models.Host.query.get(id_) for id_ in form.hosts.data]
models.AnsibleGroup.query.get(id_) for id_ in form.children.data group.children = [
] models.AnsibleGroup.query.get(id_) for id_ in form.children.data
]
except ValidationError as e:
# Check for error raised by model validation (not implemented in form vaildation)
current_app.logger.warning(f"{e}")
flash(f"{e}", "error")
return render_template("network/edit_group.html", form=form)
current_app.logger.debug(f"Trying to update: {group!r}") current_app.logger.debug(f"Trying to update: {group!r}")
try: try:
db.session.commit() db.session.commit()
...@@ -497,13 +515,19 @@ def create_ansible_group(): ...@@ -497,13 +515,19 @@ def create_ansible_group():
if form.validate_on_submit(): if form.validate_on_submit():
hosts = [models.Host.query.get(id_) for id_ in form.hosts.data] hosts = [models.Host.query.get(id_) for id_ in form.hosts.data]
children = [models.AnsibleGroup.query.get(id_) for id_ in form.children.data] children = [models.AnsibleGroup.query.get(id_) for id_ in form.children.data]
group = models.AnsibleGroup( try:
name=form.name.data, group = models.AnsibleGroup(
vars=form.vars.data or None, name=form.name.data,
type=form.type.data, vars=form.vars.data or None,
hosts=hosts, type=form.type.data,
children=children, hosts=hosts,
) children=children,
)
except ValidationError as e:
# Check for error raised by model validation (not implemented in form vaildation)
current_app.logger.warning(f"{e}")
flash(f"{e}", "error")
return render_template("network/create_group.html", form=form)
current_app.logger.debug(f"Trying to create: {group!r}") current_app.logger.debug(f"Trying to create: {group!r}")
db.session.add(group) db.session.add(group)
try: try:
......
...@@ -389,6 +389,43 @@ def test_ansible_groups_children(ansible_group_factory, host_factory): ...@@ -389,6 +389,43 @@ def test_ansible_groups_children(ansible_group_factory, host_factory):
assert group1.children == [group2, group3, group4] assert group1.children == [group2, group3, group4]
def test_ansible_groups_children_all_forbidden(ansible_group_factory):
all = ansible_group_factory(name="all")
group1 = ansible_group_factory()
group2 = ansible_group_factory()
with pytest.raises(ValidationError) as excinfo:
group1.children.append(all)
assert (
f"Adding group 'all' as child to '{group1.name}' creates a recursive dependency loop"
in str(excinfo.value)
)
with pytest.raises(ValidationError) as excinfo:
ansible_group_factory(children=[all])
assert "creates a recursive dependency loop" in str(excinfo.value)
with pytest.raises(ValidationError) as excinfo:
ansible_group_factory(children=[group2, all])
assert "creates a recursive dependency loop" in str(excinfo.value)
def test_ansible_groups_no_recursive_dependency(ansible_group_factory):
group3 = ansible_group_factory()
group2 = ansible_group_factory(children=[group3])
group1 = ansible_group_factory(children=[group2])
with pytest.raises(ValidationError) as excinfo:
group3.children.append(group1)
assert (
f"Adding group '{group1.name}' as child to '{group3.name}' creates a recursive dependency loop"
in str(excinfo.value)
)
def test_ansible_groups_no_child_of_itself(ansible_group_factory):
group1 = ansible_group_factory()
with pytest.raises(ValidationError) as excinfo:
group1.children.append(group1)
assert f"Group '{group1.name}' can't be a child of itself" in str(excinfo.value)
def test_host_model(model_factory, item_factory, host_factory): def test_host_model(model_factory, item_factory, host_factory):
host1 = host_factory() host1 = host_factory()
model1 = model_factory(name="EX3400") model1 = model_factory(name="EX3400")
......
...@@ -748,3 +748,17 @@ def test_create_item_with_host_and_no_stack_member( ...@@ -748,3 +748,17 @@ def test_create_item_with_host_and_no_stack_member(
item = models.Item.query.filter_by(ics_id=ics_id).first() item = models.Item.query.filter_by(ics_id=ics_id).first()
assert item.host == host assert item.host == host
assert item.stack_member is None assert item.stack_member is None
def test_ansible_groups_no_recursive_dependency(
ansible_group_factory, logged_admin_client
):
group3 = ansible_group_factory()
group2 = ansible_group_factory(children=[group3])
group1 = ansible_group_factory(children=[group2])
form = {"name": group3.name, "type": group3.type, "children": [group1.id]}
response = logged_admin_client.post(
f"/network/groups/edit/{group3.name}", data=form
)
assert response.status_code == 200
assert b"creates a recursive dependency loop" in response.data
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment