From e6494e471a8791d037edc2dc7faa8592a84d7b82 Mon Sep 17 00:00:00 2001 From: Benjamin Bertrand <benjamin.bertrand@esss.se> Date: Wed, 28 Nov 2018 14:38:11 +0100 Subject: [PATCH] Add extra fields to the API - add fqdn to host - add gateway to network - add netmask to interface - allow to pass recursive via the API (to expand interfaces in host) - add is_main property to interface (True if the interface is the main interface of the host) JIRA INFRA-640 --- app/api/utils.py | 7 ++- app/models.py | 33 +++++++----- tests/functional/test_api.py | 91 +++++++++++++++++++++++++++------ tests/functional/test_models.py | 46 +++++++++++++++++ tests/functional/test_web.py | 4 +- 5 files changed, 150 insertions(+), 31 deletions(-) diff --git a/app/api/utils.py b/app/api/utils.py index cee5266..db8c004 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -70,13 +70,18 @@ def get_generic_model(model, order_by=None, query=None): kwargs = request.args.to_dict() page = int(kwargs.pop("page", 1)) per_page = int(kwargs.pop("per_page", 20)) + # Remove recursive from kwargs so that it doesn't get passed + # to query.filter_by in get_query + recursive = kwargs.pop("recursive", "false").lower() == "true" if query is None: query = utils.get_query(model.query, **kwargs) if order_by is None: order_by = getattr(model, "name") query = query.order_by(order_by) pagination = query.paginate(page, per_page) - data = [item.to_dict() for item in pagination.items] + data = [item.to_dict(recursive=recursive) for item in pagination.items] + # Re-add recursive to kwargs so that it's part of the pagination url + kwargs["recursive"] = recursive header = build_pagination_header(pagination, request.base_url, **kwargs) return jsonify(data), 200, header diff --git a/app/models.py b/app/models.py index 62de3ca..8da58c0 100644 --- a/app/models.py +++ b/app/models.py @@ -349,7 +349,7 @@ class User(db.Model, UserMixin): def __str__(self): return self.username - def to_dict(self): + def to_dict(self, recursive=False): return { "id": self.id, "username": self.username, @@ -499,7 +499,7 @@ class QRCodeMixin: # See https://flask-caching.readthedocs.io/en/latest/#memoization return f"{self.__class__.__name__}(id={self.id}, name={self.name})" - def to_dict(self): + def to_dict(self, recursive=False): return { "id": self.id, "name": self.name, @@ -701,7 +701,7 @@ class ItemComment(CreatedMixin, db.Model): def __str__(self): return self.body - def to_dict(self): + def to_dict(self, recursive=False): d = super().to_dict() d.update({"body": self.body, "item": str(self.item)}) return d @@ -841,7 +841,7 @@ class Network(CreatedMixin, db.Model): raise ValidationError("Vlan name shall match [A-Za-z0-9\-]{3,25}") return string - def to_dict(self): + def to_dict(self, recursive=False): d = super().to_dict() d.update( { @@ -851,6 +851,7 @@ class Network(CreatedMixin, db.Model): "netmask": str(self.netmask), "first_ip": self.first_ip, "last_ip": self.last_ip, + "gateway": str(self.gateway), "description": self.description, "admin_only": self.admin_only, "scope": utils.format_field(self.scope), @@ -903,7 +904,7 @@ class DeviceType(db.Model): def __str__(self): return self.name - def to_dict(self): + def to_dict(self, recursive=False): return { "id": self.id, "name": self.name, @@ -1024,7 +1025,7 @@ class AnsibleGroup(CreatedMixin, db.Model): raise AttributeError("can't set dynamic hosts") self._hosts = value - def to_dict(self): + def to_dict(self, recursive=False): d = super().to_dict() d.update( { @@ -1179,6 +1180,8 @@ class Host(CreatedMixin, SearchableMixin, db.Model): d.update( { "name": self.name, + "fqdn": self.fqdn, + "is_ioc": self.is_ioc, "device_type": str(self.device_type), "model": self.model, "description": self.description, @@ -1303,18 +1306,24 @@ class Interface(CreatedMixin, db.Model): return True return False + @property + def is_main(self): + return self.name == self.host.main_interface.name + def __str__(self): return str(self.name) def __repr__(self): return f"Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})" - def to_dict(self): + def to_dict(self, recursive=False): d = super().to_dict() d.update( { + "is_main": self.is_main, "network": str(self.network), "ip": self.ip, + "netmask": str(self.network.netmask), "name": self.name, "mac": utils.format_field(self.mac), "host": utils.format_field(self.host), @@ -1349,7 +1358,7 @@ class Mac(db.Model): raise ValidationError(f"'{string}' does not appear to be a MAC address") return string - def to_dict(self): + def to_dict(self, recursive=False): return { "id": self.id, "address": self.address, @@ -1398,7 +1407,7 @@ class Cname(CreatedMixin, db.Model): raise ValidationError(f"cname matches an existing host") return lower_string - def to_dict(self): + def to_dict(self, recursive=False): d = super().to_dict() d.update({"name": self.name, "interface": str(self.interface)}) return d @@ -1417,7 +1426,7 @@ class Domain(CreatedMixin, db.Model): def __str__(self): return str(self.name) - def to_dict(self): + def to_dict(self, recursive=False): d = super().to_dict() d.update( { @@ -1492,7 +1501,7 @@ class NetworkScope(CreatedMixin, db.Model): if subnet not in self.used_subnets() ] - def to_dict(self): + def to_dict(self, recursive=False): d = super().to_dict() d.update( { @@ -1548,7 +1557,7 @@ class Task(db.Model): def __str__(self): return str(self.id) - def to_dict(self): + def to_dict(self, recursive=False): return { "id": self.id, "name": self.name, diff --git a/tests/functional/test_api.py b/tests/functional/test_api.py index 12609a4..acecb7f 100644 --- a/tests/functional/test_api.py +++ b/tests/functional/test_api.py @@ -45,6 +45,40 @@ GENERIC_CREATE_ENDPOINTS = [ CREATE_AUTH_ENDPOINTS = [ key for key in ENDPOINT_MODEL.keys() if key != "inventory/actions" ] +HOST_KEYS = { + "id", + "name", + "fqdn", + "is_ioc", + "device_type", + "model", + "description", + "items", + "interfaces", + "ansible_vars", + "ansible_groups", + "created_at", + "updated_at", + "user", +} +INTERFACE_KEYS = { + "id", + "is_main", + "network", + "ip", + "netmask", + "name", + "mac", + "domain", + "host", + "device_type", + "model", + "cnames", + "tags", + "created_at", + "updated_at", + "user", +} def get(client, url, token=None): @@ -677,6 +711,7 @@ def test_create_network(client, admin_token, network_scope_factory): "netmask", "first_ip", "last_ip", + "gateway", "description", "admin_only", "scope", @@ -1026,25 +1061,12 @@ def test_create_interface(client, host, network_factory, no_login_check_token): client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token ) assert response.status_code == 201 - assert { - "id", - "network", - "ip", - "name", - "mac", - "domain", - "host", - "device_type", - "model", - "cnames", - "tags", - "created_at", - "updated_at", - "user", - } == set(response.get_json().keys()) + assert INTERFACE_KEYS == set(response.get_json().keys()) assert response.get_json()["network"] == network.vlan_name assert response.get_json()["ip"] == "192.168.1.20" assert response.get_json()["name"] == host.name + # This is the main interface + assert response.get_json()["is_main"] # Check that all parameters can be passed data2 = { @@ -1058,6 +1080,8 @@ def test_create_interface(client, host, network_factory, no_login_check_token): client, f"{API_URL}/network/interfaces", data=data2, token=no_login_check_token ) assert response.status_code == 201 + # This is not the main interface + assert not response.get_json()["is_main"] # check all items that were created assert models.Interface.query.count() == 2 @@ -1291,6 +1315,7 @@ def test_get_hosts(client, host_factory, readonly_token): response = get(client, f"{API_URL}/network/hosts", token=readonly_token) assert response.status_code == 200 assert len(response.get_json()) == 2 + assert HOST_KEYS == set(response.get_json()[0].keys()) check_input_is_subset_of_response(response, (host1.to_dict(), host2.to_dict())) @@ -1320,6 +1345,38 @@ def test_get_hosts_with_no_model(client, host_factory, readonly_token): assert response.get_json()[0]["model"] is None +def test_get_hosts_recursive(client, host_factory, interface_factory, readonly_token): + # Create some hosts with interfaces + host1 = host_factory() + interface11 = interface_factory(name=host1.name, host=host1) + interface12 = interface_factory(host=host1) + host2 = host_factory() + interface21 = interface_factory(host=host2) + # Without recursive, we only get the name of the interfaces + response = get(client, f"{API_URL}/network/hosts", token=readonly_token) + assert response.status_code == 200 + assert len(response.get_json()) == 2 + rhost1, rhost2 = response.get_json() + # We can't be sure in which order the interfaces are returned + assert set(rhost1["interfaces"]) == {interface11.name, interface12.name} + assert rhost2["interfaces"] == [interface21.name] + # With recursive, interfaces are expanded + response = get( + client, f"{API_URL}/network/hosts?recursive=true", token=readonly_token + ) + assert response.status_code == 200 + assert len(response.get_json()) == 2 + rhost1, rhost2 = response.get_json() + assert len(rhost1["interfaces"]) == 2 + rinterface11, rinterface12 = rhost1["interfaces"] + assert INTERFACE_KEYS == set(rinterface11.keys()) + assert INTERFACE_KEYS == set(rinterface12.keys()) + assert len(rhost2["interfaces"]) == 1 + rinterface21 = rhost2["interfaces"][0] + assert INTERFACE_KEYS == set(rinterface21.keys()) + assert rinterface21["network"] == interface21.network.vlan_name + + def test_create_host(client, device_type_factory, user_token): device_type = device_type_factory(name="Virtual") # check that name and device_type are mandatory @@ -1343,6 +1400,8 @@ def test_create_host(client, device_type_factory, user_token): assert { "id", "name", + "fqdn", + "is_ioc", "device_type", "model", "description", diff --git a/tests/functional/test_models.py b/tests/functional/test_models.py index 0f2186c..8493080 100644 --- a/tests/functional/test_models.py +++ b/tests/functional/test_models.py @@ -434,6 +434,30 @@ def test_interface_name_existing_cname( assert "Interface name matches an existing cname" in str(excinfo.value) +def test_interface_is_main(host_factory, interface_factory): + # The interface with the same name as the host is the main one + host1 = host_factory(name="myhost") + interface11 = interface_factory(name=host1.name, host=host1) + interface12 = interface_factory(name=host1.name + "-2", host=host1) + interface13 = interface_factory(name=host1.name + "-3", host=host1) + assert interface11.is_main + assert not interface12.is_main + assert not interface13.is_main + host2 = host_factory(name="anotherhost") + interface21 = interface_factory(name=host2.name + "-1", host=host2) + # If no interface has the same name as the host, the first one is the main + assert interface21.is_main + interface22 = interface_factory(name=host2.name + "-2", host=host2) + # The first interface in the list is the main one + assert host2.interfaces[0].is_main + assert not host2.interfaces[1].is_main + interface23 = interface_factory(name=host2.name, host=host2) + # The new interface has the same name as the host, so this is the main one + assert not interface21.is_main + assert not interface22.is_main + assert interface23.is_main + + def test_host_existing_interface(db, host_factory, interface): with pytest.raises(ValidationError) as excinfo: host_factory(name=interface.name) @@ -446,6 +470,28 @@ def test_host_existing_cname(db, host_factory, cname): assert "Host name matches an existing cname" in str(excinfo.value) +def test_host_fqdn(host_factory, interface_factory): + host1 = host_factory(name="myhost") + interface1 = interface_factory(name=host1.name, host=host1) + interface2 = interface_factory(name=host1.name + "-2", host=host1) + assert interface1.network.domain != interface2.network.domain + # The domain is the one from the main interface + assert interface1.is_main + assert host1.fqdn == f"{host1.name}.{interface1.network.domain.name}" + + +def test_host_is_ioc(host_factory, interface_factory, tag_factory): + ioc_tag = tag_factory(name="IOC") + another_tag = tag_factory(name="foo") + host1 = host_factory() + interface_factory(name=host1.name, host=host1, tags=[ioc_tag]) + interface_factory(host=host1) + assert host1.is_ioc + host2 = host_factory() + interface_factory(name=host2.name, host=host2, tags=[another_tag]) + assert not host2.is_ioc + + def test_cname_existing_host(db, host_factory, cname_factory): host_factory(name="myhost") with pytest.raises(ValidationError) as excinfo: diff --git a/tests/functional/test_web.py b/tests/functional/test_web.py index 2c39473..81bbb94 100644 --- a/tests/functional/test_web.py +++ b/tests/functional/test_web.py @@ -194,8 +194,8 @@ def test_retrieve_hosts(logged_client, interface_factory, host_factory): response = logged_client.post("/network/_retrieve_hosts") hosts = response.get_json()["data"] assert {host1.name, host2.name} == set(host["name"] for host in hosts) - assert len(hosts[0]) == 12 - assert len(hosts[0]["interfaces"][0]) == 14 + assert len(hosts[0]) == 14 + assert len(hosts[0]["interfaces"][0]) == 16 def test_retrieve_hosts_by_ip(logged_client, interface_factory): -- GitLab