From cd81f112582bf2155e20f988f7029a284a94f820 Mon Sep 17 00:00:00 2001
From: Benjamin Bertrand <benjamin.bertrand@esss.se>
Date: Wed, 18 Dec 2019 16:43:18 +0100
Subject: [PATCH] Prevent recursive dependency loop in Ansible groups

Check that a group child is not:
- the group itself
- the "all" group
- in one of the parent groups (recursively)

JIRA INFRA-1622 #action In Progress
---
 app/models.py                   | 23 +++++++++++++++
 app/network/forms.py            | 10 +++++--
 app/network/views.py            | 52 ++++++++++++++++++++++++---------
 tests/functional/test_models.py | 37 +++++++++++++++++++++++
 tests/functional/test_web.py    | 14 +++++++++
 5 files changed, 119 insertions(+), 17 deletions(-)

diff --git a/app/models.py b/app/models.py
index 2cd7965..fb3468f 100644
--- a/app/models.py
+++ b/app/models.py
@@ -1056,6 +1056,29 @@ class AnsibleGroup(CreatedMixin, db.Model):
     def __str__(self):
         return str(self.name)
 
+    @validates("children")
+    def validate_children(self, key, child):
+        """Ensure the child is not in the group parents to avoid circular references"""
+        if child == self:
+            raise ValidationError(f"Group '{self.name}' can't be a child of itself.")
+        # "all" is special for Ansible. Any group is automatically a child of "all".
+        if child.name == "all":
+            raise ValidationError(
+                f"Adding group 'all' as child to '{self.name}' creates a recursive dependency loop."
+            )
+
+        def check_parents(group):
+            """Recursively check all parents"""
+            if child in group.parents:
+                raise ValidationError(
+                    f"Adding group '{child}' as child to '{self.name}' creates a recursive dependency loop."
+                )
+            for parent in group.parents:
+                check_parents(parent)
+
+        check_parents(self)
+        return child
+
     @property
     def is_dynamic(self):
         return self.type != AnsibleGroupType.STATIC
diff --git a/app/network/forms.py b/app/network/forms.py
index d1d7a1e..26aba27 100644
--- a/app/network/forms.py
+++ b/app/network/forms.py
@@ -265,9 +265,13 @@ class AnsibleGroupForm(CSEntryForm):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.children.choices = utils.get_model_choices(
-            models.AnsibleGroup, attr="name"
-        )
+        self.children.choices = [
+            (str(group.id), group.name)
+            for group in models.AnsibleGroup.query.order_by(
+                models.AnsibleGroup.name
+            ).all()
+            if group.name != "all"
+        ]
         self.hosts.choices = utils.get_model_choices(
             models.Host, attr="fqdn", order_by="name"
         )
diff --git a/app/network/views.py b/app/network/views.py
index a849720..672bdbf 100644
--- a/app/network/views.py
+++ b/app/network/views.py
@@ -460,6 +460,18 @@ def delete_ansible_group():
 def edit_ansible_group(name):
     group = models.AnsibleGroup.query.filter_by(name=name).first_or_404()
     form = AnsibleGroupForm(request.form, obj=group)
+    # Restrict the children that can be added
+    # We don't check parents of parents, but that will be catched by the validate_children
+    # and raise a ValidationError
+    form.children.choices = [
+        (str(ansible_group.id), ansible_group.name)
+        for ansible_group in models.AnsibleGroup.query.order_by(
+            models.AnsibleGroup.name
+        ).all()
+        if (ansible_group not in group.parents)
+        and (ansible_group.name != "all")
+        and (ansible_group != group)
+    ]
     # Passing hosts as kwarg to the AnsibleGroupForm doesn't work because
     # obj takes precedence (but group.hosts contain Host instances and not id)
     # We need to update the default values. Calling process is required.
@@ -470,13 +482,19 @@ def edit_ansible_group(name):
     form.children.default = [child.id for child in group.children]
     form.children.process(request.form)
     if form.validate_on_submit():
-        group.name = form.name.data
-        group.vars = form.vars.data or None
-        group.type = form.type.data
-        group.hosts = [models.Host.query.get(id_) for id_ in form.hosts.data]
-        group.children = [
-            models.AnsibleGroup.query.get(id_) for id_ in form.children.data
-        ]
+        try:
+            group.name = form.name.data
+            group.vars = form.vars.data or None
+            group.type = form.type.data
+            group.hosts = [models.Host.query.get(id_) for id_ in form.hosts.data]
+            group.children = [
+                models.AnsibleGroup.query.get(id_) for id_ in form.children.data
+            ]
+        except ValidationError as e:
+            # Check for error raised by model validation (not implemented in form vaildation)
+            current_app.logger.warning(f"{e}")
+            flash(f"{e}", "error")
+            return render_template("network/edit_group.html", form=form)
         current_app.logger.debug(f"Trying to update: {group!r}")
         try:
             db.session.commit()
