From 29c556bf4d3bae704635cb68711c76866be328d3 Mon Sep 17 00:00:00 2001
From: Benjamin Bertrand <benjamin.bertrand@esss.se>
Date: Tue, 19 Dec 2017 11:50:09 +0100
Subject: [PATCH] Add Unique validator

To check if the unique object is the one being edited, we need to store
the object used to create the form in the form (_obj).
This is inspired from flask-admin.
---
 app/admin/validators.py                | 32 --------------
 app/admin/views.py                     |  2 +-
 app/helpers.py                         | 16 +++++++
 app/inventory/forms.py                 | 12 ++---
 app/network/forms.py                   | 17 ++++---
 app/templates/inventory/edit_item.html |  2 +-
 app/tokens.py                          |  2 +-
 app/validators.py                      | 61 ++++++++++++++++++++++++++
 8 files changed, 98 insertions(+), 46 deletions(-)
 delete mode 100644 app/admin/validators.py
 create mode 100644 app/validators.py

diff --git a/app/admin/validators.py b/app/admin/validators.py
deleted file mode 100644
index 187b134..0000000
--- a/app/admin/validators.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-app.admin.validators
-~~~~~~~~~~~~~~~~~~~~
-
-This module defines field validators
-
-:copyright: (c) 2017 European Spallation Source ERIC
-:license: BSD 2-Clause, see LICENSE for more details.
-
-"""
-import ipaddress
-from wtforms import ValidationError
-
-
-class IPNetwork:
-    """
-    Validates an IP network.
-
-    :param message:
-        Error message to raise in case of a validation error.
-    """
-    def __init__(self, message=None):
-        self.message = message
-
-    def __call__(self, form, field):
-        try:
-            ipaddress.ip_network(field.data, strict=True)
-        except (ipaddress.AddressValueError, ipaddress.NetmaskValueError, ValueError):
-            if self.message is None:
-                self.message = field.gettext('Invalid IP network.')
-            raise ValidationError(self.message)
diff --git a/app/admin/views.py b/app/admin/views.py
index ba0d6aa..8ab0a16 100644
--- a/app/admin/views.py
+++ b/app/admin/views.py
@@ -13,7 +13,7 @@ from wtforms import validators, fields
 from flask_admin.contrib import sqla
 from flask_admin.model.form import converts
 from flask_login import current_user
-from .validators import IPNetwork
+from ..validators import IPNetwork
 from ..models import ICS_ID_RE
 
 
diff --git a/app/helpers.py b/app/helpers.py
index cd8bcce..3c4b20c 100644
--- a/app/helpers.py
+++ b/app/helpers.py
@@ -11,9 +11,25 @@ This module implements helpers functions for the models.
 """
 import string
 from flask import current_app
+from flask_wtf import FlaskForm
 from . import models
 
 
+class CSEntryForm(FlaskForm):
+
+    def __init__(self, formdata=None, obj=None, **kwargs):
+        # Store obj for Unique validator to check if the unique object
+        # is identical to the one being edited
+        self._obj = obj
+        # formdata is often given as first argument (not keyword argument)
+        # It's initialized as a singleton in flask-wtf so we should only pass
+        # it if not None
+        if formdata is None:
+            super().__init__(obj=obj, **kwargs)
+        else:
+            super().__init__(formdata=formdata, obj=obj, **kwargs)
+
+
 def temporary_ics_ids():
     """Generator that returns the full list of temporary ICS ids"""
     return (f'{current_app.config["TEMPORARY_ICS_ID"]}{letter}{number:0=3d}'
diff --git a/app/inventory/forms.py b/app/inventory/forms.py
index 286a519..67773e0 100644
--- a/app/inventory/forms.py
+++ b/app/inventory/forms.py
@@ -10,8 +10,9 @@ This module defines the inventory blueprint forms.
 
 """
 import re
-from flask_wtf import FlaskForm
 from wtforms import SelectField, StringField, TextAreaField, validators
+from ..helpers import CSEntryForm
+from ..validators import Unique
 from .. import utils, models
 
 MAC_ADDRESS_RE = re.compile('^(?:[0-9a-fA-F]{2}[:-]){5}[0-9a-fA-F]{2}$')
@@ -23,14 +24,15 @@ def check_mac_addresses_list(form, field):
             raise validators.ValidationError('Invalid Mac address')
 
 
-class AttributeForm(FlaskForm):
+class AttributeForm(CSEntryForm):
     name = StringField('name', validators=[validators.DataRequired()])
 
 
-class ItemForm(FlaskForm):
+class ItemForm(CSEntryForm):
     ics_id = StringField('ICS id',
                          validators=[validators.InputRequired(),
-                                     validators.Regexp(models.ICS_ID_RE)])
+                                     validators.Regexp(models.ICS_ID_RE),
+                                     Unique(models.Item, 'ics_id')])
     serial_number = StringField('Serial number',
                                 validators=[validators.InputRequired()])
     manufacturer_id = SelectField('Manufacturer', coerce=utils.coerce_to_str_or_none)
@@ -52,6 +54,6 @@ class ItemForm(FlaskForm):
         self.parent_id.choices = utils.get_model_choices(models.Item, allow_none=True, attr='ics_id')
 
 
-class CommentForm(FlaskForm):
+class CommentForm(CSEntryForm):
     text = TextAreaField('Enter your comment:',
                          validators=[validators.DataRequired()])
diff --git a/app/network/forms.py b/app/network/forms.py
index f3ba1e8..c78c52d 100644
--- a/app/network/forms.py
+++ b/app/network/forms.py
@@ -10,9 +10,10 @@ This module defines the network blueprint forms.
 
 """
 from flask_login import current_user
-from flask_wtf import FlaskForm
 from wtforms import (SelectField, StringField, TextAreaField,
                      SelectMultipleField, BooleanField, validators)
+from ..helpers import CSEntryForm
+from ..validators import Unique
 from .. import utils, models
 
 
@@ -39,12 +40,13 @@ class NoValidateSelectField(SelectField):
         pass
 
 
-class NetworkForm(FlaskForm):
+class NetworkForm(CSEntryForm):
     scope_id = SelectField('Network Scope')
     vlan_name = StringField('Vlan name',
                             description='hostname must be 3-25 characters long and contain only letters, numbers and dash',
                             validators=[validators.InputRequired(),
-                                        validators.Regexp(models.VLAN_NAME_RE)])
+                                        validators.Regexp(models.VLAN_NAME_RE),
+                                        Unique(models.Network, column='vlan_name')])
     vlan_id = NoValidateSelectField('Vlan id', choices=[])
     description = TextAreaField('Description')
     prefix = NoValidateSelectField('Prefix', choices=[])
