From 738fd37180be50392bac14a508f97c52f2b8cac2 Mon Sep 17 00:00:00 2001
From: Benjamin Bertrand <benjamin.bertrand@esss.se>
Date: Wed, 6 Sep 2017 16:24:34 +0200
Subject: [PATCH] Improve API (patch and get item)

- add more tests
- patching an item update children location and status
- setting the parent on an item updates the location and status with the
  parent ones
---
 app/api/main.py              |  72 +++++++++++-
 app/models.py                |   1 +
 tests/functional/test_api.py | 210 ++++++++++++++++++++++++++++++++++-
 3 files changed, 274 insertions(+), 9 deletions(-)

diff --git a/app/api/main.py b/app/api/main.py
index 9bdf0b5..316a011 100644
--- a/app/api/main.py
+++ b/app/api/main.py
@@ -21,8 +21,21 @@ from ..decorators import jwt_groups_accepted
 bp = Blueprint('api', __name__)
 
 
+def get_item_by_id_or_ics_id(id_):
+    """Retrieve item by id or ICS id"""
+    try:
+        item_id = int(id_)
+    except ValueError:
+        # Assume id_ is an ics_id
+        item = Item.query.filter_by(ics_id=id_).first()
+    else:
+        item = Item.query.get(item_id)
+    if item is None:
+        raise utils.InventoryError(f"Item id '{id_}' not found", status_code=404)
+    return item
+
+
 def get_generic_model(model):
-    # TODO: add pagination
     items = model.query.order_by(model.name)
     data = [item.to_dict() for item in items]
     return jsonify(data)
@@ -81,27 +94,74 @@ def get_items():
     return jsonify(data)
 
 
+@bp.route('/items/<id_>')
+@jwt_required
+def get_item(id_):
+    """Retrieve item by id or ICS id"""
+    item = get_item_by_id_or_ics_id(id_)
+    return jsonify(item.to_dict())
+
+
 @bp.route('/items', methods=['POST'])
 @jwt_required
 @jwt_groups_accepted('admin', 'create')
 def create_item():
+    """Register a new item"""
+    # People should assign an ICS id to a serial number when creating
+    # an item so ics_id should also be a mandatory field.
+    # But there are existing items (in confluence and JIRA) that we want to
+    # import and associate after they have been created.
     return create_generic_model(Item, mandatory_field='serial_number')
 
 
-@bp.route('/items/<item_id>', methods=['PATCH'])
+@bp.route('/items/<id_>', methods=['PATCH'])
 @jwt_required
 @jwt_groups_accepted('admin', 'create')
-def patch_item(item_id):
+def patch_item(id_):
+    """Patch an existing item
+
+    id_ can be the primary key or the ics_id field
+    Fields allowed to update are:
+        - ics_id ONLY if currently null (422 returned otherwise)
+        - manufacturer
+        - model
+        - location
+        - status
+        - parent
+
+    422 is returned if other fields are given.
+    """
     data = request.get_json()
     if data is None:
         raise utils.InventoryError('Body should be a JSON object')
-    item = Item.query.get(item_id)
-    if item is None:
-        raise utils.InventoryError(f'Unknown item id: {item_id}', status_code=422)
+    if not data:
+        raise utils.InventoryError('At least one field is required', status_code=422)
+    for key in data.keys():
+        if key not in ('ics_id', 'manufacturer', 'model',
+                       'location', 'status', 'parent'):
+            raise utils.InventoryError(f"Invalid field '{key}'", status_code=422)
+    item = get_item_by_id_or_ics_id(id_)
+    # Only allow to set ICS id if it's null
+    if item.ics_id is None:
+        item.ics_id = data.get('ics_id')
+    elif 'ics_id' in data:
+        raise utils.InventoryError("'ics_id' can't be changed", status_code=422)
     item.manufacturer = utils.convert_to_model(data.get('manufacturer', item.manufacturer), Manufacturer)
     item.model = utils.convert_to_model(data.get('model', item.model), Model)
     item.location = utils.convert_to_model(data.get('location', item.location), Location)
     item.status = utils.convert_to_model(data.get('status', item.status), Status)
