Skip to content
Snippets Groups Projects
Commit 82542c83 authored by Benjamin Bertrand's avatar Benjamin Bertrand
Browse files

Explicitely update relationships

This is to ensure that when editing an object it is properly
re-indexed.

JIRA INFRA-575
parent 224db0f0
No related branches found
No related tags found
No related merge requests found
...@@ -45,3 +45,9 @@ def associate_mac_to_interface(address, interface): ...@@ -45,3 +45,9 @@ def associate_mac_to_interface(address, interface):
mac = models.Mac(address=address) mac = models.Mac(address=address)
db.session.add(mac) db.session.add(mac)
mac.interfaces.append(interface) mac.interfaces.append(interface)
def get_model(class_, id_):
if id_ is None:
return None
return class_.query.get(id_)
...@@ -25,7 +25,7 @@ from flask_login import login_required, current_user ...@@ -25,7 +25,7 @@ from flask_login import login_required, current_user
from .forms import AttributeForm, ItemForm, CommentForm from .forms import AttributeForm, ItemForm, CommentForm
from ..extensions import db from ..extensions import db
from ..decorators import login_groups_accepted from ..decorators import login_groups_accepted
from .. import utils, models from .. import utils, models, helpers
bp = Blueprint("inventory", __name__) bp = Blueprint("inventory", __name__)
...@@ -67,12 +67,14 @@ def create_item(): ...@@ -67,12 +67,14 @@ def create_item():
ics_id=form.ics_id.data, ics_id=form.ics_id.data,
serial_number=form.serial_number.data, serial_number=form.serial_number.data,
quantity=form.quantity.data, quantity=form.quantity.data,
manufacturer_id=form.manufacturer_id.data, manufacturer=helpers.get_model(
model_id=form.model_id.data, models.Manufacturer, form.manufacturer_id.data
location_id=form.location_id.data, ),
status_id=form.status_id.data, model=helpers.get_model(models.Model, form.model_id.data),
parent_id=form.parent_id.data, location=helpers.get_model(models.Location, form.location_id.data),
host_id=form.host_id.data, status=helpers.get_model(models.Status, form.status_id.data),
parent=helpers.get_model(models.Item, form.parent_id.data),
host=helpers.get_model(models.Host, form.host_id.data),
stack_member=form.stack_member.data, stack_member=form.stack_member.data,
) )
item.macs = [ item.macs = [
...@@ -105,7 +107,8 @@ def comment_item(ics_id): ...@@ -105,7 +107,8 @@ def comment_item(ics_id):
item = models.Item.query.filter_by(ics_id=ics_id).first_or_404() item = models.Item.query.filter_by(ics_id=ics_id).first_or_404()
form = CommentForm() form = CommentForm()
if form.validate_on_submit(): if form.validate_on_submit():
comment = models.ItemComment(body=form.body.data, item_id=item.id) comment = models.ItemComment(body=form.body.data)
item.comments.append(comment)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
return redirect(url_for("inventory.view_item", ics_id=ics_id)) return redirect(url_for("inventory.view_item", ics_id=ics_id))
...@@ -132,15 +135,14 @@ def edit_item(ics_id): ...@@ -132,15 +135,14 @@ def edit_item(ics_id):
else: else:
# Field is disabled, force it to None # Field is disabled, force it to None
item.stack_member = None item.stack_member = None
for key in ( item.manufacturer = helpers.get_model(
"manufacturer_id", models.Manufacturer, form.manufacturer_id.data
"model_id", )
"location_id", item.model = helpers.get_model(models.Model, form.model_id.data)
"status_id", item.location = helpers.get_model(models.Location, form.location_id.data)
"parent_id", item.status = helpers.get_model(models.Status, form.status_id.data)
"host_id", item.parent = helpers.get_model(models.Item, form.parent_id.data)
): item.host = helpers.get_model(models.Host, form.host_id.data)
setattr(item, key, getattr(form, key).data)
new_addresses = form.mac_addresses.data.split() new_addresses = form.mac_addresses.data.split()
# Delete the MAC addresses that have been removed # Delete the MAC addresses that have been removed
for (index, mac) in enumerate(item.macs): for (index, mac) in enumerate(item.macs):
......
...@@ -70,7 +70,7 @@ def create_host(): ...@@ -70,7 +70,7 @@ def create_host():
] ]
host = models.Host( host = models.Host(
name=form.name.data, name=form.name.data,
device_type_id=form.device_type_id.data, device_type=models.DeviceType.query.get(form.device_type_id.data),
description=form.description.data or None, description=form.description.data or None,
ansible_vars=form.ansible_vars.data or None, ansible_vars=form.ansible_vars.data or None,
ansible_groups=ansible_groups, ansible_groups=ansible_groups,
...@@ -83,7 +83,7 @@ def create_host(): ...@@ -83,7 +83,7 @@ def create_host():
interface = models.Interface( interface = models.Interface(
name=form.interface_name.data, name=form.interface_name.data,
ip=form.ip.data, ip=form.ip.data,
network_id=network_id, network=models.Network.query.get(network_id),
tags=tags, tags=tags,
) )
interface.cnames = [ interface.cnames = [
...@@ -188,7 +188,7 @@ def edit_host(name): ...@@ -188,7 +188,7 @@ def edit_host(name):
form.ansible_groups.process(request.form) form.ansible_groups.process(request.form)
if form.validate_on_submit(): if form.validate_on_submit():
host.name = form.name.data host.name = form.name.data
host.device_type_id = form.device_type_id.data host.device_type = models.DeviceType.query.get(form.device_type_id.data)
host.description = form.description.data or None host.description = form.description.data or None
host.ansible_vars = form.ansible_vars.data or None host.ansible_vars = form.ansible_vars.data or None
host.ansible_groups = [ host.ansible_groups = [
...@@ -222,16 +222,18 @@ def create_interface(hostname): ...@@ -222,16 +222,18 @@ def create_interface(hostname):
all_tags = models.Tag.query.all() all_tags = models.Tag.query.all()
tags = [tag for tag in all_tags if str(tag.id) in form.tags.data] tags = [tag for tag in all_tags if str(tag.id) in form.tags.data]
interface = models.Interface( interface = models.Interface(
host_id=host.id,
name=form.interface_name.data, name=form.interface_name.data,
ip=form.ip.data, ip=form.ip.data,
network_id=form.network_id.data, network=models.Network.query.get(form.network_id.data),
tags=tags, tags=tags,
) )
interface.cnames = [ interface.cnames = [
models.Cname(name=name) for name in form.cnames_string.data.split() models.Cname(name=name) for name in form.cnames_string.data.split()
] ]
helpers.associate_mac_to_interface(form.mac_address.data, interface) helpers.associate_mac_to_interface(form.mac_address.data, interface)
# Make sure to update host.interfaces instead of using interface.host_id
# to force the host to be added to the session for indexing
host.interfaces.append(interface)
current_app.logger.debug(f"Trying to create: {interface!r}") current_app.logger.debug(f"Trying to create: {interface!r}")
db.session.add(interface) db.session.add(interface)
try: try:
...@@ -241,7 +243,7 @@ def create_interface(hostname): ...@@ -241,7 +243,7 @@ def create_interface(hostname):
current_app.logger.warning(f"{e}") current_app.logger.warning(f"{e}")
flash(f"{e}", "error") flash(f"{e}", "error")
else: else:
flash(f"Host {interface} created!", "success") flash(f"Interface {interface} created!", "success")
return redirect(url_for("network.create_interface", hostname=hostname)) return redirect(url_for("network.create_interface", hostname=hostname))
return render_template( return render_template(
"network/create_interface.html", form=form, hostname=hostname "network/create_interface.html", form=form, hostname=hostname
...@@ -287,7 +289,7 @@ def edit_interface(name): ...@@ -287,7 +289,7 @@ def edit_interface(name):
# else: nothing to do (address didn't change) # else: nothing to do (address didn't change)
else: else:
# No MAC associated # No MAC associated
interface.mac_id = None interface.mac = None
# Delete the cnames that have been removed # Delete the cnames that have been removed
new_cnames_string = form.cnames_string.data.split() new_cnames_string = form.cnames_string.data.split()
for (index, cname) in enumerate(interface.cnames): for (index, cname) in enumerate(interface.cnames):
...@@ -306,6 +308,9 @@ def edit_interface(name): ...@@ -306,6 +308,9 @@ def edit_interface(name):
all_tags = models.Tag.query.all() all_tags = models.Tag.query.all()
tags = [tag for tag in all_tags if str(tag.id) in form.tags.data] tags = [tag for tag in all_tags if str(tag.id) in form.tags.data]
interface.tags = tags interface.tags = tags
# Mark the host as "dirty" to add it to the session so that it will
# be re-indexed
sa.orm.attributes.flag_modified(interface.host, "interfaces")
current_app.logger.debug(f"Trying to update: {interface!r}") current_app.logger.debug(f"Trying to update: {interface!r}")
try: try:
db.session.commit() db.session.commit()
...@@ -326,6 +331,9 @@ def edit_interface(name): ...@@ -326,6 +331,9 @@ def edit_interface(name):
def delete_interface(): def delete_interface():
interface = models.Interface.query.get_or_404(request.form["interface_id"]) interface = models.Interface.query.get_or_404(request.form["interface_id"])
hostname = interface.host.name hostname = interface.host.name
# Explicitely remove the interface from the host to make sure
# it will be re-indexed
interface.host.interfaces.remove(interface)
# Deleting the interface will also delete all # Deleting the interface will also delete all
# associated cnames due to the cascade delete option # associated cnames due to the cascade delete option
# defined on the model # defined on the model
...@@ -454,7 +462,7 @@ def create_scope(): ...@@ -454,7 +462,7 @@ def create_scope():
first_vlan=form.first_vlan.data, first_vlan=form.first_vlan.data,
last_vlan=form.last_vlan.data, last_vlan=form.last_vlan.data,
supernet=form.supernet.data, supernet=form.supernet.data,
domain_id=form.domain_id.data, domain=models.Domain.query.get(form.domain_id.data),
) )
current_app.logger.debug(f"Trying to create: {scope!r}") current_app.logger.debug(f"Trying to create: {scope!r}")
db.session.add(scope) db.session.add(scope)
...@@ -548,14 +556,14 @@ def create_network(): ...@@ -548,14 +556,14 @@ def create_network():
if form.validate_on_submit(): if form.validate_on_submit():
scope_id = form.scope_id.data scope_id = form.scope_id.data
network = models.Network( network = models.Network(
scope_id=scope_id, scope=models.NetworkScope.query.get(scope_id),
vlan_name=form.vlan_name.data, vlan_name=form.vlan_name.data,
vlan_id=form.vlan_id.data, vlan_id=form.vlan_id.data,
description=form.description.data or None, description=form.description.data or None,
address=form.address.data, address=form.address.data,
first_ip=form.first_ip.data, first_ip=form.first_ip.data,
last_ip=form.last_ip.data, last_ip=form.last_ip.data,
domain_id=form.domain_id.data, domain=models.Domain.query.get(form.domain_id.data),
admin_only=form.admin_only.data, admin_only=form.admin_only.data,
) )
current_app.logger.debug(f"Trying to create: {network!r}") current_app.logger.debug(f"Trying to create: {network!r}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment