diff --git a/app/fields.py b/app/fields.py index bc35c41688b7033285d31b8feef20c25d11da80c..07fb4d1e6afb04bba5899b867caf1c8c5237ab02 100644 --- a/app/fields.py +++ b/app/fields.py @@ -32,11 +32,13 @@ class YAMLField(TextAreaField): return yaml.safe_dump(self.data, default_flow_style=False) if self.data else "" def process_formdata(self, valuelist): - if valuelist: + if valuelist and valuelist[0].strip() != "": try: self.data = yaml.safe_load(valuelist[0]) except yaml.YAMLError: raise ValueError("This field contains invalid YAML") + if not isinstance(self.data, dict): + raise ValueError("This field shall only contain key-value-pairs") else: self.data = None diff --git a/tests/functional/test_web.py b/tests/functional/test_web.py index 4204ac909c5a35fb206e99b9e57c4d540ed93a34..53c1c55c4428ef0a3898affb90156106710a5dbe 100644 --- a/tests/functional/test_web.py +++ b/tests/functional/test_web.py @@ -318,7 +318,7 @@ def test_create_host(client, network_scope_factory, network_factory, device_type "ip": ip, "mac": mac, "description": "test", - "ansible_vars": "", + "ansible_vars": "foo: hello", "ansible_groups": [], "random_mac": False, "cnames_string": "", diff --git a/tests/unit/test_fields.py b/tests/unit/test_fields.py index 99e2a9be433eea1cb33a3ced998fad2e0cb04e6d..70b51790fce9b88c7bea7accb38564c8eade5123 100644 --- a/tests/unit/test_fields.py +++ b/tests/unit/test_fields.py @@ -9,7 +9,13 @@ This module defines fields tests. :license: BSD 2-Clause, see LICENSE for more details. """ -from app.fields import yaml +import pytest +from wtforms.form import Form +from app.fields import yaml, YAMLField + + +class MyForm(Form): + vars = YAMLField("Ansible vars") def test_vault_yaml_tag_load(): @@ -33,3 +39,33 @@ def test_vault_yaml_tag_load(): """ } } + + +@pytest.mark.parametrize( + "text_input,expected", + [ + ("foo: hello", {"foo": "hello"}), + ("foo:\n - a\n - b", {"foo": ["a", "b"]}), + ("", None), + (" ", None), + ], +) +def test_yamlfield_process_formdata(text_input, expected): + form = MyForm() + YAMLField.process_formdata(form.vars, [text_input]) + assert form.vars.data == expected + + +def test_yamlfield_process_formdata_invalid_yaml(): + form = MyForm() + with pytest.raises(ValueError, match="This field contains invalid YAML"): + YAMLField.process_formdata(form.vars, ["foo: hello: world"]) + + +@pytest.mark.parametrize("text_input", ("foo", "- a\n- b")) +def test_yamlfield_process_formdata_non_dict(text_input): + form = MyForm() + with pytest.raises( + ValueError, match="This field shall only contain key-value-pairs" + ): + YAMLField.process_formdata(form.vars, [text_input])