From e6494e471a8791d037edc2dc7faa8592a84d7b82 Mon Sep 17 00:00:00 2001
From: Benjamin Bertrand <benjamin.bertrand@esss.se>
Date: Wed, 28 Nov 2018 14:38:11 +0100
Subject: [PATCH] Add extra fields to the API

- add fqdn to host
- add gateway to network
- add netmask to interface
- allow to pass recursive via the API (to expand interfaces in host)
- add is_main property to interface (True if the interface is the main
  interface of the host)

JIRA INFRA-640
---
 app/api/utils.py                |  7 ++-
 app/models.py                   | 33 +++++++-----
 tests/functional/test_api.py    | 91 +++++++++++++++++++++++++++------
 tests/functional/test_models.py | 46 +++++++++++++++++
 tests/functional/test_web.py    |  4 +-
 5 files changed, 150 insertions(+), 31 deletions(-)

diff --git a/app/api/utils.py b/app/api/utils.py
index cee5266..db8c004 100644
--- a/app/api/utils.py
+++ b/app/api/utils.py
@@ -70,13 +70,18 @@ def get_generic_model(model, order_by=None, query=None):
     kwargs = request.args.to_dict()
     page = int(kwargs.pop("page", 1))
     per_page = int(kwargs.pop("per_page", 20))
+    # Remove recursive from kwargs so that it doesn't get passed
+    # to query.filter_by in get_query
+    recursive = kwargs.pop("recursive", "false").lower() == "true"
     if query is None:
         query = utils.get_query(model.query, **kwargs)
         if order_by is None:
             order_by = getattr(model, "name")
         query = query.order_by(order_by)
     pagination = query.paginate(page, per_page)
-    data = [item.to_dict() for item in pagination.items]
+    data = [item.to_dict(recursive=recursive) for item in pagination.items]
+    # Re-add recursive to kwargs so that it's part of the pagination url
+    kwargs["recursive"] = recursive
     header = build_pagination_header(pagination, request.base_url, **kwargs)
     return jsonify(data), 200, header
 
