From 738fd37180be50392bac14a508f97c52f2b8cac2 Mon Sep 17 00:00:00 2001 From: Benjamin Bertrand <benjamin.bertrand@esss.se> Date: Wed, 6 Sep 2017 16:24:34 +0200 Subject: [PATCH] Improve API (patch and get item) - add more tests - patching an item update children location and status - setting the parent on an item updates the location and status with the parent ones --- app/api/main.py | 72 +++++++++++- app/models.py | 1 + tests/functional/test_api.py | 210 ++++++++++++++++++++++++++++++++++- 3 files changed, 274 insertions(+), 9 deletions(-) diff --git a/app/api/main.py b/app/api/main.py index 9bdf0b5..316a011 100644 --- a/app/api/main.py +++ b/app/api/main.py @@ -21,8 +21,21 @@ from ..decorators import jwt_groups_accepted bp = Blueprint('api', __name__) +def get_item_by_id_or_ics_id(id_): + """Retrieve item by id or ICS id""" + try: + item_id = int(id_) + except ValueError: + # Assume id_ is an ics_id + item = Item.query.filter_by(ics_id=id_).first() + else: + item = Item.query.get(item_id) + if item is None: + raise utils.InventoryError(f"Item id '{id_}' not found", status_code=404) + return item + + def get_generic_model(model): - # TODO: add pagination items = model.query.order_by(model.name) data = [item.to_dict() for item in items] return jsonify(data) @@ -81,27 +94,74 @@ def get_items(): return jsonify(data) +@bp.route('/items/<id_>') +@jwt_required +def get_item(id_): + """Retrieve item by id or ICS id""" + item = get_item_by_id_or_ics_id(id_) + return jsonify(item.to_dict()) + + @bp.route('/items', methods=['POST']) @jwt_required @jwt_groups_accepted('admin', 'create') def create_item(): + """Register a new item""" + # People should assign an ICS id to a serial number when creating + # an item so ics_id should also be a mandatory field. + # But there are existing items (in confluence and JIRA) that we want to + # import and associate after they have been created. return create_generic_model(Item, mandatory_field='serial_number') -@bp.route('/items/<item_id>', methods=['PATCH']) +@bp.route('/items/<id_>', methods=['PATCH']) @jwt_required @jwt_groups_accepted('admin', 'create') -def patch_item(item_id): +def patch_item(id_): + """Patch an existing item + + id_ can be the primary key or the ics_id field + Fields allowed to update are: + - ics_id ONLY if currently null (422 returned otherwise) + - manufacturer + - model + - location + - status + - parent + + 422 is returned if other fields are given. + """ data = request.get_json() if data is None: raise utils.InventoryError('Body should be a JSON object') - item = Item.query.get(item_id) - if item is None: - raise utils.InventoryError(f'Unknown item id: {item_id}', status_code=422) + if not data: + raise utils.InventoryError('At least one field is required', status_code=422) + for key in data.keys(): + if key not in ('ics_id', 'manufacturer', 'model', + 'location', 'status', 'parent'): + raise utils.InventoryError(f"Invalid field '{key}'", status_code=422) + item = get_item_by_id_or_ics_id(id_) + # Only allow to set ICS id if it's null + if item.ics_id is None: + item.ics_id = data.get('ics_id') + elif 'ics_id' in data: + raise utils.InventoryError("'ics_id' can't be changed", status_code=422) item.manufacturer = utils.convert_to_model(data.get('manufacturer', item.manufacturer), Manufacturer) item.model = utils.convert_to_model(data.get('model', item.model), Model) item.location = utils.convert_to_model(data.get('location', item.location), Location) item.status = utils.convert_to_model(data.get('status', item.status), Status) + parent_ics_id = data.get('parent') + if parent_ics_id is not None: + parent = Item.query.filter_by(ics_id=parent_ics_id).first() + if parent is not None: + item.parent_id = parent.id + # Update location and status with those from parent + item.location = parent.location + item.status = parent.status + # Update all children status and location + for child in item.children: + child.location = item.location + child.status = item.status db.session.commit() return jsonify(item.to_dict()) diff --git a/app/models.py b/app/models.py index 8e74079..86a6916 100644 --- a/app/models.py +++ b/app/models.py @@ -238,6 +238,7 @@ class Item(db.Model): 'status': utils.format_field(self.status), 'updated': utils.format_field(self._updated), 'created': utils.format_field(self._created), + 'parent': utils.format_field(self.parent), } diff --git a/tests/functional/test_api.py b/tests/functional/test_api.py index db395da..088eece 100644 --- a/tests/functional/test_api.py +++ b/tests/functional/test_api.py @@ -19,7 +19,9 @@ ENDPOINT_MODEL = { 'models': models.Model, 'locations': models.Location, 'status': models.Status, + 'items': models.Item, } +GENERIC_ENDPOINTS = [key for key in ENDPOINT_MODEL.keys() if key != 'items'] ENDPOINTS = list(ENDPOINT_MODEL.keys()) @@ -44,6 +46,16 @@ def post(client, url, data, token=None): return response +def patch(client, url, data, token=None): + headers = {'Content-Type': 'application/json'} + if token is not None: + headers['Authorization'] = f'Bearer {token}' + response = client.patch(url, data=json.dumps(data), headers=headers) + if response.headers['Content-Type'] == 'application/json': + response.json = json.loads(response.data) + return response + + def login(client, username, password): data = { 'username': username, @@ -92,6 +104,13 @@ def check_names(response, names): assert set(names) == response_names +def check_items(response, inputs): + # Sort the response by id to match the inputs order + response_items = sorted(response.json, key=lambda d: d['id']) + for d1, d2 in zip(inputs, response_items): + assert set(d1.items()).issubset(set(d2.items())) + + def test_login(client): response = client.post('/api/login') check_response_message(response, 'Body should be a JSON object') @@ -104,7 +123,7 @@ def test_login(client): assert 'access_token' in response.json -@pytest.mark.parametrize('endpoint', ENDPOINTS) +@pytest.mark.parametrize('endpoint', GENERIC_ENDPOINTS) def test_get_generic_model(endpoint, session, client, readonly_token): model = ENDPOINT_MODEL[endpoint] names = ('Foo', 'Bar', 'Alice') @@ -120,7 +139,7 @@ def test_get_generic_model(endpoint, session, client, readonly_token): @pytest.mark.parametrize('endpoint', ENDPOINTS) -def test_create_generic_model_fail(endpoint, client, readonly_token): +def test_create_model_auth_fail(endpoint, client, readonly_token): response = client.post(f'/api/{endpoint}') check_response_message(response, 'Missing Authorization Header', 401) response = post(client, f'/api/{endpoint}', data={}, token='xxxxxxxxx') @@ -131,7 +150,7 @@ def test_create_generic_model_fail(endpoint, client, readonly_token): assert model.query.count() == 0 -@pytest.mark.parametrize('endpoint', ENDPOINTS) +@pytest.mark.parametrize('endpoint', GENERIC_ENDPOINTS) def test_create_generic_model(endpoint, client, user_token): response = post(client, f'/api/{endpoint}', data={}, token=user_token) check_response_message(response, "Missing mandatory field 'name'", 422) @@ -153,3 +172,188 @@ def test_create_generic_model(endpoint, client, user_token): assert model.query.count() == 2 response = get(client, f'/api/{endpoint}', user_token) check_names(response, ('Foo', 'Bar')) + + +def test_create_item(client, user_token): + # check that serial_number is mandatory + response = post(client, '/api/items', data={}, token=user_token) + check_response_message(response, "Missing mandatory field 'serial_number'", 422) + + # check create with only serial_number + data = {'serial_number': '123456'} + response = post(client, '/api/items', data=data, token=user_token) + assert response.status_code == 201 + assert {'id', 'ics_id', 'serial_number', 'manufacturer', 'model', + 'location', 'status', 'updated', 'created', 'parent'} == set(response.json.keys()) + assert response.json['serial_number'] == '123456' + + # Check that serial_number doesn't have to be unique + response = post(client, '/api/items', data=data, token=user_token) + assert response.status_code == 201 + + # check that ics_id shall be unique + data2 = {'serial_number': '456789', 'ics_id': 'AAA001'} + response = post(client, '/api/items', data=data2, token=user_token) + assert response.status_code == 201 + response = post(client, '/api/items', data=data2, token=user_token) + check_response_message(response, 'IntegrityError', 409) + + # check all items that were created + assert models.Item.query.count() == 3 + response = get(client, '/api/items', user_token) + check_items(response, (data, data, data2)) + + +def test_get_item_fail(client, session, readonly_token): + response = get(client, '/api/items/50', token=readonly_token) + check_response_message(response, "Item id '50' not found", 404) + response = get(client, '/api/items/bar', token=readonly_token) + check_response_message(response, "Item id 'bar' not found", 404) + + +def test_get_item(client, session, readonly_token): + # Create some items + session.add(models.Status(name='Stock')) + item1 = models.Item(serial_number='123456') + item2 = models.Item(serial_number='234567', ics_id='AAA001', status='Stock') + for item in (item1, item2): + session.add(item) + session.commit() + + # we can get items by id... + response = get(client, f'/api/items/{item1.id}', token=readonly_token) + assert response.status_code == 200 + assert response.json['id'] == item1.id + assert response.json['serial_number'] == item1.serial_number + # ...or ics_id + response = get(client, f'/api/items/{item2.ics_id}', token=readonly_token) + assert response.status_code == 200 + assert response.json['id'] == item2.id + assert response.json['ics_id'] == item2.ics_id + assert response.json['serial_number'] == item2.serial_number + assert response.json['status'] == str(item2.status) + + +def test_patch_item_auth_fail(client, session, readonly_token): + response = client.patch('/api/items/50') + check_response_message(response, 'Missing Authorization Header', 401) + response = patch(client, '/api/items/50', data={}, token='xxxxxxxxx') + check_response_message(response, 'Not enough segments', 422) + response = patch(client, '/api/items/50', data={}, token=readonly_token) + check_response_message(response, "User doesn't have the required group", 403) + + +def test_patch_item_fail(client, session, user_token): + response = patch(client, '/api/items/50', data={}, token=user_token) + check_response_message(response, 'At least one field is required', 422) + data = {'location': 'ESS', 'foo': 'bar'} + response = patch(client, '/api/items/50', data=data, token=user_token) + check_response_message(response, "Invalid field 'foo'", 422) + data = {'location': 'ESS'} + response = patch(client, '/api/items/50', data=data, token=user_token) + check_response_message(response, "Item id '50' not found", 404) + response = patch(client, '/api/items/bar', data=data, token=user_token) + check_response_message(response, "Item id 'bar' not found", 404) + + # Create an item + item1 = models.Item(serial_number='234567', ics_id='AAA001') + session.add(item1) + session.commit() + + # check that we can't change the serial_number or ics_id + response = patch(client, f'/api/items/{item1.id}', data={'serial_number': '12345'}, token=user_token) + check_response_message(response, "Invalid field 'serial_number'", 422) + response = patch(client, f'/api/items/{item1.id}', data={'ics_id': 'AAA002'}, token=user_token) + check_response_message(response, "'ics_id' can't be changed", 422) + + +def test_patch_item(client, session, user_token): + # Create some items + session.add(models.Status(name='Stock')) + session.add(models.Status(name='In service')) + item1 = models.Item(serial_number='123456') + item2 = models.Item(serial_number='234567', ics_id='AAA001', status='Stock') + for item in (item1, item2): + session.add(item) + session.commit() + + # we can patch items by id... + data = {'ics_id': 'AAA004'} + response = patch(client, f'/api/items/{item1.id}', data=data, token=user_token) + assert response.status_code == 200 + assert response.json['id'] == item1.id + assert response.json['serial_number'] == item1.serial_number + assert response.json['ics_id'] == data['ics_id'] + # ...or ics_id + data = {'status': 'In service'} + response = patch(client, f'/api/items/{item2.ics_id}', data=data, token=user_token) + assert response.status_code == 200 + assert response.json['id'] == item2.id + assert response.json['ics_id'] == item2.ics_id + assert response.json['serial_number'] == item2.serial_number + assert response.json['status'] == data['status'] + + +def test_patch_item_parent(client, session, user_token): + # Create some items + session.add(models.Manufacturer(name='HP')) + session.add(models.Manufacturer(name='Dell')) + session.add(models.Status(name='Stock')) + session.add(models.Status(name='In service')) + session.add(models.Location(name='ESS')) + session.add(models.Location(name='ICS lab')) + item1 = models.Item(serial_number='123456', ics_id='AAA001', location='ICS lab', status='In service', + manufacturer='Dell') + item2 = models.Item(serial_number='234567', ics_id='AAA002', status='Stock') + item3 = models.Item(serial_number='345678', ics_id='AAA003', status='Stock', manufacturer='Dell') + for item in (item1, item2, item3): + session.add(item) + session.flush() + item3.parent_id = item1.id + session.commit() + + # set parent changes the status and location + data1 = {'parent': item1.ics_id} + response = patch(client, f'/api/items/{item2.ics_id}', data=data1, token=user_token) + assert response.status_code == 200 + assert response.json['id'] == item2.id + assert response.json['ics_id'] == item2.ics_id + assert response.json['serial_number'] == item2.serial_number + assert response.json['parent'] == item1.ics_id + assert response.json['status'] == str(item1.status) + assert response.json['location'] == str(item1.location) + + # updating a parent, modifies the status and location of all children + # check location + data2 = {'location': 'ESS'} + response = patch(client, f'/api/items/{item1.ics_id}', data=data2, token=user_token) + assert response.status_code == 200 + assert response.json['id'] == item1.id + assert response.json['ics_id'] == item1.ics_id + assert response.json['serial_number'] == item1.serial_number + assert response.json['status'] == str(item1.status) + assert response.json['location'] == data2['location'] + for ics_id in ('AAA002', 'AAA003'): + response = get(client, f'/api/items/{ics_id}', token=user_token) + assert response.json['location'] == data2['location'] + assert response.json['status'] == 'In service' + # check status + data3 = {'status': 'Stock'} + response = patch(client, f'/api/items/{item1.ics_id}', data=data3, token=user_token) + assert response.status_code == 200 + assert response.json['status'] == data3['status'] + for ics_id in ('AAA002', 'AAA003'): + response = get(client, f'/api/items/{ics_id}', token=user_token) + assert response.json['location'] == data2['location'] + assert response.json['status'] == data3['status'] + + # manufacturer has no impact on children + data4 = {'manufacturer': 'HP'} + response = patch(client, f'/api/items/{item1.ics_id}', data=data4, token=user_token) + assert response.status_code == 200 + assert response.json['manufacturer'] == 'HP' + # Manufacturer didn't change on children + response = get(client, f'/api/items/{item2.ics_id}', token=user_token) + assert response.json['manufacturer'] is None + response = get(client, f'/api/items/{item3.ics_id}', token=user_token) + assert response.json['manufacturer'] == 'Dell' -- GitLab