+    parent_ics_id = data.get('parent')
+    if parent_ics_id is not None:
+        parent = Item.query.filter_by(ics_id=parent_ics_id).first()
+        if parent is not None:
+            item.parent_id = parent.id
+            # Update location and status with those from parent
+            item.location = parent.location
+            item.status = parent.status
+    # Update all children status and location
+    for child in item.children:
+        child.location = item.location
+        child.status = item.status
     db.session.commit()
     return jsonify(item.to_dict())
 
diff --git a/app/models.py b/app/models.py
index 8e74079..86a6916 100644
--- a/app/models.py
+++ b/app/models.py
@@ -238,6 +238,7 @@ class Item(db.Model):
             'status': utils.format_field(self.status),
             'updated': utils.format_field(self._updated),
             'created': utils.format_field(self._created),
+            'parent': utils.format_field(self.parent),
         }
 
 
diff --git a/tests/functional/test_api.py b/tests/functional/test_api.py
index db395da..088eece 100644
--- a/tests/functional/test_api.py
+++ b/tests/functional/test_api.py
@@ -19,7 +19,9 @@ ENDPOINT_MODEL = {
     'models': models.Model,
     'locations': models.Location,
     'status': models.Status,
+    'items': models.Item,
 }
+GENERIC_ENDPOINTS = [key for key in ENDPOINT_MODEL.keys() if key != 'items']
 ENDPOINTS = list(ENDPOINT_MODEL.keys())
 
 
@@ -44,6 +46,16 @@ def post(client, url, data, token=None):
     return response
 
 