diff --git a/app/models.py b/app/models.py
index 62de3ca..8da58c0 100644
--- a/app/models.py
+++ b/app/models.py
@@ -349,7 +349,7 @@ class User(db.Model, UserMixin):
     def __str__(self):
         return self.username
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         return {
             "id": self.id,
             "username": self.username,
@@ -499,7 +499,7 @@ class QRCodeMixin:
         # See https://flask-caching.readthedocs.io/en/latest/#memoization
         return f"{self.__class__.__name__}(id={self.id}, name={self.name})"
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         return {
             "id": self.id,
             "name": self.name,
@@ -701,7 +701,7 @@ class ItemComment(CreatedMixin, db.Model):
     def __str__(self):
         return self.body
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         d = super().to_dict()
         d.update({"body": self.body, "item": str(self.item)})
         return d
@@ -841,7 +841,7 @@ class Network(CreatedMixin, db.Model):
             raise ValidationError("Vlan name shall match [A-Za-z0-9\-]{3,25}")
         return string
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         d = super().to_dict()
         d.update(
             {
@@ -851,6 +851,7 @@ class Network(CreatedMixin, db.Model):
                 "netmask": str(self.netmask),
                 "first_ip": self.first_ip,
                 "last_ip": self.last_ip,
+                "gateway": str(self.gateway),
                 "description": self.description,
                 "admin_only": self.admin_only,
                 "scope": utils.format_field(self.scope),
@@ -903,7 +904,7 @@ class DeviceType(db.Model):
     def __str__(self):
         return self.name
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         return {
             "id": self.id,
             "name": self.name,
@@ -1024,7 +1025,7 @@ class AnsibleGroup(CreatedMixin, db.Model):
             raise AttributeError("can't set dynamic hosts")
         self._hosts = value
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         d = super().to_dict()
         d.update(
             {
@@ -1179,6 +1180,8 @@ class Host(CreatedMixin, SearchableMixin, db.Model):
         d.update(
             {
                 "name": self.name,
+                "fqdn": self.fqdn,
+                "is_ioc": self.is_ioc,
                 "device_type": str(self.device_type),
                 "model": self.model,
                 "description": self.description,
@@ -1303,18 +1306,24 @@ class Interface(CreatedMixin, db.Model):
                 return True
         return False
 
+    @property
+    def is_main(self):
+        return self.name == self.host.main_interface.name
+
     def __str__(self):
         return str(self.name)
 
     def __repr__(self):
         return f"Interface(id={self.id}, network_id={self.network_id}, ip={self.ip}, name={self.name}, mac={self.mac})"
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         d = super().to_dict()
         d.update(
             {
+                "is_main": self.is_main,
                 "network": str(self.network),
                 "ip": self.ip,
+                "netmask": str(self.network.netmask),
                 "name": self.name,
                 "mac": utils.format_field(self.mac),
                 "host": utils.format_field(self.host),
@@ -1349,7 +1358,7 @@ class Mac(db.Model):
             raise ValidationError(f"'{string}' does not appear to be a MAC address")
         return string
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         return {
             "id": self.id,
             "address": self.address,
@@ -1398,7 +1407,7 @@ class Cname(CreatedMixin, db.Model):
             raise ValidationError(f"cname matches an existing host")
         return lower_string
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         d = super().to_dict()
         d.update({"name": self.name, "interface": str(self.interface)})
         return d
@@ -1417,7 +1426,7 @@ class Domain(CreatedMixin, db.Model):
     def __str__(self):
         return str(self.name)
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         d = super().to_dict()
         d.update(
             {
@@ -1492,7 +1501,7 @@ class NetworkScope(CreatedMixin, db.Model):
             if subnet not in self.used_subnets()
         ]
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         d = super().to_dict()
         d.update(
             {
@@ -1548,7 +1557,7 @@ class Task(db.Model):
     def __str__(self):
         return str(self.id)
 
-    def to_dict(self):
+    def to_dict(self, recursive=False):
         return {
             "id": self.id,
             "name": self.name,
diff --git a/tests/functional/test_api.py b/tests/functional/test_api.py
index 12609a4..acecb7f 100644
--- a/tests/functional/test_api.py
+++ b/tests/functional/test_api.py
@@ -45,6 +45,40 @@ GENERIC_CREATE_ENDPOINTS = [
 CREATE_AUTH_ENDPOINTS = [
     key for key in ENDPOINT_MODEL.keys() if key != "inventory/actions"
 ]
+HOST_KEYS = {
+    "id",
+    "name",
+    "fqdn",
+    "is_ioc",
+    "device_type",
+    "model",
+    "description",
+    "items",
+    "interfaces",
+    "ansible_vars",
+    "ansible_groups",
+    "created_at",
+    "updated_at",
+    "user",
+}
+INTERFACE_KEYS = {
+    "id",
+    "is_main",
+    "network",
+    "ip",
+    "netmask",
+    "name",
+    "mac",
+    "domain",
+    "host",
+    "device_type",
+    "model",
+    "cnames",
+    "tags",
+    "created_at",
+    "updated_at",
+    "user",
+}
 
 
 def get(client, url, token=None):
@@ -677,6 +711,7 @@ def test_create_network(client, admin_token, network_scope_factory):
         "netmask",
         "first_ip",
         "last_ip",
+        "gateway",
         "description",
         "admin_only",
         "scope",
@@ -1026,25 +1061,12 @@ def test_create_interface(client, host, network_factory, no_login_check_token):
         client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token
     )
     assert response.status_code == 201
-    assert {
-        "id",
-        "network",
-        "ip",
-        "name",
-        "mac",
-        "domain",
-        "host",
-        "device_type",
-        "model",
-        "cnames",
-        "tags",
-        "created_at",
-        "updated_at",
-        "user",
-    } == set(response.get_json().keys())
+    assert INTERFACE_KEYS == set(response.get_json().keys())
     assert response.get_json()["network"] == network.vlan_name
     assert response.get_json()["ip"] == "192.168.1.20"
     assert response.get_json()["name"] == host.name
+    # This is the main interface
+    assert response.get_json()["is_main"]
 
     # Check that all parameters can be passed
     data2 = {
@@ -1058,6 +1080,8 @@ def test_create_interface(client, host, network_factory, no_login_check_token):
         client, f"{API_URL}/network/interfaces", data=data2, token=no_login_check_token
     )
     assert response.status_code == 201
+    # This is not the main interface
+    assert not response.get_json()["is_main"]
 
     # check all items that were created
     assert models.Interface.query.count() == 2
@@ -1291,6 +1315,7 @@ def test_get_hosts(client, host_factory, readonly_token):
     response = get(client, f"{API_URL}/network/hosts", token=readonly_token)
     assert response.status_code == 200
     assert len(response.get_json()) == 2
+    assert HOST_KEYS == set(response.get_json()[0].keys())
     check_input_is_subset_of_response(response, (host1.to_dict(), host2.to_dict()))
 
 
@@ -1320,6 +1345,38 @@ def test_get_hosts_with_no_model(client, host_factory, readonly_token):
     assert response.get_json()[0]["model"] is None
 
 
+def test_get_hosts_recursive(client, host_factory, interface_factory, readonly_token):
+    # Create some hosts with interfaces
+    host1 = host_factory()
+    interface11 = interface_factory(name=host1.name, host=host1)
+    interface12 = interface_factory(host=host1)
+    host2 = host_factory()
+    interface21 = interface_factory(host=host2)
+    # Without recursive, we only get the name of the interfaces
+    response = get(client, f"{API_URL}/network/hosts", token=readonly_token)
+    assert response.status_code == 200
+    assert len(response.get_json()) == 2
+    rhost1, rhost2 = response.get_json()
+    # We can't be sure in which order the interfaces are returned
+    assert set(rhost1["interfaces"]) == {interface11.name, interface12.name}
+    assert rhost2["interfaces"] == [interface21.name]
+    # With recursive, interfaces are expanded
+    response = get(
+        client, f"{API_URL}/network/hosts?recursive=true", token=readonly_token
+    )
+    assert response.status_code == 200
+    assert len(response.get_json()) == 2
+    rhost1, rhost2 = response.get_json()
+    assert len(rhost1["interfaces"]) == 2
+    rinterface11, rinterface12 = rhost1["interfaces"]
+    assert INTERFACE_KEYS == set(rinterface11.keys())
+    assert INTERFACE_KEYS == set(rinterface12.keys())
+    assert len(rhost2["interfaces"]) == 1
+    rinterface21 = rhost2["interfaces"][0]
+    assert INTERFACE_KEYS == set(rinterface21.keys())
+    assert rinterface21["network"] == interface21.network.vlan_name
+
+
 def test_create_host(client, device_type_factory, user_token):
     device_type = device_type_factory(name="Virtual")
     # check that name and device_type are  mandatory
@@ -1343,6 +1400,8 @@ def test_create_host(client, device_type_factory, user_token):
     assert {
         "id",
         "name",
+        "fqdn",
+        "is_ioc",
         "device_type",
         "model",
         "description",
diff --git a/tests/functional/test_models.py b/tests/functional/test_models.py
index 0f2186c..8493080 100644
--- a/tests/functional/test_models.py
+++ b/tests/functional/test_models.py
@@ -434,6 +434,30 @@ def test_interface_name_existing_cname(
     assert "Interface name matches an existing cname" in str(excinfo.value)
 
 
+def test_interface_is_main(host_factory, interface_factory):
+    # The interface with the same name as the host is the main one
+    host1 = host_factory(name="myhost")
+    interface11 = interface_factory(name=host1.name, host=host1)
+    interface12 = interface_factory(name=host1.name + "-2", host=host1)
+    interface13 = interface_factory(name=host1.name + "-3", host=host1)
+    assert interface11.is_main
+    assert not interface12.is_main
+    assert not interface13.is_main
+    host2 = host_factory(name="anotherhost")
+    interface21 = interface_factory(name=host2.name + "-1", host=host2)
+    # If no interface has the same name as the host, the first one is the main
+    assert interface21.is_main
+    interface22 = interface_factory(name=host2.name + "-2", host=host2)
+    # The first interface in the list is the main one
+    assert host2.interfaces[0].is_main
+    assert not host2.interfaces[1].is_main
+    interface23 = interface_factory(name=host2.name, host=host2)
+    # The new interface has the same name as the host, so this is the main one
+    assert not interface21.is_main
+    assert not interface22.is_main
+    assert interface23.is_main
+
+
 def test_host_existing_interface(db, host_factory, interface):
     with pytest.raises(ValidationError) as excinfo:
         host_factory(name=interface.name)
@@ -446,6 +470,28 @@ def test_host_existing_cname(db, host_factory, cname):
     assert "Host name matches an existing cname" in str(excinfo.value)
 
 
+def test_host_fqdn(host_factory, interface_factory):
+    host1 = host_factory(name="myhost")
+    interface1 = interface_factory(name=host1.name, host=host1)
+    interface2 = interface_factory(name=host1.name + "-2", host=host1)
+    assert interface1.network.domain != interface2.network.domain
+    # The domain is the one from the main interface
+    assert interface1.is_main
+    assert host1.fqdn == f"{host1.name}.{interface1.network.domain.name}"
+
+
+def test_host_is_ioc(host_factory, interface_factory, tag_factory):
+    ioc_tag = tag_factory(name="IOC")
+    another_tag = tag_factory(name="foo")
+    host1 = host_factory()
+    interface_factory(name=host1.name, host=host1, tags=[ioc_tag])
+    interface_factory(host=host1)
+    assert host1.is_ioc
+    host2 = host_factory()
+    interface_factory(name=host2.name, host=host2, tags=[another_tag])
+    assert not host2.is_ioc
+
+
 def test_cname_existing_host(db, host_factory, cname_factory):
     host_factory(name="myhost")
     with pytest.raises(ValidationError) as excinfo:
diff --git a/tests/functional/test_web.py b/tests/functional/test_web.py
index 2c39473..81bbb94 100644
--- a/tests/functional/test_web.py
+++ b/tests/functional/test_web.py
@@ -194,8 +194,8 @@ def test_retrieve_hosts(logged_client, interface_factory, host_factory):
     response = logged_client.post("/network/_retrieve_hosts")
     hosts = response.get_json()["data"]
     assert {host1.name, host2.name} == set(host["name"] for host in hosts)
-    assert len(hosts[0]) == 12
-    assert len(hosts[0]["interfaces"][0]) == 14
+    assert len(hosts[0]) == 14
+    assert len(hosts[0]["interfaces"][0]) == 16
 
 
 def test_retrieve_hosts_by_ip(logged_client, interface_factory):
-- 
GitLab