From 9295df9b3503b3bd46921bb15e68f70038656c3a Mon Sep 17 00:00:00 2001
From: Benjamin Bertrand <benjamin.bertrand@esss.se>
Date: Wed, 20 Dec 2017 09:16:57 +0100
Subject: [PATCH] Add created_at/updated_at/user_id columns to tables

- This will help tracking who and what was created / changed
- Move fetch_current_user_id to utils.py
- Make ItemComment inherit from CreatedMixin
---
 app/inventory/views.py                 |  9 ++---
 app/models.py                          | 53 +++++++++++++++-----------
 app/plugins.py                         | 24 ++----------
 app/templates/inventory/view_item.html |  8 ++--
 app/utils.py                           | 16 ++++++++
 5 files changed, 59 insertions(+), 51 deletions(-)

diff --git a/app/inventory/views.py b/app/inventory/views.py
index 8f90f63..fd4b466 100644
--- a/app/inventory/views.py
+++ b/app/inventory/views.py
@@ -12,7 +12,7 @@ This module implements the inventory blueprint.
 import sqlalchemy as sa
 from flask import (Blueprint, render_template, jsonify, session,
                    request, redirect, url_for, flash, current_app)
-from flask_login import login_required, current_user
+from flask_login import login_required
 from .forms import AttributeForm, ItemForm, CommentForm
 from ..extensions import db
 from ..decorators import login_groups_accepted
@@ -24,11 +24,11 @@ bp = Blueprint('inventory', __name__)
 @bp.route('/_retrieve_items')
 @login_required
 def retrieve_items():