+def patch(client, url, data, token=None):
+    headers = {'Content-Type': 'application/json'}
+    if token is not None:
+        headers['Authorization'] = f'Bearer {token}'
+    response = client.patch(url, data=json.dumps(data), headers=headers)
+    if response.headers['Content-Type'] == 'application/json':
+        response.json = json.loads(response.data)
+    return response
+
+
 def login(client, username, password):
     data = {
         'username': username,
@@ -92,6 +104,13 @@ def check_names(response, names):
     assert set(names) == response_names
 
 
+def check_items(response, inputs):
+    # Sort the response by id to match the inputs order
+    response_items = sorted(response.json, key=lambda d: d['id'])
+    for d1, d2 in zip(inputs, response_items):
+        assert set(d1.items()).issubset(set(d2.items()))
+
+
 def test_login(client):
     response = client.post('/api/login')
     check_response_message(response, 'Body should be a JSON object')
@@ -104,7 +123,7 @@ def test_login(client):
     assert 'access_token' in response.json
 
 
-@pytest.mark.parametrize('endpoint', ENDPOINTS)
+@pytest.mark.parametrize('endpoint', GENERIC_ENDPOINTS)
 def test_get_generic_model(endpoint, session, client, readonly_token):
     model = ENDPOINT_MODEL[endpoint]
     names = ('Foo', 'Bar', 'Alice')
@@ -120,7 +139,7 @@ def test_get_generic_model(endpoint, session, client, readonly_token):
 
 
 @pytest.mark.parametrize('endpoint', ENDPOINTS)
-def test_create_generic_model_fail(endpoint, client, readonly_token):
+def test_create_model_auth_fail(endpoint, client, readonly_token):
     response = client.post(f'/api/{endpoint}')
     check_response_message(response, 'Missing Authorization Header', 401)
     response = post(client, f'/api/{endpoint}', data={}, token='xxxxxxxxx')
@@ -131,7 +150,7 @@ def test_create_generic_model_fail(endpoint, client, readonly_token):
     assert model.query.count() == 0
 
 
-@pytest.mark.parametrize('endpoint', ENDPOINTS)
+@pytest.mark.parametrize('endpoint', GENERIC_ENDPOINTS)
 def test_create_generic_model(endpoint, client, user_token):
     response = post(client, f'/api/{endpoint}', data={}, token=user_token)
     check_response_message(response, "Missing mandatory field 'name'", 422)
@@ -153,3 +172,188 @@ def test_create_generic_model(endpoint, client, user_token):
     assert model.query.count() == 2
     response = get(client, f'/api/{endpoint}', user_token)
     check_names(response, ('Foo', 'Bar'))
+
+
+def test_create_item(client, user_token):
+    # check that serial_number is mandatory
+    response = post(client, '/api/items', data={}, token=user_token)
+    check_response_message(response, "Missing mandatory field 'serial_number'", 422)
+
+    # check create with only serial_number
+    data = {'serial_number': '123456'}
+    response = post(client, '/api/items', data=data, token=user_token)
+    assert response.status_code == 201
+    assert {'id', 'ics_id', 'serial_number', 'manufacturer', 'model',
+            'location', 'status', 'updated', 'created', 'parent'} == set(response.json.keys())
+    assert response.json['serial_number'] == '123456'
+
+    # Check that serial_number doesn't have to be unique
+    response = post(client, '/api/items', data=data, token=user_token)
+    assert response.status_code == 201
+
+    # check that ics_id shall be unique
+    data2 = {'serial_number': '456789', 'ics_id': 'AAA001'}
+    response = post(client, '/api/items', data=data2, token=user_token)
+    assert response.status_code == 201
+    response = post(client, '/api/items', data=data2, token=user_token)
+    check_response_message(response, 'IntegrityError', 409)
+
+    # check all items that were created
+    assert models.Item.query.count() == 3
+    response = get(client, '/api/items', user_token)
+    check_items(response, (data, data, data2))
+
+
+def test_get_item_fail(client, session, readonly_token):
+    response = get(client, '/api/items/50', token=readonly_token)
+    check_response_message(response, "Item id '50' not found", 404)
+    response = get(client, '/api/items/bar', token=readonly_token)
+    check_response_message(response, "Item id 'bar' not found", 404)
+
+
+def test_get_item(client, session, readonly_token):
+    # Create some items
+    session.add(models.Status(name='Stock'))
+    item1 = models.Item(serial_number='123456')
+    item2 = models.Item(serial_number='234567', ics_id='AAA001', status='Stock')
+    for item in (item1, item2):
+        session.add(item)
+    session.commit()
+
+    # we can get items by id...
+    response = get(client, f'/api/items/{item1.id}', token=readonly_token)
+    assert response.status_code == 200
+    assert response.json['id'] == item1.id
+    assert response.json['serial_number'] == item1.serial_number
+    # ...or ics_id
+    response = get(client, f'/api/items/{item2.ics_id}', token=readonly_token)
+    assert response.status_code == 200
+    assert response.json['id'] == item2.id
+    assert response.json['ics_id'] == item2.ics_id
+    assert response.json['serial_number'] == item2.serial_number
+    assert response.json['status'] == str(item2.status)
+
+
+def test_patch_item_auth_fail(client, session, readonly_token):
+    response = client.patch('/api/items/50')
+    check_response_message(response, 'Missing Authorization Header', 401)
+    response = patch(client, '/api/items/50', data={}, token='xxxxxxxxx')
+    check_response_message(response, 'Not enough segments', 422)
+    response = patch(client, '/api/items/50', data={}, token=readonly_token)
+    check_response_message(response, "User doesn't have the required group", 403)
+
+
+def test_patch_item_fail(client, session, user_token):
+    response = patch(client, '/api/items/50', data={}, token=user_token)
+    check_response_message(response, 'At least one field is required', 422)
+    data = {'location': 'ESS', 'foo': 'bar'}
+    response = patch(client, '/api/items/50', data=data, token=user_token)
+    check_response_message(response, "Invalid field 'foo'", 422)
+    data = {'location': 'ESS'}
+    response = patch(client, '/api/items/50', data=data, token=user_token)
+    check_response_message(response, "Item id '50' not found", 404)
+    response = patch(client, '/api/items/bar', data=data, token=user_token)
+    check_response_message(response, "Item id 'bar' not found", 404)
+
+    # Create an item
+    item1 = models.Item(serial_number='234567', ics_id='AAA001')
+    session.add(item1)
+    session.commit()
+
+    # check that we can't change the serial_number or ics_id
+    response = patch(client, f'/api/items/{item1.id}', data={'serial_number': '12345'}, token=user_token)
+    check_response_message(response, "Invalid field 'serial_number'", 422)
+    response = patch(client, f'/api/items/{item1.id}', data={'ics_id': 'AAA002'}, token=user_token)
+    check_response_message(response, "'ics_id' can't be changed", 422)
+
+
+def test_patch_item(client, session, user_token):
+    # Create some items
+    session.add(models.Status(name='Stock'))
+    session.add(models.Status(name='In service'))
+    item1 = models.Item(serial_number='123456')
+    item2 = models.Item(serial_number='234567', ics_id='AAA001', status='Stock')
+    for item in (item1, item2):
+        session.add(item)
+    session.commit()
+
+    # we can patch items by id...
+    data = {'ics_id': 'AAA004'}
+    response = patch(client, f'/api/items/{item1.id}', data=data, token=user_token)
+    assert response.status_code == 200
+    assert response.json['id'] == item1.id
+    assert response.json['serial_number'] == item1.serial_number
+    assert response.json['ics_id'] == data['ics_id']
+    # ...or ics_id
+    data = {'status': 'In service'}
+    response = patch(client, f'/api/items/{item2.ics_id}', data=data, token=user_token)
+    assert response.status_code == 200
+    assert response.json['id'] == item2.id
+    assert response.json['ics_id'] == item2.ics_id
+    assert response.json['serial_number'] == item2.serial_number
+    assert response.json['status'] == data['status']
+
+
+def test_patch_item_parent(client, session, user_token):
+    # Create some items
+    session.add(models.Manufacturer(name='HP'))
+    session.add(models.Manufacturer(name='Dell'))
+    session.add(models.Status(name='Stock'))
+    session.add(models.Status(name='In service'))
+    session.add(models.Location(name='ESS'))
+    session.add(models.Location(name='ICS lab'))
+    item1 = models.Item(serial_number='123456', ics_id='AAA001', location='ICS lab', status='In service',
+                        manufacturer='Dell')
+    item2 = models.Item(serial_number='234567', ics_id='AAA002', status='Stock')
+    item3 = models.Item(serial_number='345678', ics_id='AAA003', status='Stock', manufacturer='Dell')
+    for item in (item1, item2, item3):
+        session.add(item)
+    session.flush()
+    item3.parent_id = item1.id
+    session.commit()
+
+    # set parent changes the status and location
+    data1 = {'parent': item1.ics_id}
+    response = patch(client, f'/api/items/{item2.ics_id}', data=data1, token=user_token)
+    assert response.status_code == 200
+    assert response.json['id'] == item2.id
+    assert response.json['ics_id'] == item2.ics_id
+    assert response.json['serial_number'] == item2.serial_number
+    assert response.json['parent'] == item1.ics_id
+    assert response.json['status'] == str(item1.status)
+    assert response.json['location'] == str(item1.location)
+
+    # updating a parent, modifies the status and location of all children
+    # check location
+    data2 = {'location': 'ESS'}
+    response = patch(client, f'/api/items/{item1.ics_id}', data=data2, token=user_token)
+    assert response.status_code == 200
+    assert response.json['id'] == item1.id
+    assert response.json['ics_id'] == item1.ics_id
+    assert response.json['serial_number'] == item1.serial_number
+    assert response.json['status'] == str(item1.status)
+    assert response.json['location'] == data2['location']
+    for ics_id in ('AAA002', 'AAA003'):
+        response = get(client, f'/api/items/{ics_id}', token=user_token)
+        assert response.json['location'] == data2['location']
+        assert response.json['status'] == 'In service'
+    # check status
+    data3 = {'status': 'Stock'}
+    response = patch(client, f'/api/items/{item1.ics_id}', data=data3, token=user_token)
+    assert response.status_code == 200
+    assert response.json['status'] == data3['status']
+    for ics_id in ('AAA002', 'AAA003'):
+        response = get(client, f'/api/items/{ics_id}', token=user_token)
+        assert response.json['location'] == data2['location']
+        assert response.json['status'] == data3['status']
+
+    # manufacturer has no impact on children
+    data4 = {'manufacturer': 'HP'}
+    response = patch(client, f'/api/items/{item1.ics_id}', data=data4, token=user_token)
+    assert response.status_code == 200
+    assert response.json['manufacturer'] == 'HP'
+    # Manufacturer didn't change on children
+    response = get(client, f'/api/items/{item2.ics_id}', token=user_token)
+    assert response.json['manufacturer'] is None
+    response = get(client, f'/api/items/{item3.ics_id}', token=user_token)
+    assert response.json['manufacturer'] == 'Dell'
-- 
GitLab