diff --git a/app/api/inventory.py b/app/api/inventory.py index 1f86909f6f678c996e4c7214ff6a92e431e5732a..1b89b5070bb3757a3b6e200fe1cd5b3d9dafa450 100644 --- a/app/api/inventory.py +++ b/app/api/inventory.py @@ -10,9 +10,9 @@ This module implements the inventory API. """ from flask import Blueprint, jsonify, request, current_app -from flask_jwt_extended import jwt_required +from flask_login import login_required from .. import utils, models -from ..decorators import jwt_groups_accepted +from ..decorators import login_groups_accepted from .utils import commit, create_generic_model, get_generic_model bp = Blueprint("inventory_api", __name__) @@ -33,7 +33,7 @@ def get_item_by_id_or_ics_id(id_): @bp.route("/items") -@jwt_required +@login_required def get_items(): """Return items @@ -43,7 +43,7 @@ def get_items(): @bp.route("/items/<id_>") -@jwt_required +@login_required def get_item(id_): """Retrieve item by id or ICS id @@ -56,8 +56,7 @@ def get_item(id_): @bp.route("/items", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin", "create") +@login_groups_accepted("admin", "create") def create_item(): """Register a new item @@ -82,8 +81,7 @@ def create_item(): @bp.route("/items/<id_>", methods=["PATCH"]) -@jwt_required -@jwt_groups_accepted("admin", "create") +@login_groups_accepted("admin", "create") def patch_item(id_): """Patch an existing item @@ -143,7 +141,7 @@ def patch_item(id_): @bp.route("/items/<id_>/comments") -@jwt_required +@login_required def get_item_comments(id_): """Get item comments @@ -156,8 +154,7 @@ def get_item_comments(id_): @bp.route("/items/<id_>/comments", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin", "create") +@login_groups_accepted("admin", "create") def create_item_comment(id_): """Create a comment on item @@ -173,7 +170,7 @@ def create_item_comment(id_): @bp.route("/actions") -@jwt_required +@login_required def get_actions(): """Get actions @@ -183,7 +180,7 @@ def get_actions(): @bp.route("/manufacturers") -@jwt_required +@login_required def get_manufacturers(): """Get manufacturers @@ -193,8 +190,7 @@ def get_manufacturers(): @bp.route("/manufacturers", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin", "create") +@login_groups_accepted("admin", "create") def create_manufacturer(): """Create a new manufacturer @@ -207,7 +203,7 @@ def create_manufacturer(): @bp.route("/models") -@jwt_required +@login_required def get_models(): """Get models @@ -217,8 +213,7 @@ def get_models(): @bp.route("/models", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin", "create") +@login_groups_accepted("admin", "create") def create_model(): """Create a new model @@ -231,7 +226,7 @@ def create_model(): @bp.route("/locations") -@jwt_required +@login_required def get_locations(): """Get locations @@ -241,8 +236,7 @@ def get_locations(): @bp.route("/locations", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin", "create") +@login_groups_accepted("admin", "create") def create_locations(): """Create a new location @@ -255,7 +249,7 @@ def create_locations(): @bp.route("/statuses") -@jwt_required +@login_required def get_status(): """Get statuses @@ -265,8 +259,7 @@ def get_status(): @bp.route("/statuses", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin", "create") +@login_groups_accepted("admin", "create") def create_status(): """Create a new status diff --git a/app/api/network.py b/app/api/network.py index b89b1882abd97cd3a2917a9523208a11fc8546a0..2f728af0d31dfd8f2849115a5fbb4260a7e83d35 100644 --- a/app/api/network.py +++ b/app/api/network.py @@ -10,16 +10,16 @@ This module implements the network API. """ from flask import Blueprint, request -from flask_jwt_extended import jwt_required +from flask_login import login_required from .. import models -from ..decorators import jwt_groups_accepted +from ..decorators import login_groups_accepted from .utils import get_generic_model, create_generic_model bp = Blueprint("network_api", __name__) @bp.route("/scopes") -@jwt_required +@login_required def get_scopes(): """Return network scopes @@ -29,8 +29,7 @@ def get_scopes(): @bp.route("/scopes", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin") +@login_groups_accepted("admin") def create_scope(): """Create a new network scope @@ -50,7 +49,7 @@ def create_scope(): @bp.route("/networks") -@jwt_required +@login_required def get_networks(): """Return networks @@ -60,8 +59,7 @@ def get_networks(): @bp.route("/networks", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin") +@login_groups_accepted("admin") def create_network(): """Create a new network @@ -92,7 +90,7 @@ def create_network(): @bp.route("/interfaces") -@jwt_required +@login_required def get_interfaces(): """Return interfaces @@ -120,8 +118,7 @@ def get_interfaces(): @bp.route("/interfaces", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin", "create") +@login_groups_accepted("admin", "create") def create_interface(): """Create a new interface @@ -142,7 +139,7 @@ def create_interface(): @bp.route("/groups") -@jwt_required +@login_required def get_ansible_groups(): """Return ansible groups @@ -152,8 +149,7 @@ def get_ansible_groups(): @bp.route("/groups", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin") +@login_groups_accepted("admin") def create_ansible_groups(): """Create a new Ansible group @@ -166,7 +162,7 @@ def create_ansible_groups(): @bp.route("/hosts") -@jwt_required +@login_required def get_hosts(): """Return hosts @@ -176,8 +172,7 @@ def get_hosts(): @bp.route("/hosts", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin", "create") +@login_groups_accepted("admin", "create") def create_host(): """Create a new host @@ -194,7 +189,7 @@ def create_host(): @bp.route("/macs") -@jwt_required +@login_required def get_macs(): """Return mac addresses @@ -204,8 +199,7 @@ def get_macs(): @bp.route("/macs", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin", "create") +@login_groups_accepted("admin", "create") def create_macs(): """Create a new mac address @@ -218,7 +212,7 @@ def create_macs(): @bp.route("/domains") -@jwt_required +@login_required def get_domains(): """Return domains @@ -228,8 +222,7 @@ def get_domains(): @bp.route("/domains", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin") +@login_groups_accepted("admin") def create_domain(): """Create a new domain @@ -241,7 +234,7 @@ def create_domain(): @bp.route("/cnames") -@jwt_required +@login_required def get_cnames(): """Return cnames @@ -262,8 +255,7 @@ def get_cnames(): @bp.route("/cnames", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin") +@login_groups_accepted("admin") def create_cname(): """Create a new cname diff --git a/app/api/user.py b/app/api/user.py index 3d8c56f058caf355fd14b598eb46b8cf7158e2e9..516d9fb78cae943391f061a0aad28fda51d9ecd1 100644 --- a/app/api/user.py +++ b/app/api/user.py @@ -11,9 +11,9 @@ This module implements the user API. """ from flask import current_app, Blueprint, jsonify, request from flask_ldap3_login import AuthenticationResponseStatus -from flask_jwt_extended import jwt_required, get_current_user +from flask_login import login_required, current_user from ..extensions import ldap_manager -from ..decorators import jwt_groups_accepted +from ..decorators import login_groups_accepted from .. import utils, tokens, models from .utils import get_generic_model, create_generic_model @@ -21,7 +21,7 @@ bp = Blueprint("user_api", __name__) @bp.route("/users") -@jwt_required +@login_required def get_users(): """Return users information @@ -31,19 +31,17 @@ def get_users(): @bp.route("/profile") -@jwt_required +@login_required def get_user_profile(): """Return the current user profile .. :quickref: User; Get current user profile """ - user = get_current_user() - return jsonify(user.to_dict()), 200 + return jsonify(current_user.to_dict()), 200 @bp.route("/users", methods=["POST"]) -@jwt_required -@jwt_groups_accepted("admin") +@login_groups_accepted("admin") def create_user(): """Create a new user diff --git a/app/decorators.py b/app/decorators.py index 634cffefcc04c12c31a69fd5f997b847c7cc5936..26669608f529f834d5e4becb0db25aa34e90528a 100644 --- a/app/decorators.py +++ b/app/decorators.py @@ -10,51 +10,15 @@ This module defines some useful decorators. """ from functools import wraps -from flask import current_app, abort +from flask import current_app, abort, g from flask_login import current_user -from flask_jwt_extended import get_current_user from .utils import CSEntryError -def jwt_groups_accepted(*groups): - """Decorator which specifies that a user must have at least one of the specified groups. - - This shall be used for users logged in using a JWT (API). - - Example:: - @bp.route('/models', methods=['POST']) - @jwt_required - @jwt_groups_accepted('admin', 'create') - def create_model(): - return create() - - The current user must be in either 'admin' or 'create' group - to access this route. - - :param groups: accepted groups - """ - - def wrapper(fn): - @wraps(fn) - def decorated_view(*args, **kwargs): - user = get_current_user() - if user is None: - raise CSEntryError("Invalid indentity", status_code=403) - if not user.is_member_of_one_group(groups): - raise CSEntryError( - "User doesn't have the required group", status_code=403 - ) - return fn(*args, **kwargs) - - return decorated_view - - return wrapper - - def login_groups_accepted(*groups): """Decorator which specifies that a user must have at least one of the specified groups. - This shall be used for users logged in using a cookie (web UI). + This can be used for users logged in using a cookie (web UI) or JWT (API). Example:: @bp.route('/models', methods=['POST']) @@ -62,8 +26,9 @@ def login_groups_accepted(*groups): def create_model(): return create() - The current user must be in either 'admin' or 'create' group - to access this route. + The current user must be in either 'admin' or 'create' group + to access the /models route. + This checks that the user is logged in. There is no need to use the @login_required decorator. @@ -76,7 +41,12 @@ def login_groups_accepted(*groups): if not current_user.is_authenticated: return current_app.login_manager.unauthorized() if not current_user.is_member_of_one_group(groups): - abort(403) + if g.get("login_via_request"): + raise CSEntryError( + "User doesn't have the required group", status_code=403 + ) + else: + abort(403) return fn(*args, **kwargs) return decorated_view diff --git a/app/models.py b/app/models.py index b757935c767fdc688addb470d5470062ba90e7bc..15cfce8054c5dbe58d88789bbab977220b6a756d 100644 --- a/app/models.py +++ b/app/models.py @@ -22,7 +22,7 @@ from sqlalchemy.orm import validates from sqlalchemy_continuum import make_versioned, version_class from citext import CIText from flask import current_app -from flask_login import UserMixin +from flask_login import UserMixin, current_user from wtforms import ValidationError from rq import Queue from .extensions import db, login_manager, ldap_manager, cache @@ -345,8 +345,7 @@ class QRCodeMixin: def is_user_favorite(self): """Return True if the attribute is part of the current user favorites""" - user = utils.cse_current_user() - return user in self.favorite_users + return current_user in self.favorite_users def __str__(self): return self.name @@ -659,8 +658,13 @@ class Network(CreatedMixin, db.Model): """Ensure the interface IP is in the network range""" addr, net = self.ip_in_network(interface.ip, self.address) # Admin user can create IP outside the defined range - user = utils.cse_current_user() - if user is None or not user.is_admin: + try: + # current_user is a local proxy and is not + # valid outside of a request context. + is_admin = current_user.is_admin + except AttributeError: + is_admin = False + if not is_admin: if addr < self.first or addr > self.last: raise ValidationError( f"IP address {interface.ip} is not in range {self.first} - {self.last}" diff --git a/app/tasks.py b/app/tasks.py index 3015ff2caf52c7f7e7e6cb6df88656c5d256ea0c..3f5fca6d151ccb6b852ac3785e46b2ada9d409d3 100644 --- a/app/tasks.py +++ b/app/tasks.py @@ -13,9 +13,10 @@ import time import traceback import tower_cli from flask import current_app +from flask_login import current_user from rq import Worker, get_current_job from .extensions import db -from . import utils, models +from . import models class TaskWorker(Worker): @@ -107,8 +108,7 @@ def trigger_vm_creation(name, interface, memory, cores): current_app.logger.info( f"Launch new job to create the {name} VM: {job_template} with {extra_vars}" ) - user = utils.cse_current_user() - task = user.launch_task( + task = current_user.launch_task( task_name, func="launch_job_template", job_template=job_template, @@ -130,8 +130,7 @@ def trigger_ztp_configuration(host): current_app.logger.info( f"Launch new job to generate ZTP configuration for {host.name} device: {job_template} with {extra_vars}" ) - user = utils.cse_current_user() - task = user.launch_task( + task = current_user.launch_task( "trigger_ztp_configuration", func="launch_job_template", job_template=job_template, diff --git a/app/tokens.py b/app/tokens.py index 282cd8f6e8f2459ae6b8bfb4597133ca1971a71c..7bbbfe422e445e98e4ec9d020f9ff027fe58603c 100644 --- a/app/tokens.py +++ b/app/tokens.py @@ -11,12 +11,96 @@ This module implements helper functions to manipulate JWT. """ import sqlalchemy as sa from datetime import datetime -from flask import current_app +from flask import current_app, g, flash, redirect, request +from flask_login import user_loaded_from_request +from flask_login import login_url from flask_jwt_extended import decode_token, create_access_token -from .extensions import db, jwt +from flask_jwt_extended.exceptions import ( + NoAuthorizationError, + InvalidHeaderError, + WrongTokenError, + RevokedTokenError, + JWTExtendedException, +) +from .extensions import db, jwt, login_manager from . import models, utils +@user_loaded_from_request.connect +def user_loaded_from_request(app, user=None): + g.login_via_request = True + + +@login_manager.unauthorized_handler +def unauthorized(): + """Called when the user is required to log in.""" + if ( + request.path.startswith("/api") + or request.accept_mimetypes.best == "application/json" + ): + # API request + # If this method is called, load_user_from_request returned None + # Either an exception was raised by decode_jwt_from_request + # or None was returned by models.User.query.get(int(identity)) + # Let decode_jwt_from_request raise an exception again + # or raise CSEntryError + jwt_data = decode_jwt_from_request(request, request_type="access") + identity = jwt_data[current_app.config["JWT_IDENTITY_CLAIM"]] + raise utils.CSEntryError(f"Invalid indentity '{identity}'", status_code=403) + else: + # browser request + flash("Please log in to access this page.", "info") + redirect_url = login_url("user.login", next_url=request.url) + return redirect(redirect_url) + + +def decode_jwt_from_headers(request): + header_name = current_app.config["JWT_HEADER_NAME"] + header_type = current_app.config["JWT_HEADER_TYPE"] + # Verify we have the auth header + jwt_header = request.headers.get(header_name, None) + if not jwt_header: + raise NoAuthorizationError(f"Missing {header_name} Header") + # Make sure the header is in a valid format that we are expecting, ie + # <HeaderName>: <HeaderType(optional)> <JWT> + parts = jwt_header.split() + if not header_type: + if len(parts) != 1: + msg = f"Bad {header_name} header. Expected value '<JWT>'" + raise InvalidHeaderError(msg) + encoded_token = parts[0] + else: + if parts[0] != header_type or len(parts) != 2: + msg = f"Bad {header_name} header. Expected value '{header_type} <JWT>'" + raise InvalidHeaderError(msg) + encoded_token = parts[1] + return decode_token(encoded_token) + + +def decode_jwt_from_request(request, request_type): + decoded_token = decode_jwt_from_headers(request) + # Make sure the type of token we received matches the request type we expect + if decoded_token["type"] != request_type: + raise WrongTokenError(f"Only {request_type} tokens can access this endpoint") + if is_token_in_blacklist(decoded_token): + raise RevokedTokenError("Token has been revoked") + return decoded_token + + +@login_manager.request_loader +def load_user_from_request(request): + """User loader callback using JWT from the headers + + Return a user object or None if the user doesn't exist. + """ + try: + jwt_data = decode_jwt_from_request(request, request_type="access") + except JWTExtendedException: + return None + identity = jwt_data[current_app.config["JWT_IDENTITY_CLAIM"]] + return models.User.query.get(int(identity)) + + @jwt.user_loader_callback_loader def user_loader_callback(identity): """User loader callback for flask-jwt-extended diff --git a/app/utils.py b/app/utils.py index 04a6d1791cf47b1e08d439572ac41aa7ce80377f..003a1cb1bf74abec59327a2d87e579c7f6b48a11 100644 --- a/app/utils.py +++ b/app/utils.py @@ -19,26 +19,15 @@ import yaml from flask import current_app from flask.globals import _app_ctx_stack, _request_ctx_stack from flask_login import current_user -from flask_jwt_extended import get_current_user - - -def cse_current_user(): - """Return the current_user from flask_jwt_extended (API) or flask_login (web UI)""" - # Return None if we are outside of request context. - if _app_ctx_stack.top is None or _request_ctx_stack.top is None: - return None - return get_current_user() or current_user def fetch_current_user_id(): - """Retrieve the user_id from flask_jwt_extended (API) or flask_login (web UI)""" + """Retrieve the current user id""" # Return None if we are outside of request context. if _app_ctx_stack.top is None or _request_ctx_stack.top is None: return None - # Try to get the user from both flask_jwt_extended and flask_login - user = cse_current_user() try: - return user.id + return current_user.id except AttributeError: return None @@ -246,17 +235,16 @@ def trigger_core_services_update(): Make sure that we don't have more than one in queue. """ job_template = current_app.config["AWX_CORE_SERVICES_UPDATE"] - user = cse_current_user() - if user.is_task_waiting("trigger_core_services_update"): + if current_user.is_task_waiting("trigger_core_services_update"): current_app.logger.info( 'Already one "trigger_core_services_update" task waiting. No need to trigger a new one.' ) return None kwargs = {"func": "launch_job_template", "job_template": job_template} - started = user.get_task_started("trigger_core_services_update") + started = current_user.get_task_started("trigger_core_services_update") if started: # There is already one running task. Trigger a new one when it's done. kwargs["depends_on"] = started.id current_app.logger.info(f"Launch new job to update core services: {job_template}") - task = user.launch_task("trigger_core_services_update", **kwargs) + task = current_user.launch_task("trigger_core_services_update", **kwargs) return task