-    items = models.Item.query.order_by(models.Item._created)
+    items = models.Item.query.order_by(models.Item.created_at)
     data = [[item.id,
              item.ics_id,
-             utils.format_field(item._created),
-             utils.format_field(item._updated),
+             utils.format_field(item.created_at),
+             utils.format_field(item.updated_at),
              item.serial_number,
              utils.format_field(item.manufacturer),
              utils.format_field(item.model),
@@ -92,7 +92,6 @@ def comment_item(ics_id):
     form = CommentForm()
     if form.validate_on_submit():
         comment = models.ItemComment(text=form.text.data,
-                                     user_id=current_user.id,
                                      item_id=item.id)
         db.session.add(comment)
         db.session.commit()
diff --git a/app/models.py b/app/models.py
index 3747083..39dc16f 100644
--- a/app/models.py
+++ b/app/models.py
@@ -12,6 +12,7 @@ This module implements the models used in the app.
 import ipaddress
 import qrcode
 import sqlalchemy as sa
+from sqlalchemy.ext.declarative import declared_attr
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.orm import validates
 from sqlalchemy.ext.associationproxy import association_proxy
@@ -106,7 +107,6 @@ class User(db.Model, UserMixin):
     groups = association_proxy('grp', 'name',
                                creator=find_or_create_group)
     tokens = db.relationship("Token", backref="user")
-    comments = db.relationship('ItemComment', backref='user')
 
     def get_id(self):
         """Return the user id as unicode
@@ -221,15 +221,32 @@ class Status(QRCodeMixin, db.Model):
     items = db.relationship('Item', back_populates='status')
 
 
-class Item(db.Model):
+class CreatedMixin:
+    id = db.Column(db.Integer, primary_key=True)
+    created_at = db.Column(db.DateTime, default=utcnow())
+    updated_at = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow())
+
+    # Using ForeignKey and relationship in mixin requires the @declared_attr decorator
+    # See http://docs.sqlalchemy.org/en/latest/orm/extensions/declarative/mixins.html
+    @declared_attr
+    def user_id(cls):
+        return db.Column(db.Integer, db.ForeignKey('user_account.id'),
+                         nullable=False, default=utils.fetch_current_user_id)
+
+    @declared_attr
+    def user(cls):
+        return db.relationship('User')
+
+
+class Item(CreatedMixin, db.Model):
     __versioned__ = {
-        'exclude': ['_created', 'ics_id', 'serial_number',
+        'exclude': ['created_at', 'ics_id', 'serial_number',
                     'manufacturer_id', 'model_id']
     }
 
+    # WARNING! Inheriting id from CreatedMixin doesn't play well with
+    # SQLAlchemy-Continuum. It has to be defined here.
     id = db.Column(db.Integer, primary_key=True)
-    _created = db.Column(db.DateTime, default=utcnow())
-    _updated = db.Column(db.DateTime, default=utcnow(), onupdate=utcnow())
     ics_id = db.Column(db.Text, unique=True, nullable=False, index=True)
     serial_number = db.Column(db.Text, nullable=False)
     manufacturer_id = db.Column(db.Integer, db.ForeignKey('manufacturer.id'))
@@ -278,8 +295,8 @@ class Item(db.Model):
             'model': utils.format_field(self.model),
             'location': utils.format_field(self.location),
             'status': utils.format_field(self.status),
-            'updated': utils.format_field(self._updated),
-            'created': utils.format_field(self._created),
+            'updated_at': utils.format_field(self.updated_at),
+            'created_at': utils.format_field(self.created_at),
             'parent': utils.format_field(self.parent),
         }
         if long:
@@ -300,7 +317,7 @@ class Item(db.Model):
             else:
                 parent = Item.query.get(version.parent_id)
             versions.append({
-                'updated': utils.format_field(version._updated),
+                'updated_at': utils.format_field(version.updated_at),
                 'location': utils.format_field(version.location),
                 'status': utils.format_field(version.status),
                 'parent': utils.format_field(parent),
@@ -308,16 +325,12 @@ class Item(db.Model):
         return versions
 
 
-class ItemComment(db.Model):
-    id = db.Column(db.Integer, primary_key=True)
-    timestamp = db.Column(db.DateTime, default=utcnow(), index=True)
+class ItemComment(CreatedMixin, db.Model):
     text = db.Column(db.Text, nullable=False)
-    user_id = db.Column(db.Integer, db.ForeignKey('user_account.id'), nullable=False)
     item_id = db.Column(db.Integer, db.ForeignKey('item.id'), nullable=False)
 
 
-class Network(db.Model):
-    id = db.Column(db.Integer, primary_key=True)
+class Network(CreatedMixin, db.Model):
     vlan_name = db.Column(CIText, nullable=False, unique=True)
     vlan_id = db.Column(db.Integer, nullable=False, unique=True)
     address = db.Column(postgresql.CIDR, nullable=False, unique=True)
@@ -455,8 +468,7 @@ class Tag(QRCodeMixin, db.Model):
     pass
 
 
-class Host(db.Model):
-    id = db.Column(db.Integer, primary_key=True)
+class Host(CreatedMixin, db.Model):
     name = db.Column(db.Text, nullable=False, unique=True)
     type = db.Column(db.Text)
     description = db.Column(db.Text)
@@ -479,8 +491,7 @@ class Host(db.Model):
         return lower_string
 
 
-class Interface(db.Model):
-    id = db.Column(db.Integer, primary_key=True)
+class Interface(CreatedMixin, db.Model):
     network_id = db.Column(db.Integer, db.ForeignKey('network.id'), nullable=False)
     ip = db.Column(postgresql.INET, nullable=False, unique=True)
     name = db.Column(db.Text, nullable=False, unique=True)
@@ -565,8 +576,7 @@ class Mac(db.Model):
         return d
 
 
-class Cname(db.Model):
-    id = db.Column(db.Integer, primary_key=True)
+class Cname(CreatedMixin, db.Model):
     name = db.Column(db.Text, nullable=False, unique=True)
     interface_id = db.Column(db.Integer, db.ForeignKey('interface.id'), nullable=False)
 
@@ -581,9 +591,8 @@ class Cname(db.Model):
         }
 
 
-class NetworkScope(db.Model):
+class NetworkScope(CreatedMixin, db.Model):
     __tablename__ = 'network_scope'
-    id = db.Column(db.Integer, primary_key=True)
     name = db.Column(CIText, nullable=False, unique=True)
     first_vlan = db.Column(db.Integer, nullable=False, unique=True)
     last_vlan = db.Column(db.Integer, nullable=False, unique=True)
diff --git a/app/plugins.py b/app/plugins.py
index d51cd96..fb09e7e 100644
--- a/app/plugins.py
+++ b/app/plugins.py
@@ -20,31 +20,15 @@ The `user_id` column is automatically populated when the transaction object is c
 
 """
 
-from flask.globals import _app_ctx_stack, _request_ctx_stack
 from sqlalchemy_continuum.plugins import Plugin
-from flask_login import current_user
-from flask_jwt_extended import get_current_user
-
-
-def fetch_current_user_id():
-    # Return None if we are outside of request context.
-    if _app_ctx_stack.top is None or _request_ctx_stack.top is None:
-        return None
-    # Try to get the user from both flask_jwt_extended and flask_login
-    user = get_current_user() or current_user
-    try:
-        return user.id
-    except AttributeError:
-        return None
+from . import utils
 
 
 class FlaskUserPlugin(Plugin):
-    def __init__(
-        self,
-        current_user_id_factory=None,
-    ):
+
+    def __init__(self, current_user_id_factory=None):
         self.current_user_id_factory = (
-            fetch_current_user_id if current_user_id_factory is None
+            utils.fetch_current_user_id if current_user_id_factory is None
             else current_user_id_factory
         )
 
diff --git a/app/templates/inventory/view_item.html b/app/templates/inventory/view_item.html
index e489b41..6c1eb58 100644
--- a/app/templates/inventory/view_item.html
+++ b/app/templates/inventory/view_item.html
@@ -17,9 +17,9 @@
     <dt class="col-sm-3">ICS id</dt>
     <dd class="col-sm-9">{{ item.ics_id }}</dd>
     <dt class="col-sm-3">Created</dt>
-    <dd class="col-sm-9">{{ format_datetime(item._created) }}</dd>
+    <dd class="col-sm-9">{{ format_datetime(item.created_at) }}</dd>
     <dt class="col-sm-3">Updated</dt>
-    <dd class="col-sm-9">{{ format_datetime(item._updated) }}</dd>
+    <dd class="col-sm-9">{{ format_datetime(item.updated_at) }}</dd>
     <dt class="col-sm-3">Serial number</dt>
     <dd class="col-sm-9">{{ item.serial_number }}</dd>
     <dt class="col-sm-3">Manufacturer</dt>
@@ -49,7 +49,7 @@
   {% for comment in item.comments %}
   <div class="card border-light mb-3">
     <div class="card-header">
-      {{ comment.user }} commented on {{ format_datetime(comment.timestamp) }}
+      {{ comment.user }} commented on {{ format_datetime(comment.created_at) }}
     </div>
     <div class="card-body item-comment">{{ comment.text }}</div>
   </div>
@@ -73,7 +73,7 @@
     <tbody>
       {% for version in item.history() %}
       <tr>
-        <td>{{ version['updated'] }}</td>
+        <td>{{ version['updated_at'] }}</td>
         <td>{{ version['location'] }}</td>
         <td>{{ version['status'] }}</td>
         <td>{{ link_to_item(version['parent']) }}</td>
diff --git a/app/utils.py b/app/utils.py
index 67128fa..7f9c8b8 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -13,6 +13,22 @@ import base64
 import datetime
 import io
 import sqlalchemy as sa
+from flask.globals import _app_ctx_stack, _request_ctx_stack
+from flask_login import current_user
+from flask_jwt_extended import get_current_user
+
+
+def fetch_current_user_id():
+    """Retrieve the user_id from flask_jwt_extended (API) or flask_login (web UI)"""
+    # Return None if we are outside of request context.
+    if _app_ctx_stack.top is None or _request_ctx_stack.top is None:
+        return None
+    # Try to get the user from both flask_jwt_extended and flask_login
+    user = get_current_user() or current_user
+    try:
+        return user.id
+    except AttributeError:
+        return None
 
 
 class CSEntryError(Exception):
-- 
GitLab