Skip to content
Snippets Groups Projects
  • Benjamin Bertrand's avatar
    29c556bf
    Add Unique validator · 29c556bf
    Benjamin Bertrand authored
    To check if the unique object is the one being edited, we need to store
    the object used to create the form in the form (_obj).
    This is inspired from flask-admin.
    29c556bf
    History
    Add Unique validator
    Benjamin Bertrand authored
    To check if the unique object is the one being edited, we need to store
    the object used to create the form in the form (_obj).
    This is inspired from flask-admin.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
tokens.py 2.94 KiB
# -*- coding: utf-8 -*-
"""
app.api.tokens
~~~~~~~~~~~~~~

This module implements helper functions to manipulate JWT.

:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.

"""
import sqlalchemy as sa
from datetime import datetime
from flask import current_app
from flask_jwt_extended import decode_token, create_access_token
from .extensions import db, jwt
from . import models, utils


@jwt.user_loader_callback_loader
def user_loader_callback(identity):
    """User loader callback for flask-jwt-extended

    :param str identity: identity from the token (user_id)
    :returns: corresponding user object or None
    """
    return models.User.query.get(int(identity))


@jwt.token_in_blacklist_loader
def is_token_in_blacklist(decoded_token):
    """Token blacklist loader for flask-jwt-extended

    All created tokens are added to the database. If a token is not found
    in the database, it is considered blacklisted / revoked.
    """
    jti = decoded_token['jti']
    try:
        models.Token.query.filter_by(jti=jti).one()
    except sa.orm.exc.NoResultFound:
        return True
    return False


def generate_access_token(identity, fresh=False, expires_delta=None, description=None):
    """Create a new access token and store it in the database"""
    token = create_access_token(identity, fresh=fresh, expires_delta=expires_delta)
    save_token(token, description=description)
    return token


def save_token(encoded_token, description=None):
    """Add a new token to the database"""
    identity_claim = current_app.config['JWT_IDENTITY_CLAIM']
    decoded_token = decode_token(encoded_token)
    jti = decoded_token['jti']
    token_type = decoded_token['type']
    user_id = int(decoded_token[identity_claim])
    iat = datetime.fromtimestamp(decoded_token['iat'])
    try:
        expires = datetime.fromtimestamp(decoded_token['exp'])
    except KeyError:
        expires = None
    db_token = models.Token(
        jti=jti,
        token_type=token_type,
        user_id=user_id,
        issued_at=iat,
        expires=expires,
        description=description,
    )
    db.session.add(db_token)
    db.session.commit()


def revoke_token(token_id, user_id):
    """Revoke the given token

    Raises a CSEntryError if the token does not exist in the database
    or if it doesn't belong to the given user
    """
    token = models.Token.query.get(token_id)
    if token is None:
        raise utils.CSEntryError(f'Could not find the token {token_id}', status_code=404)
    if token.user_id != user_id:
        raise utils.CSEntryError(f"Token {token_id} doesn't belong to user {user_id}", status_code=401)
    db.session.delete(token)
    db.session.commit()


def prune_database():
    """Delete tokens that have expired from the database"""
    now = datetime.now()
    expired = models.Token.query.filter(models.Token.expires < now).all()
    for token in expired:
        db.session.delete(token)
    db.session.commit()