Skip to content
Snippets Groups Projects
utils.py 3.74 KiB
Newer Older
Benjamin Bertrand's avatar
Benjamin Bertrand committed
# -*- coding: utf-8 -*-
"""
app.api.utils
~~~~~~~~~~~~~

This module implements useful functions for the API.

:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.

"""
import urllib.parse
Benjamin Bertrand's avatar
Benjamin Bertrand committed
import sqlalchemy as sa
from flask import current_app, jsonify, request
from ..extensions import db
from .. import utils


def commit():
    try:
        db.session.commit()
    except (sa.exc.IntegrityError, sa.exc.DataError) as e:
        db.session.rollback()
        raise utils.CSEntryError(str(e), status_code=422)


def build_pagination_header(pagination, base_url, **kwargs):
    """Return the X-Total-Count and Link header information

    :param pagination: flask_sqlalchemy Pagination class instance
    :param base_url: request base_url
    :param kwargs: extra query string parameters (without page and per_page)
    :returns: dict with X-Total-Count and Link keys
    """
    header = {'X-Total-Count': pagination.total}
    links = []
    if pagination.page > 1:
        params = urllib.parse.urlencode({'per_page': pagination.per_page,
                                         'page': 1,
                                         **kwargs})
        links.append(f'<{base_url}?{params}>; rel="first"')
    if pagination.has_prev:
        params = urllib.parse.urlencode({'per_page': pagination.per_page,
                                         'page': pagination.prev_num,
                                         **kwargs})
        links.append(f'<{base_url}?{params}>; rel="prev"')
    if pagination.has_next:
        params = urllib.parse.urlencode({'per_page': pagination.per_page,
                                         'page': pagination.next_num,
                                         **kwargs})
        links.append(f'<{base_url}?{params}>; rel="next"')
    if pagination.pages > pagination.page:
        params = urllib.parse.urlencode({'per_page': pagination.per_page,
                                         'page': pagination.pages,
                                         **kwargs})
        links.append(f'<{base_url}?{params}>; rel="last"')
    if links:
        header['Link'] = ', '.join(links)
    return header


def get_generic_model(model, order_by=None, query=None):
    """Return data from model as json

    :param model: model class
    :param order_by: column to order the result by
    :param query: optional query to use (for more complex queries)
    :returns: data from model as json
    """
    kwargs = request.args.to_dict()
    page = int(kwargs.pop('page', 1))
    per_page = int(kwargs.pop('per_page', 20))
    if query is None:
        query = utils.get_query(model.query, **kwargs)
        if order_by is None:
            order_by = getattr(model, 'name')
        query = query.order_by(order_by)
    pagination = query.paginate(page, per_page)
    data = [item.to_dict() for item in pagination.items]
    header = build_pagination_header(pagination, request.base_url, **kwargs)
    return jsonify(data), 200, header
def create_generic_model(model, mandatory_fields=('name',), **kwargs):
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    data = request.get_json()
    if data is None:
        raise utils.CSEntryError('Body should be a JSON object')
    current_app.logger.debug(f'Received: {data}')
    data.update(kwargs)
Benjamin Bertrand's avatar
Benjamin Bertrand committed
    for mandatory_field in mandatory_fields:
        if mandatory_field not in data:
            raise utils.CSEntryError(f"Missing mandatory field '{mandatory_field}'", status_code=422)
    try:
        instance = model(**data)
    except TypeError as e:
        message = str(e).replace('__init__() got an ', '')
        raise utils.CSEntryError(message, status_code=422)
    except ValueError as e:
        raise utils.CSEntryError(str(e), status_code=422)
    db.session.add(instance)
    commit()
    return jsonify(instance.to_dict()), 201