diff --git a/app/models.py b/app/models.py index 1e7f99084cbdac7060d8df160719e561defd6e09..0b564584a6d19eaa008df9ea28b411da19f35da1 100644 --- a/app/models.py +++ b/app/models.py @@ -17,6 +17,7 @@ import urllib.parse import elasticsearch import sqlalchemy as sa from enum import Enum +from operator import attrgetter from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.dialects import postgresql from sqlalchemy.orm import validates @@ -1131,18 +1132,18 @@ class AnsibleGroup(CreatedMixin, SearchableMixin, db.Model): nullable=False, ) - children = db.relationship( + _children = db.relationship( "AnsibleGroup", secondary=ansible_groups_parent_child_table, primaryjoin=id == ansible_groups_parent_child_table.c.parent_group_id, secondaryjoin=id == ansible_groups_parent_child_table.c.child_group_id, - backref=db.backref("parents"), + backref=db.backref("_parents"), ) def __str__(self): return str(self.name) - @validates("children") + @validates("_children") def validate_children(self, key, child): """Ensure the child is not in the group parents to avoid circular references""" if child == self: @@ -1213,6 +1214,65 @@ class AnsibleGroup(CreatedMixin, SearchableMixin, db.Model): raise AttributeError("can't set dynamic hosts") self._hosts = value + @property + def children(self): + if self.type == AnsibleGroupType.NETWORK_SCOPE: + # Return all existing network groups part of the scope + network_children = ( + AnsibleGroup.query.filter(AnsibleGroup.type == AnsibleGroupType.NETWORK) + .join(Network, AnsibleGroup.name == Network.vlan_name) + .join(NetworkScope) + .filter(NetworkScope.name == self.name) + .all() + ) + return sorted(self._children + network_children, key=attrgetter("name")) + return sorted(self._children, key=attrgetter("name")) + + @children.setter + def children(self, value): + if self.type == AnsibleGroupType.NETWORK_SCOPE: + # Forbid setting a NETWORK group as child + # Groups linked to networks part of the scope are added automatically + # Also forbid NETWORK_SCOPE group as child + for group in value: + if group.type in ( + AnsibleGroupType.NETWORK, + AnsibleGroupType.NETWORK_SCOPE, + ): + raise ValidationError( + f"can't set {str(group.type).lower()} group '{group}' as a network scope child" + ) + self._children = value + + @property + def parents(self): + if self.type == AnsibleGroupType.NETWORK: + # Add the group corresponding to the network scope if it exists + network = Network.query.filter_by(vlan_name=self.name).first() + if network is not None: + scope_group = AnsibleGroup.query.filter_by( + name=network.scope.name + ).first() + if scope_group is not None: + return sorted(self._parents + [scope_group], key=attrgetter("name")) + return sorted(self._parents, key=attrgetter("name")) + + @parents.setter + def parents(self, value): + if self.type == AnsibleGroupType.NETWORK: + # Forbid setting a NETWORK_SCOPE group as parent + # The group linked to the scope of the network is added automatically + # Also forbid setting a NETWORK group as it doesn't make sense + for group in value: + if group.type in ( + AnsibleGroupType.NETWORK, + AnsibleGroupType.NETWORK_SCOPE, + ): + raise ValidationError( + f"can't set {str(group.type).lower()} group '{group}' as a network parent" + ) + self._parents = value + def to_dict(self, recursive=False): d = super().to_dict() d.update( diff --git a/app/network/views.py b/app/network/views.py index 281c64a8459fd75d58132dc1c558573592aa90d9..38b5e3180198539ad609fe28fafe954bc7548b23 100644 --- a/app/network/views.py +++ b/app/network/views.py @@ -489,7 +489,8 @@ def edit_ansible_group(name): form.hosts.default = [host.id for host in group.hosts] form.hosts.process(request.form) # Same for AnsibleGroup children - form.children.default = [child.id for child in group.children] + # WARNING: use _children to not include groups automatically added to the children property + form.children.default = [child.id for child in group._children] form.children.process(request.form) if form.validate_on_submit(): try: diff --git a/tests/functional/test_models.py b/tests/functional/test_models.py index c0c85a567417d41af5806347f5b044d103e48fb3..5b8f0dc3dfd1c760f718d1bee79441344e39fb0d 100644 --- a/tests/functional/test_models.py +++ b/tests/functional/test_models.py @@ -484,13 +484,12 @@ def test_ansible_groups_children(ansible_group_factory, host_factory): group1 = ansible_group_factory() group2 = ansible_group_factory() group3 = ansible_group_factory() - group1.children.append(group2) - group1.children.append(group3) - assert group1.children == [group2, group3] + group1.children = [group2, group3] + assert group1.children == sorted([group2, group3], key=lambda grp: grp.name) assert group2.parents == [group1] assert group3.parents == [group1] group4 = ansible_group_factory(parents=[group1]) - assert group1.children == [group2, group3, group4] + assert group1.children == sorted([group2, group3, group4], key=lambda grp: grp.name) def test_ansible_groups_children_all_forbidden(ansible_group_factory): @@ -498,7 +497,7 @@ def test_ansible_groups_children_all_forbidden(ansible_group_factory): group1 = ansible_group_factory() group2 = ansible_group_factory() with pytest.raises(ValidationError) as excinfo: - group1.children.append(all) + group1.children = [all] assert ( f"Adding group 'all' as child to '{group1.name}' creates a recursive dependency loop" in str(excinfo.value) @@ -516,7 +515,7 @@ def test_ansible_groups_no_recursive_dependency(ansible_group_factory): group2 = ansible_group_factory(children=[group3]) group1 = ansible_group_factory(children=[group2]) with pytest.raises(ValidationError) as excinfo: - group3.children.append(group1) + group3.children = [group1] assert ( f"Adding group '{group1.name}' as child to '{group3.name}' creates a recursive dependency loop" in str(excinfo.value) @@ -526,10 +525,84 @@ def test_ansible_groups_no_recursive_dependency(ansible_group_factory): def test_ansible_groups_no_child_of_itself(ansible_group_factory): group1 = ansible_group_factory() with pytest.raises(ValidationError) as excinfo: - group1.children.append(group1) + group1.children = [group1] assert f"Group '{group1.name}' can't be a child of itself" in str(excinfo.value) +@pytest.mark.parametrize( + "grp_type", + [ + models.AnsibleGroupType.STATIC, + models.AnsibleGroupType.DEVICE_TYPE, + models.AnsibleGroupType.IOC, + models.AnsibleGroupType.HOSTNAME, + ], +) +def test_ansible_group_network_scope_children(ansible_group_factory, grp_type): + group = ansible_group_factory(type=grp_type) + scope_group = ansible_group_factory( + type=models.AnsibleGroupType.NETWORK_SCOPE, children=[group] + ) + assert scope_group.children == [group] + assert group.parents == [scope_group] + + +@pytest.mark.parametrize( + "grp_type", + [ + models.AnsibleGroupType.NETWORK, + models.AnsibleGroupType.NETWORK_SCOPE, + ], +) +def test_ansible_group_network_scope_children_forbidden( + ansible_group_factory, grp_type +): + child_group = ansible_group_factory(name="mygroup", type=grp_type) + with pytest.raises(ValidationError) as excinfo: + ansible_group_factory( + type=models.AnsibleGroupType.NETWORK_SCOPE, children=[child_group] + ) + assert ( + f"can't set {str(grp_type).lower()} group 'mygroup' as a network scope child" + in str(excinfo.value) + ) + + +@pytest.mark.parametrize( + "grp_type", + [ + models.AnsibleGroupType.STATIC, + models.AnsibleGroupType.DEVICE_TYPE, + models.AnsibleGroupType.IOC, + models.AnsibleGroupType.HOSTNAME, + ], +) +def test_ansible_group_network_parent(ansible_group_factory, grp_type): + group = ansible_group_factory(type=grp_type) + network_group = ansible_group_factory( + type=models.AnsibleGroupType.NETWORK, parents=[group] + ) + assert network_group.parents == [group] + assert group.children == [network_group] + + +@pytest.mark.parametrize( + "grp_type", + [ + models.AnsibleGroupType.NETWORK, + models.AnsibleGroupType.NETWORK_SCOPE, + ], +) +def test_ansible_group_network_parent_forbidden(ansible_group_factory, grp_type): + group = ansible_group_factory(name="mygroup", type=grp_type) + with pytest.raises(ValidationError) as excinfo: + ansible_group_factory(type=models.AnsibleGroupType.NETWORK, parents=[group]) + assert ( + f"can't set {str(grp_type).lower()} group 'mygroup' as a network parent" + in str(excinfo.value) + ) + + def test_host_model(model_factory, item_factory, host_factory): host1 = host_factory() model1 = model_factory(name="EX3400") @@ -585,7 +658,7 @@ def test_ansible_dynamic_network_group( assert group3.hosts == [] -def test_ansible_dynamic_network_scope_group( +def test_ansible_dynamic_network_scope_group_hosts( ansible_group_factory, network_scope_factory, network_factory, @@ -620,6 +693,68 @@ def test_ansible_dynamic_network_scope_group( assert group3.hosts == [] +@pytest.mark.parametrize( + "networks, groups, expected_names", + [ + ( + ["network1", "network2", "network3"], + [], + [], + ), + ( + ["network1", "network2", "network3"], + [("network2", models.AnsibleGroupType.NETWORK)], + ["network2"], + ), + ( + ["network1", "network2", "network3"], + [ + ("network2", models.AnsibleGroupType.NETWORK), + ("network3", models.AnsibleGroupType.NETWORK), + ], + ["network2", "network3"], + ), + ( + ["network1"], + [ + ("network2", models.AnsibleGroupType.NETWORK), + ], + [], + ), + ( + ["network1", "network2"], + [ + ("mygroup1", models.AnsibleGroupType.DEVICE_TYPE), + ("network2", models.AnsibleGroupType.NETWORK), + ("mygroup2", models.AnsibleGroupType.STATIC), + ], + ["mygroup1", "mygroup2", "network2"], + ), + ], +) +def test_ansible_dynamic_network_scope_group_children( + ansible_group_factory, + network_scope_factory, + network_factory, + networks, + groups, + expected_names, +): + name = "myscope" + scope = network_scope_factory(name=name) + group = ansible_group_factory(name=name, type=models.AnsibleGroupType.NETWORK_SCOPE) + for network in networks: + network_factory(vlan_name=network, scope=scope) + for grp, grp_type in groups: + ag = ansible_group_factory(name=grp, type=grp_type) + if grp_type != models.AnsibleGroupType.NETWORK: + ag.parents = [group] + expected = [ + grp for grp in models.AnsibleGroup.query.all() if grp.name in expected_names + ] + assert group.children == sorted(expected, key=lambda grp: grp.name) + + def test_ansible_dynamic_ioc_group(ansible_group_factory, host_factory): host1 = host_factory(name="host1", is_ioc=True) host2 = host_factory(name="host2", is_ioc=True)