diff --git a/app/models.py b/app/models.py index 51de7bdc017ff6dd838ee8d6a4cdb4daedf19257..0710b0d6a5874e481a37a0618027ec97fc7e6014 100644 --- a/app/models.py +++ b/app/models.py @@ -59,7 +59,7 @@ def save_user(dn, username, data, memberships): """ user = User.query.filter_by(username=username).first() if user is None: - user = User(username, + user = User(username=username, name=utils.attribute_to_string(data['cn']), email=utils.attribute_to_string(data['mail'])) # Always update the user groups to keep them up-to-date @@ -81,9 +81,6 @@ class Group(db.Model): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.Text, nullable=False, unique=True) - def __init__(self, name): - self.name = name - def __str__(self): return self.name @@ -110,11 +107,6 @@ class User(db.Model, UserMixin): groups = association_proxy('grp', 'name', creator=find_or_create_group) - def __init__(self, username, name, email): - self.username = username - self.name = name - self.email = email - def get_id(self): """Return the user id as unicode @@ -144,9 +136,6 @@ class QRCodeMixin: id = db.Column(db.Integer, primary_key=True) name = db.Column(CIText, nullable=False, unique=True) - def __init__(self, name=None): - self.name = name - def image(self): """Return a QRCode image to identify a record @@ -219,14 +208,16 @@ class Item(db.Model): children = db.relationship('Item', backref=db.backref('parent', remote_side=[id])) macs = db.relationship('Mac', backref='item') - def __init__(self, ics_id=None, serial_number=None, manufacturer=None, model=None, location=None, status=None): - # All arguments must be optional for this class to work with flask-admin! - self.ics_id = ics_id - self.serial_number = serial_number - self.manufacturer = utils.convert_to_model(manufacturer, Manufacturer) - self.model = utils.convert_to_model(model, Model) - self.location = utils.convert_to_model(location, Location) - self.status = utils.convert_to_model(status, Status) + def __init__(self, **kwargs): + # Automatically convert manufacturer/model/location/status to an + # instance of their class if passed as a string + for key, cls in [('manufacturer', Manufacturer), + ('model', Model), + ('location', Location), + ('status', Status)]: + if key in kwargs: + kwargs[key] = utils.convert_to_model(kwargs[key], cls) + super().__init__(**kwargs) def __str__(self): return str(self.ics_id) diff --git a/tests/functional/test_api.py b/tests/functional/test_api.py index d033b9c2437925c780be74750bcd40564334394f..fb755d29314d2f55a57d1c0405d6545b64f7c627 100644 --- a/tests/functional/test_api.py +++ b/tests/functional/test_api.py @@ -188,8 +188,9 @@ def test_create_generic_model(endpoint, client, user_token): @pytest.mark.parametrize('endpoint', GENERIC_CREATE_ENDPOINTS) def test_create_generic_model_invalid_param(endpoint, client, user_token): + model = ENDPOINT_MODEL[endpoint] response = post(client, f'/api/{endpoint}', data={'name': 'foo', 'hello': 'world'}, token=user_token) - check_response_message(response, "unexpected keyword argument 'hello'", 422) + check_response_message(response, f"'hello' is an invalid keyword argument for {model.__name__}", 422) def test_create_item(client, user_token):