@@ -497,13 +515,19 @@ def create_ansible_group():
     if form.validate_on_submit():
         hosts = [models.Host.query.get(id_) for id_ in form.hosts.data]
         children = [models.AnsibleGroup.query.get(id_) for id_ in form.children.data]
-        group = models.AnsibleGroup(
-            name=form.name.data,
-            vars=form.vars.data or None,
-            type=form.type.data,
-            hosts=hosts,
-            children=children,
-        )
+        try:
+            group = models.AnsibleGroup(
+                name=form.name.data,
+                vars=form.vars.data or None,
+                type=form.type.data,
+                hosts=hosts,
+                children=children,
+            )
+        except ValidationError as e:
+            # Check for error raised by model validation (not implemented in form vaildation)
+            current_app.logger.warning(f"{e}")
+            flash(f"{e}", "error")
+            return render_template("network/create_group.html", form=form)
         current_app.logger.debug(f"Trying to create: {group!r}")
         db.session.add(group)
         try:
diff --git a/tests/functional/test_models.py b/tests/functional/test_models.py
index 0d7e45e..6471a22 100644
--- a/tests/functional/test_models.py
+++ b/tests/functional/test_models.py
@@ -389,6 +389,43 @@ def test_ansible_groups_children(ansible_group_factory, host_factory):
     assert group1.children == [group2, group3, group4]
 
 
+def test_ansible_groups_children_all_forbidden(ansible_group_factory):
+    all = ansible_group_factory(name="all")
+    group1 = ansible_group_factory()
+    group2 = ansible_group_factory()
+    with pytest.raises(ValidationError) as excinfo:
+        group1.children.append(all)
+    assert (
+        f"Adding group 'all' as child to '{group1.name}' creates a recursive dependency loop"
+        in str(excinfo.value)
+    )
+    with pytest.raises(ValidationError) as excinfo:
+        ansible_group_factory(children=[all])
+    assert "creates a recursive dependency loop" in str(excinfo.value)
+    with pytest.raises(ValidationError) as excinfo:
+        ansible_group_factory(children=[group2, all])
+    assert "creates a recursive dependency loop" in str(excinfo.value)
+
+
+def test_ansible_groups_no_recursive_dependency(ansible_group_factory):
+    group3 = ansible_group_factory()
+    group2 = ansible_group_factory(children=[group3])
+    group1 = ansible_group_factory(children=[group2])
+    with pytest.raises(ValidationError) as excinfo:
+        group3.children.append(group1)
+    assert (
+        f"Adding group '{group1.name}' as child to '{group3.name}' creates a recursive dependency loop"
+        in str(excinfo.value)
+    )
+
+
+def test_ansible_groups_no_child_of_itself(ansible_group_factory):
+    group1 = ansible_group_factory()
+    with pytest.raises(ValidationError) as excinfo:
+        group1.children.append(group1)
+    assert f"Group '{group1.name}' can't be a child of itself" in str(excinfo.value)
+
+
 def test_host_model(model_factory, item_factory, host_factory):
     host1 = host_factory()
     model1 = model_factory(name="EX3400")
diff --git a/tests/functional/test_web.py b/tests/functional/test_web.py
index 8a7fd56..1999560 100644
--- a/tests/functional/test_web.py
+++ b/tests/functional/test_web.py
@@ -748,3 +748,17 @@ def test_create_item_with_host_and_no_stack_member(
     item = models.Item.query.filter_by(ics_id=ics_id).first()
     assert item.host == host
     assert item.stack_member is None
+
+
+def test_ansible_groups_no_recursive_dependency(
+    ansible_group_factory, logged_admin_client
+):
+    group3 = ansible_group_factory()
+    group2 = ansible_group_factory(children=[group3])
+    group1 = ansible_group_factory(children=[group2])
+    form = {"name": group3.name, "type": group3.type, "children": [group1.id]}
+    response = logged_admin_client.post(
+        f"/network/groups/edit/{group3.name}", data=form
+    )
+    assert response.status_code == 200
+    assert b"creates a recursive dependency loop" in response.data
-- 
GitLab