@@ -58,11 +60,12 @@ class NetworkForm(FlaskForm):
         self.scope_id.choices = utils.get_model_choices(models.NetworkScope, attr='name')
 
 
-class HostForm(FlaskForm):
+class HostForm(CSEntryForm):
     name = StringField('Hostname',
                        description='hostname must be 2-20 characters long and contain only letters, numbers and dash',
                        validators=[validators.InputRequired(),
-                                   validators.Regexp(models.HOST_NAME_RE)],
+                                   validators.Regexp(models.HOST_NAME_RE),
+                                   Unique(models.Host)],
                        filters=[utils.lowercase_field])
     type = SelectField('Type', choices=utils.get_choices(('Virtual', 'Physical')))
     description = TextAreaField('Description')
@@ -74,7 +77,9 @@ class HostForm(FlaskForm):
     interface_name = StringField(
         'Interface name',
         description='name must be 2-20 characters long and contain only letters, numbers and dash',
-        validators=[validators.InputRequired(), validators.Regexp(models.HOST_NAME_RE)],
+        validators=[validators.InputRequired(),
+                    validators.Regexp(models.HOST_NAME_RE),
+                    Unique(models.Interface)],
         filters=[utils.lowercase_field])
     mac_id = SelectField('MAC', coerce=utils.coerce_to_str_or_none)
     tags = SelectMultipleField('Tags', coerce=utils.coerce_to_str_or_none,
diff --git a/app/templates/inventory/edit_item.html b/app/templates/inventory/edit_item.html
index 2e177fd..1070849 100644
--- a/app/templates/inventory/edit_item.html
+++ b/app/templates/inventory/edit_item.html
@@ -16,7 +16,7 @@
 {% block items_main %}
   <form id="itemForm" method="POST">
     {{ form.hidden_tag() }}
-    {% if form.ics_id.data.startswith(config['TEMPORARY_ICS_ID']) %}
+    {% if form.ics_id.data.startswith(config['TEMPORARY_ICS_ID'])  or form.ics_id.errors %}
       {{ render_field(form.ics_id) }}
     {% else %}
       {{ render_field(form.ics_id, readonly=True) }}
diff --git a/app/tokens.py b/app/tokens.py
index ef1650f..8dbb2d2 100644
--- a/app/tokens.py
+++ b/app/tokens.py
@@ -37,7 +37,7 @@ def is_token_in_blacklist(decoded_token):
     jti = decoded_token['jti']
     try:
         models.Token.query.filter_by(jti=jti).one()
-    except sa.exc.NoResultFound:
+    except sa.orm.exc.NoResultFound:
         return True
     return False
 
diff --git a/app/validators.py b/app/validators.py
new file mode 100644
index 0000000..3beb4db
--- /dev/null
+++ b/app/validators.py
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+"""
+app.validators
+~~~~~~~~~~~~~~
+
+This module defines extra field validators
+
+:copyright: (c) 2017 European Spallation Source ERIC
+:license: BSD 2-Clause, see LICENSE for more details.
+
+"""
+import ipaddress
+import sqlalchemy as sa
+from wtforms import ValidationError
+
+
+class IPNetwork:
+    """Validates an IP network.
+
+    :param message: the error message to raise in case of a validation error
+    """
+    def __init__(self, message=None):
+        self.message = message
+
+    def __call__(self, form, field):
+        try:
+            ipaddress.ip_network(field.data, strict=True)
+        except (ipaddress.AddressValueError, ipaddress.NetmaskValueError, ValueError):
+            if self.message is None:
+                self.message = field.gettext('Invalid IP network.')
+            raise ValidationError(self.message)
+
+
+# Inspired by flask-admin Unique validator
+# Modified to use flask-sqlalchemy query on Model
+class Unique(object):
+    """Checks field value unicity against specified table field
+
+    :param model: the model to check unicity against
+    :param column: the unique column
+    :param message: the error message
+    """
+
+    def __init__(self, model, column='name', message=None):
+        self.model = model
+        self.column = column
+        self.message = message
+
+    def __call__(self, form, field):
+        # databases allow multiple NULL values for unique columns
+        if field.data is None:
+            return
+        try:
+            kwargs = {self.column: field.data}
+            obj = self.model.query.filter_by(**kwargs).one()
+            if not hasattr(form, '_obj') or not form._obj == obj:
+                if self.message is None:
+                    self.message = field.gettext('Already exists.')
+                raise ValidationError(self.message)
+        except sa.orm.exc.NoResultFound:
+            pass
-- 
GitLab