From aeaab0a70bcc9bd41b6a3ec17360e09ef0c3cc1a Mon Sep 17 00:00:00 2001
From: Benjamin Bertrand <benjamin.bertrand@esss.se>
Date: Wed, 19 Jul 2017 14:13:28 +0200
Subject: [PATCH] Fix admin view

- restrict admin view to admin users only
- the models shall only take optional parameters for flask-admin to work
- the models shall implement the __str__ method to be displayed properly
  in the admin view
---
 app/admin/__init__.py |  0
 app/admin/views.py    | 27 +++++++++++++++++++++++++++
 app/factory.py        | 19 ++++++++++---------
 app/models.py         | 12 +++++++++---
 4 files changed, 46 insertions(+), 12 deletions(-)
 create mode 100644 app/admin/__init__.py
 create mode 100644 app/admin/views.py

diff --git a/app/admin/__init__.py b/app/admin/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/admin/views.py b/app/admin/views.py
new file mode 100644
index 0000000..465f8d6
--- /dev/null
+++ b/app/admin/views.py
@@ -0,0 +1,27 @@
+# -*- coding: utf-8 -*-
+"""
+app.admin.views
+~~~~~~~~~~~~~~~
+
+This module customizes the admin views.
+
+:copyright: (c) 2017 European Spallation Source ERIC
+:license: BSD 2-Clause, see LICENSE for more details.
+
+"""
+from flask_admin.contrib import sqla
+from flask_login import current_user
+
+
+class AdminModelView(sqla.ModelView):
+
+    def is_accessible(self):
+        return current_user.is_authenticated and current_user.is_admin
+
+
+# Here is an example to customize an admin view
+# class ItemAdmin(AdminModelView):
+#     form_columns = ['serial_number', 'vendor', 'model']
+#
+#     def __init__(self, session):
+#         super().__init__(Item, session)
diff --git a/app/factory.py b/app/factory.py
index 941db72..b2218cd 100644
--- a/app/factory.py
+++ b/app/factory.py
@@ -11,10 +11,10 @@ Create the WSGI application.
 """
 import sqlalchemy as sa
 from flask import Flask
-from flask_admin.contrib.sqla import ModelView
 from . import settings
 from .extensions import db, migrate, login_manager, ldap_manager, bootstrap, admin, mail, jwt
 from .models import User, Role, Action, Vendor, Model, Location, Status, Item
+from .admin.views import AdminModelView
 from .main.views import bp as main
 from .users.views import bp as users
 from .api.items import bp as api
@@ -88,14 +88,15 @@ def create_app():
     jwt.init_app(app)
 
     admin.init_app(app)
-    admin.add_view(ModelView(Role, db.session))
-    admin.add_view(ModelView(User, db.session))
-    admin.add_view(ModelView(Action, db.session))
-    admin.add_view(ModelView(Vendor, db.session))
-    admin.add_view(ModelView(Model, db.session))
-    admin.add_view(ModelView(Location, db.session))
-    admin.add_view(ModelView(Status, db.session))
-    admin.add_view(ModelView(Item, db.session))
+    admin.add_view(AdminModelView(Role, db.session))
+    admin.add_view(AdminModelView(User, db.session))
+    admin.add_view(AdminModelView(Action, db.session))
+    admin.add_view(AdminModelView(Vendor, db.session))
+    admin.add_view(AdminModelView(Model, db.session))
+    admin.add_view(AdminModelView(Location, db.session))
+    admin.add_view(AdminModelView(Status, db.session))
+    admin.add_view(AdminModelView(Item, db.session))
+    # admin.add_view(ItemAdmin(db.session))
 
     app.register_blueprint(main)
     app.register_blueprint(users)
diff --git a/app/models.py b/app/models.py
index 075ffdb..ece4168 100644
--- a/app/models.py
+++ b/app/models.py
@@ -132,6 +132,9 @@ class QRCodeMixin:
         data = ','.join([self.code, str(self.id), self.name])
         return qrcode.make(data, version=1, box_size=5)
 
+    def __str__(self):
+        return self.name
+
 
 class Action(QRCodeMixin, db.Model):
     code = 'AC'
@@ -161,8 +164,8 @@ class Item(db.Model):
     id = db.Column(db.Integer, primary_key=True)
     _created = db.Column(db.DateTime, default=db.func.now())
     _updated = db.Column(db.DateTime, default=db.func.now(), onupdate=db.func.now())
-    name = db.Column(db.String(100))
     serial_number = db.Column(db.String(100), nullable=False)
+    name = db.Column(db.String(100))
     hash = db.Column(GUID, unique=True)
     vendor_id = db.Column(db.Integer, db.ForeignKey('vendor.id'))
     model_id = db.Column(db.Integer, db.ForeignKey('model.id'))
@@ -176,11 +179,14 @@ class Item(db.Model):
     status = db.relationship('Status', back_populates='items')
     children = db.relationship('Item', backref=db.backref('parent', remote_side=[id]))
 
-    def __init__(self, name, serial_number, vendor, model, location, status):
-        self.name = name
+    def __init__(self, serial_number='', name=None, vendor=None, model=None, location=None, status=None):
         self.serial_number = serial_number
+        self.name = name
         self.vendor = vendor
         self.model = model
         self.location = location
         self.status = status
         self.hash = utils.compute_hash(serial_number)
+
+    def __str__(self):
+        return self.serial_number
-- 
GitLab