diff --git a/app/defaults.py b/app/defaults.py index afc5efafd42bf646cd3fb01d4e7d9f110bdeb545..737f7d689004860a7391c24051620c5ad790623f 100644 --- a/app/defaults.py +++ b/app/defaults.py @@ -21,8 +21,8 @@ defaults = [ models.Action(name='Set as parent'), models.Action(name='Update'), - models.DeviceType(name='Physical Machine'), - models.DeviceType(name='Virtual Machine'), + models.DeviceType(name='PhysicalMachine'), + models.DeviceType(name='VirtualMachine'), models.DeviceType(name='Network'), models.DeviceType(name='MicroTCA'), models.DeviceType(name='VME'), diff --git a/app/models.py b/app/models.py index ceb3e7771cb5948aef11f1f70fdf6bb21e8ecddf..dc1cf05fae08b5e92bd30e5f21312640e55ee7ec 100644 --- a/app/models.py +++ b/app/models.py @@ -23,7 +23,8 @@ from flask_login import UserMixin from wtforms import ValidationError from .extensions import db, login_manager, ldap_manager, cache from .plugins import FlaskUserPlugin -from .validators import ICS_ID_RE, HOST_NAME_RE, VLAN_NAME_RE, MAC_ADDRESS_RE +from .validators import (ICS_ID_RE, HOST_NAME_RE, VLAN_NAME_RE, MAC_ADDRESS_RE, + DEVICE_TYPE_RE, TAG_RE) from . import utils @@ -594,6 +595,14 @@ interfacetags_table = db.Table( class Tag(QRCodeMixin, db.Model): admin_only = db.Column(db.Boolean, nullable=False, default=False) + @validates('name') + def validate_name(self, key, string): + """Ensure the name field matches the required format""" + if string is not None: + if TAG_RE.fullmatch(string) is None: + raise ValidationError(f"'{string}' is an invalid tag name") + return string + class DeviceType(db.Model): __tablename__ = 'device_type' @@ -602,6 +611,14 @@ class DeviceType(db.Model): hosts = db.relationship('Host', backref='device_type') + @validates('name') + def validate_name(self, key, string): + """Ensure the name field matches the required format""" + if string is not None: + if DEVICE_TYPE_RE.fullmatch(string) is None: + raise ValidationError(f"'{string}' is an invalid device type name") + return string + def __str__(self): return self.name diff --git a/app/validators.py b/app/validators.py index c73c3d9db91b8154871521f98bc8b47e00e0939c..1228f508e7b05df452063e0d6e39585d1f633164 100644 --- a/app/validators.py +++ b/app/validators.py @@ -18,6 +18,8 @@ ICS_ID_RE = re.compile('[A-Z]{3}[0-9]{3}') HOST_NAME_RE = re.compile('^[a-z0-9\-]{2,20}$') VLAN_NAME_RE = re.compile('^[A-Za-z0-9\-]{3,25}$') MAC_ADDRESS_RE = re.compile('^(?:[0-9a-fA-F]{2}[:-]?){5}[0-9a-fA-F]{2}$') +DEVICE_TYPE_RE = re.compile('^[A-Za-z0-9]{3,25}$') +TAG_RE = DEVICE_TYPE_RE class NoValidateSelectField(SelectField): diff --git a/migrations/versions/f5a605c0c835_remove_spaces_from_device_type.py b/migrations/versions/f5a605c0c835_remove_spaces_from_device_type.py new file mode 100644 index 0000000000000000000000000000000000000000..6373968adcb544d4439d1e40961a5e02b2a27015 --- /dev/null +++ b/migrations/versions/f5a605c0c835_remove_spaces_from_device_type.py @@ -0,0 +1,28 @@ +"""remove spaces from device_type + +Revision ID: f5a605c0c835 +Revises: ea606be23b95 +Create Date: 2018-05-22 13:41:28.137611 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'f5a605c0c835' +down_revision = 'ea606be23b95' +branch_labels = None +depends_on = None + + +def upgrade(): + device_type = sa.sql.table('device_type', sa.sql.column('id'), sa.sql.column('name')) + op.execute(device_type.update().where(device_type.c.name == 'Physical Machine').values(name='PhysicalMachine')) + op.execute(device_type.update().where(device_type.c.name == 'Virtual Machine').values(name='VirtualMachine')) + + +def downgrade(): + device_type = sa.sql.table('device_type', sa.sql.column('id'), sa.sql.column('name')) + op.execute(device_type.update().where(device_type.c.name == 'PhysicalMachine').values(name='Physical Machine')) + op.execute(device_type.update().where(device_type.c.name == 'VirtualMachine').values(name='Virtual Machine')) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 3b69217cbe699993ac4976fb2f22b8f47129099c..c0d6e2c6408e3fda9119f75e082c9dea70e27702 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -32,6 +32,7 @@ register(factories.HostFactory) register(factories.MacFactory) register(factories.DomainFactory) register(factories.CnameFactory) +register(factories.TagFactory) @pytest.fixture(scope='session') diff --git a/tests/functional/factories.py b/tests/functional/factories.py index 787fbdc866abd70d1265f0a957e11f928f1e83bf..50de310644c874e705cef68d788eae7fe9e3a499 100644 --- a/tests/functional/factories.py +++ b/tests/functional/factories.py @@ -189,3 +189,12 @@ class CnameFactory(factory.alchemy.SQLAlchemyModelFactory): name = factory.Sequence(lambda n: f'host{n}') interface = factory.SubFactory(InterfaceFactory) user = factory.SubFactory(UserFactory) + + +class TagFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = models.Tag + sqlalchemy_session = common.Session + sqlalchemy_session_persistence = 'commit' + + name = factory.Sequence(lambda n: f'Tag{n}') diff --git a/tests/functional/test_models.py b/tests/functional/test_models.py index bb84d7c1f01dd69cceec499a46270de061b7e04e..fd5063d27d0ebf29bc37fc5bd06a6ffa1ea99789 100644 --- a/tests/functional/test_models.py +++ b/tests/functional/test_models.py @@ -121,3 +121,19 @@ def test_manufacturer_favorite_users(user_factory, manufacturer_factory): assert user2 in manufacturer2.favorite_users assert user2 not in manufacturer1.favorite_users assert user3 in manufacturer1.favorite_users + + +def test_device_type_validation(device_type_factory): + device_type = device_type_factory(name='PhysicalMachine') + assert device_type.name == 'PhysicalMachine' + with pytest.raises(ValidationError) as excinfo: + device_type = device_type_factory(name='Physical Machine') + assert "'Physical Machine' is an invalid device type name" in str(excinfo.value) + + +def test_tag_validation(tag_factory): + tag = tag_factory(name='IOC') + assert tag.name == 'IOC' + with pytest.raises(ValidationError) as excinfo: + tag = tag_factory(name='My tag') + assert "'My tag' is an invalid tag name" in str(excinfo.value)