Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • andersharrisson/csentry
  • ics-infrastructure/csentry
2 results
Show changes
Showing
with 4289 additions and 367 deletions
"""Rename MicroTCA to MTCA-AMC
Revision ID: 91b0093a5e13
Revises: b1eda5cb7d9d
Create Date: 2020-03-05 12:55:35.804867
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "91b0093a5e13"
down_revision = "b1eda5cb7d9d"
branch_labels = None
depends_on = None
def upgrade():
device_type = sa.sql.table(
"device_type", sa.sql.column("id"), sa.sql.column("name")
)
op.execute(
device_type.update()
.where(device_type.c.name == "MicroTCA")
.values(name="MTCA-AMC")
)
def downgrade():
device_type = sa.sql.table(
"device_type", sa.sql.column("id"), sa.sql.column("name")
)
op.execute(
device_type.update()
.where(device_type.c.name == "MTCA-AMC")
.values(name="MicroTCA")
)
"""Add is_ioc field
Revision ID: ac04850e5f68
Revises: f7d72e432f51
Create Date: 2019-02-28 11:28:36.993953
"""
from alembic import op
import sqlalchemy as sa
import citext
# revision identifiers, used by Alembic.
revision = "ac04850e5f68"
down_revision = "f7d72e432f51"
branch_labels = None
depends_on = None
def upgrade():
op.execute("COMMIT")
op.execute("ALTER TYPE ansible_group_type ADD VALUE 'IOC'")
op.add_column(
"host",
sa.Column("is_ioc", sa.Boolean(), nullable=False, server_default="False"),
)
op.add_column(
"host_version",
sa.Column("is_ioc", sa.Boolean(), autoincrement=False, nullable=True),
)
host = sa.sql.table("host", sa.sql.column("id"), sa.sql.column("is_ioc"))
conn = op.get_bind()
res = conn.execute("SELECT id FROM tag WHERE name = 'IOC';")
row = res.fetchone()
if row is not None:
ioc_tag_id = row[0]
res.close()
res = conn.execute(
f"""SELECT interface.host_id FROM interface
INNER JOIN interfacetags ON interface.id = interfacetags.interface_id
WHERE interfacetags.tag_id = {ioc_tag_id};
"""
)
results = res.fetchall()
for result in results:
op.execute(host.update().where(host.c.id == result[0]).values(is_ioc=True))
op.drop_table("interfacetags")
op.drop_table("tag")
def downgrade():
# WARNING! The downgrade doesn't recreate the IOC tag
op.drop_column("host_version", "is_ioc")
op.drop_column("host", "is_ioc")
op.create_table(
"tag",
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("name", citext.CIText(), autoincrement=False, nullable=False),
sa.Column("description", sa.TEXT(), autoincrement=False, nullable=True),
sa.Column("admin_only", sa.BOOLEAN(), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint("id", name="pk_tag"),
sa.UniqueConstraint("name", name="uq_tag_name"),
)
op.create_table(
"interfacetags",
sa.Column("tag_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("interface_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(
["interface_id"],
["interface.id"],
name="fk_interfacetags_interface_id_interface",
),
sa.ForeignKeyConstraint(
["tag_id"], ["tag.id"], name="fk_interfacetags_tag_id_tag"
),
sa.PrimaryKeyConstraint("tag_id", "interface_id", name="pk_interfacetags"),
)
"""Add sensitive column to Network
Revision ID: acd72492f46f
Revises: 33720bfb353a
Create Date: 2020-01-30 14:12:54.923114
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "acd72492f46f"
down_revision = "33720bfb353a"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
"network",
sa.Column("sensitive", sa.Boolean(), nullable=False, server_default="False"),
)
def downgrade():
op.drop_column("network", "sensitive")
"""Add HOSTNAME Ansible group type
Revision ID: b1eda5cb7d9d
Revises: acd72492f46f
Create Date: 2020-03-04 20:37:15.636489
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b1eda5cb7d9d"
down_revision = "acd72492f46f"
branch_labels = None
depends_on = None
def upgrade():
op.execute("COMMIT")
op.execute("ALTER TYPE ansible_group_type ADD VALUE 'HOSTNAME'")
def downgrade():
# Removing an individual value from an enum type isn't supported
# https://www.postgresql.org/docs/current/datatype-enum.html
pass
"""Add depends_on field on Task
Revision ID: cb777d44627f
Revises: 166572b78449
Create Date: 2019-03-27 20:51:05.385857
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "cb777d44627f"
down_revision = "166572b78449"
branch_labels = None
depends_on = None
def upgrade():
op.add_column("task", sa.Column("depends_on_id", postgresql.UUID(), nullable=True))
op.create_foreign_key(
op.f("fk_task_depends_on_id_task"), "task", "task", ["depends_on_id"], ["id"]
)
def downgrade():
op.drop_constraint(op.f("fk_task_depends_on_id_task"), "task", type_="foreignkey")
op.drop_column("task", "depends_on_id")
"""Add gateway field
Revision ID: f7d72e432f51
Revises: 7c38e78b6de6
Create Date: 2019-02-27 17:35:22.535126
"""
import ipaddress
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "f7d72e432f51"
down_revision = "7c38e78b6de6"
branch_labels = None
depends_on = None
def upgrade():
op.add_column("network", sa.Column("gateway", postgresql.INET(), nullable=True))
network = sa.sql.table(
"network",
sa.sql.column("id"),
sa.sql.column("address"),
sa.sql.column("gateway"),
)
# Fill the gateway based on the network address
conn = op.get_bind()
res = conn.execute("SELECT id, address FROM network")
results = res.fetchall()
for result in results:
address = ipaddress.ip_network(result[1])
hosts = list(address.hosts())
# Use last IP by default
gateway = str(hosts[-1])
op.execute(
network.update().where(network.c.id == result[0]).values(gateway=gateway)
)
op.create_check_constraint(
op.f("ck_network_gateway_in_network"), "network", "gateway << address"
)
op.create_unique_constraint(op.f("uq_network_gateway"), "network", ["gateway"])
op.alter_column("network", "gateway", nullable=False)
def downgrade():
op.drop_constraint(op.f("uq_network_gateway"), "network", type_="unique")
op.drop_constraint(op.f("ck_network_gateway_in_network"), "network", type_="check")
op.drop_column("network", "gateway")
[tool.black]
py36 = true
target_version = ['py37']
[pytest]
testpaths = tests
sphinx
sphinx_rtd_theme
sphinxcontrib-httpdomain
coverage
factory_boy
......
elasticsearch
elasticsearch>=7.0.0,<8.0.0
flask>=1.0.0
flask-admin
flask-caching
......@@ -18,12 +18,14 @@ pyjwt
python-dateutil
pyyaml
qrcode
uwsgi
whitenoise
ansible-tower-cli
raven
ansible-tower-cli<3.3.9
rq
rq-dashboard
sentry-sdk
sqlalchemy<1.3
sqlalchemy-citext
sqlalchemy-continuum
openpyxl
uwsgi
WTForms>=2.1,<2.2
alembic==1.0.0
ansible-tower-cli==3.3.0
arrow==0.12.1
alembic==1.4.3
ansible-tower-cli==3.3.8
arrow==0.17.0
blinker==1.4
certifi==2018.4.16
cachelib==0.1.1
certifi==2020.6.20
chardet==3.0.4
click==6.7
colorama==0.3.9
elasticsearch==6.3.1
click==7.1.2
colorama==0.4.4
elasticsearch==7.9.1
et-xmlfile==1.0.1
Flask==1.0.2
Flask-Admin==1.5.1
Flask-Caching==1.4.0
Flask-DebugToolbar==0.10.1
Flask-JWT-Extended==3.12.0
flask-ldap3-login==0.9.14
Flask-Login==0.4.1
Flask==1.1.2
Flask-Admin==1.5.6
Flask-Caching==1.9.0
Flask-DebugToolbar==0.11.0
Flask-JWT-Extended==3.24.1
flask-ldap3-login==0.9.16
Flask-Login==0.5.0
Flask-Mail==0.9.1
Flask-Migrate==2.2.1
Flask-Redis==0.3.0
Flask-Session==0.3.1
Flask-SQLAlchemy==2.3.2
Flask-WTF==0.14.2
idna==2.7
itsdangerous==0.24
jdcal==1.4
Jinja2==2.10
ldap3==2.5.1
Mako==1.0.7
MarkupSafe==1.0
openpyxl==2.5.7
Pillow==5.2.0
psycopg2==2.7.5
pyasn1==0.4.4
PyJWT==1.6.4
python-dateutil==2.7.3
python-editor==1.0.3
PyYAML==3.13
qrcode==6.0
raven==6.9.0
redis==2.10.6
requests==2.19.1
rq==0.12.0
rq-dashboard==0.3.12
six==1.11.0
SQLAlchemy==1.2.10
sqlalchemy-citext==1.3.post0
SQLAlchemy-Continuum==1.3.6
SQLAlchemy-Utils==0.33.3
urllib3==1.23
uWSGI==2.0.17.1
Werkzeug==0.14.1
whitenoise==3.3.1
Flask-Migrate==2.5.3
flask-redis==0.4.0
Flask-Session==0.3.2
Flask-SQLAlchemy==2.4.4
Flask-WTF==0.14.3
idna==2.10
itsdangerous==1.1.0
jdcal==1.4.1
Jinja2==2.11.2
ldap3==2.8.1
Mako==1.1.3
MarkupSafe==1.1.1
openpyxl==3.0.5
Pillow==8.0.0
psycopg2==2.8.6
pyasn1==0.4.8
PyJWT==1.7.1
python-dateutil==2.8.1
python-editor==1.0.4
PyYAML==5.3.1
qrcode==6.1
redis==3.5.3
requests==2.24.0
rq==1.5.2
rq-dashboard==0.6.1
sentry-sdk==0.19.1
six==1.15.0
SQLAlchemy==1.2.19
sqlalchemy-citext==1.7.0
SQLAlchemy-Continuum==1.3.11
SQLAlchemy-Utils==0.36.8
urllib3==1.25.11
uWSGI==2.0.19.1
Werkzeug==1.0.1
whitenoise==5.2.0
WTForms==2.1
......@@ -9,12 +9,15 @@ Pytest fixtures common to all functional tests.
:license: BSD 2-Clause, see LICENSE for more details.
"""
import redis
import pytest
import sqlalchemy as sa
from pytest_factoryboy import register
from flask_ldap3_login import AuthenticationResponse, AuthenticationResponseStatus
from rq import push_connection, pop_connection
from app.factory import create_app
from app.extensions import db as _db
from app.models import SearchableMixin, Host, Item, AnsibleGroup
from . import common, factories
register(factories.UserFactory)
......@@ -24,6 +27,7 @@ register(factories.ModelFactory)
register(factories.LocationFactory)
register(factories.StatusFactory)
register(factories.ItemFactory)
register(factories.ItemCommentFactory)
register(factories.NetworkScopeFactory)
register(factories.NetworkFactory)
register(factories.InterfaceFactory)
......@@ -33,7 +37,6 @@ register(factories.HostFactory)
register(factories.MacFactory)
register(factories.DomainFactory)
register(factories.CnameFactory)
register(factories.TagFactory)
register(factories.TaskFactory)
......@@ -44,10 +47,25 @@ def app(request):
"TESTING": True,
"WTF_CSRF_ENABLED": False,
"SQLALCHEMY_DATABASE_URI": "postgresql://ics:icspwd@postgres/csentry_db_test",
"RQ_REDIS_URL": "redis://redis:6379/4",
"ELASTICSEARCH_INDEX_SUFFIX": "-test",
"ELASTICSEARCH_REFRESH": "true",
"CACHE_TYPE": "null",
"CACHE_NO_NULL_WARNING": True,
"CSENTRY_LDAP_GROUPS": {
"admin": ["CSEntry Admin"],
"create": ["CSEntry User", "CSEntry Consultant"],
"auditor": ["CSEntry Auditor"],
"inventory": ["CSEntry User"],
},
"CSENTRY_NETWORK_SCOPES_LDAP_GROUPS": {
"ProdNetworks": ["CSEntry Prod"],
"LabNetworks": ["CSEntry Lab"],
"FooNetworks": ["CSEntry User", "CSEntry Consultant"],
},
"AWX_URL": "https://awx.example.org",
"ALLOWED_VM_CREATION_NETWORK_SCOPES": ["LabNetworks"],
"ALLOWED_SET_BOOT_PROFILE_NETWORK_SCOPES": ["LabNetworks"],
"MAX_PER_PAGE": 25,
}
app = create_app(config=config)
ctx = app.app_context()
......@@ -74,7 +92,9 @@ def db(app, request):
_db.drop_all()
_db.app = app
_db.app.elasticsearch.indices.delete("*-test", ignore=404)
_db.engine.execute("CREATE EXTENSION IF NOT EXISTS citext")
_db.drop_all()
_db.create_all()
request.addfinalizer(teardown)
......@@ -104,10 +124,34 @@ def session(db, request):
session.expire_all()
session.begin_nested()
# We have to register the before_flush/after_flush_postexec/after_commit events
# because we use a specific session to run the tests (not the same used in models.py)
db.event.listen(session(), "before_flush", SearchableMixin.before_flush)
db.event.listen(
session(), "after_flush_postexec", SearchableMixin.after_flush_postexec
)
db.event.listen(session(), "after_commit", SearchableMixin.after_commit)
db.session = session
# Create the elasticsearch indices
Item.create_index()
Host.create_index()
AnsibleGroup.create_index()
# Setup RQ redis connection
redis_connection = redis.from_url(db.app.config["RQ_REDIS_URL"])
redis_connection.flushdb()
push_connection(redis_connection)
yield session
# Clean RQ redis connnection
redis_connection.flushdb()
pop_connection()
# ELASTICSEARCH_INDEX_SUFFIX is set to "-test"
# Delete all "*-test" indices after each test
db.app.elasticsearch.indices.delete("*-test", ignore=404)
session.remove()
transaction.rollback()
connection.close()
......@@ -129,6 +173,21 @@ def patch_ldap_authenticate(monkeypatch):
response.status = AuthenticationResponseStatus.success
response.user_info = {"cn": "Admin User", "mail": "admin@example.com"}
response.user_groups = [{"cn": "CSEntry Admin"}]
elif username == "audit" and password == "auditpasswd":
response.status = AuthenticationResponseStatus.success
response.user_dn = "uid=audit,ou=Service accounts,dc=esss,dc=lu,dc=se"
response.user_info = {
"uid": ["audit"],
"cn": [],
"mail": [],
"dn": "uid=audit,ou=Service accounts,dc=esss,dc=lu,dc=se",
}
response.user_groups = [
{
"cn": ["CSEntry Auditor"],
"dn": "cn=CSEntry Auditor,ou=ICS,ou=Groups,dc=esss,dc=lu,dc=se",
}
]
elif username == "user_rw" and password == "userrw":
response.status = AuthenticationResponseStatus.success
response.user_info = {"cn": "User RW", "mail": "user_rw@example.com"}
......@@ -141,6 +200,14 @@ def patch_ldap_authenticate(monkeypatch):
response.status = AuthenticationResponseStatus.success
response.user_info = {"cn": "User RO", "mail": "user_ro@example.com"}
response.user_groups = [{"cn": "ESS Employees"}]
elif username == "user_prod" and password == "userprod":
response.status = AuthenticationResponseStatus.success
response.user_info = {"cn": "User Prod", "mail": "user_prod@example.com"}
response.user_groups = [{"cn": "CSEntry Prod"}]
elif username == "user_lab" and password == "userlab":
response.status = AuthenticationResponseStatus.success
response.user_info = {"cn": "User Lab", "mail": "user_lab@example.com"}
response.user_groups = [{"cn": "CSEntry Lab"}]
else:
response.status = AuthenticationResponseStatus.fail
return response
......
......@@ -85,6 +85,17 @@ class ItemFactory(factory.alchemy.SQLAlchemyModelFactory):
user = factory.SubFactory(UserFactory)
class ItemCommentFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
model = models.ItemComment
sqlalchemy_session = common.Session
sqlalchemy_session_persistence = "commit"
body = factory.Sequence(lambda n: f"comment{n}")
user = factory.SubFactory(UserFactory)
item = factory.SubFactory(ItemFactory)
class DomainFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
model = models.Domain
......@@ -104,7 +115,7 @@ class NetworkScopeFactory(factory.alchemy.SQLAlchemyModelFactory):
name = factory.Sequence(lambda n: f"scope{n}")
first_vlan = factory.Sequence(lambda n: 1600 + 10 * n)
last_vlan = factory.Sequence(lambda n: 1609 + 10 * n)
supernet = factory.Faker("ipv4", network=True)
supernet = factory.Sequence(lambda n: str(ipaddress.ip_network(f"172.{n}.0.0/16")))
user = factory.SubFactory(UserFactory)
domain = factory.SubFactory(DomainFactory)
......@@ -117,7 +128,6 @@ class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory):
vlan_name = factory.Sequence(lambda n: f"vlan{n}")
vlan_id = factory.Sequence(lambda n: 1600 + n)
address = factory.Sequence(lambda n: f"192.168.{n}.0/24")
scope = factory.SubFactory(NetworkScopeFactory)
user = factory.SubFactory(UserFactory)
domain = factory.SubFactory(DomainFactory)
......@@ -134,19 +144,15 @@ class NetworkFactory(factory.alchemy.SQLAlchemyModelFactory):
hosts = list(net.hosts())
return str(hosts[-5])
@factory.lazy_attribute
def gateway(self):
net = ipaddress.ip_network(self.address)
hosts = list(net.hosts())
return str(hosts[-1])
class InterfaceFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
model = models.Interface
sqlalchemy_session = common.Session
sqlalchemy_session_persistence = "commit"
name = factory.Sequence(lambda n: f"host{n}")
network = factory.SubFactory(NetworkFactory)
ip = factory.LazyAttributeSequence(
lambda o, n: str(ipaddress.ip_address(o.network.first_ip) + n)
)
user = factory.SubFactory(UserFactory)
@factory.lazy_attribute
def address(self):
return self.scope.available_subnets(24)[0]
class DeviceTypeFactory(factory.alchemy.SQLAlchemyModelFactory):
......@@ -179,6 +185,22 @@ class HostFactory(factory.alchemy.SQLAlchemyModelFactory):
device_type = factory.SubFactory(DeviceTypeFactory)
class InterfaceFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
model = models.Interface
sqlalchemy_session = common.Session
sqlalchemy_session_persistence = "commit"
host = factory.SubFactory(HostFactory)
name = factory.LazyAttributeSequence(lambda o, n: f"{o.host.name}-{n}")
network = factory.SubFactory(NetworkFactory)
ip = factory.LazyAttributeSequence(
lambda o, n: str(ipaddress.ip_address(o.network.first_ip) + n)
)
mac = factory.Faker("mac_address")
user = factory.SubFactory(UserFactory)
class MacFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
model = models.Mac
......@@ -194,20 +216,11 @@ class CnameFactory(factory.alchemy.SQLAlchemyModelFactory):
sqlalchemy_session = common.Session
sqlalchemy_session_persistence = "commit"
name = factory.Sequence(lambda n: f"host{n}")
name = factory.Sequence(lambda n: f"cname{n}")
interface = factory.SubFactory(InterfaceFactory)
user = factory.SubFactory(UserFactory)
class TagFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
model = models.Tag
sqlalchemy_session = common.Session
sqlalchemy_session_persistence = "commit"
name = factory.Sequence(lambda n: f"Tag{n}")
class TaskFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
model = models.Task
......
......@@ -10,6 +10,7 @@ This module defines API tests.
"""
import datetime
import ipaddress
import json
import pytest
from app import models
......@@ -23,28 +24,84 @@ ENDPOINT_MODEL = {
"inventory/locations": models.Location,
"inventory/statuses": models.Status,
"inventory/items": models.Item,
"inventory/macs": models.Mac,
"network/networks": models.Network,
"network/interfaces": models.Interface,
"network/hosts": models.Host,
"network/groups": models.AnsibleGroup,
"network/macs": models.Mac,
"network/domains": models.Domain,
"network/cnames": models.Cname,
}
GENERIC_GET_ENDPOINTS = [
key
for key in ENDPOINT_MODEL.keys()
if key.startswith("inventory") and key != "inventory/items"
if key.startswith("inventory") and key not in ("inventory/items", "inventory/macs")
]
GENERIC_CREATE_ENDPOINTS = [
key
for key in ENDPOINT_MODEL.keys()
if key.startswith("inventory")
and key not in ("inventory/items", "inventory/actions")
and key not in ("inventory/items", "inventory/actions", "inventory/macs")
]
CREATE_AUTH_ENDPOINTS = [
key for key in ENDPOINT_MODEL.keys() if key != "inventory/actions"
]
HOST_KEYS = {
"id",
"name",
"fqdn",
"is_ioc",
"device_type",
"model",
"description",
"items",
"interfaces",
"ansible_vars",
"ansible_groups",
"created_at",
"updated_at",
"user",
"scope",
"sensitive",
}
INTERFACE_KEYS = {
"id",
"is_main",
"network",
"ip",
"netmask",
"name",
"description",
"mac",
"domain",
"host",
"device_type",
"model",
"cnames",
"created_at",
"updated_at",
"user",
}
NETWORK_KEYS = {
"id",
"vlan_name",
"vlan_id",
"address",
"netmask",
"broadcast",
"first_ip",
"last_ip",
"gateway",
"description",
"admin_only",
"sensitive",
"scope",
"domain",
"interfaces",
"created_at",
"updated_at",
"user",
}
def get(client, url, token=None):
......@@ -92,6 +149,28 @@ def get_token(client, username, password):
return response.get_json()["access_token"]
def create_hosts(number, host_factory, interface_factory, *args):
"""Helper function to create a number of hosts
:param number: number of hosts to create
:param host_factory: host_factory fixture
:param interface_factory: interface_factory fixture
:param *args: list of networks (host will have one interface per network)
"""
for i in range(number):
host = host_factory()
for n, network in enumerate(args):
ip = ipaddress.ip_address(network.first_ip) + i
if n == 0:
interface_factory(
name=host.name, host=host, network=network, ip=str(ip)
)
else:
interface_factory(
name=f"{host.name}-{n}", host=host, network=network, ip=str(ip)
)
@pytest.fixture()
def readonly_token(client):
return get_token(client, "user_ro", "userro")
......@@ -102,6 +181,11 @@ def user_token(client):
return get_token(client, "user_rw", "userrw")
@pytest.fixture()
def user_prod_token(client):
return get_token(client, "user_prod", "userprod")
@pytest.fixture()
def consultant_token(client):
return get_token(client, "consultant", "consultantpwd")
......@@ -112,6 +196,27 @@ def admin_token(client):
return get_token(client, "admin", "adminpasswd")
@pytest.fixture
def no_login_check_token(request, app):
app.config["LOGIN_DISABLED"] = True
client = app.test_client()
# We still need to login, otherwise an AnonymousUserMixin is returned
# An AnonymousUser doesn't have all the User methods
yield get_token(client, "user_ro", "userro")
app.config["LOGIN_DISABLED"] = False
@pytest.fixture
def network_192_168_1(network_scope_factory, network_factory):
scope = network_scope_factory(supernet="192.168.0.0/16")
return network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
def check_response_message(response, msg, status_code=400):
assert response.status_code == status_code
try:
......@@ -142,6 +247,12 @@ def check_input_is_subset_of_response(response, inputs):
assert d2[key] == value
def check_delete_success(clt, token, instance, endpoint, model):
response = delete(clt, f"{API_URL}/{endpoint}/{instance.id}", token=token)
assert response.status_code == 204
assert len(model.query.all()) == 0
def test_login(client):
response = client.post(f"{API_URL}/user/login")
check_response_message(response, "Body should be a JSON object")
......@@ -192,7 +303,7 @@ def test_create_model_auth_fail(endpoint, client, readonly_token):
@pytest.mark.parametrize("endpoint", GENERIC_CREATE_ENDPOINTS)
def test_create_generic_model(endpoint, client, user_token):
response = post(client, f"{API_URL}/{endpoint}", data={}, token=user_token)
check_response_message(response, "Missing mandatory field 'name'", 422)
check_response_message(response, "At least one field is required", 422)
data = {"name": "Foo"}
response = post(client, f"{API_URL}/{endpoint}", data=data, token=user_token)
assert response.status_code == 201
......@@ -201,7 +312,7 @@ def test_create_generic_model(endpoint, client, user_token):
response = post(client, f"{API_URL}/{endpoint}", data=data, token=user_token)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
"(psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint",
422,
)
response = post(
......@@ -209,7 +320,7 @@ def test_create_generic_model(endpoint, client, user_token):
)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
"(psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint",
422,
)
response = post(
......@@ -217,7 +328,7 @@ def test_create_generic_model(endpoint, client, user_token):
)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
"(psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint",
422,
)
data = {"name": "Bar", "description": "Bar description"}
......@@ -247,7 +358,7 @@ def test_create_generic_model_invalid_param(endpoint, client, user_token):
def test_create_item(client, user_token):
# check that serial_number is mandatory
response = post(client, f"{API_URL}/inventory/items", data={}, token=user_token)
check_response_message(response, "Missing mandatory field 'serial_number'", 422)
check_response_message(response, "At least one field is required", 422)
# check create with only serial_number
data = {"serial_number": "123456"}
......@@ -286,7 +397,7 @@ def test_create_item(client, user_token):
response = post(client, f"{API_URL}/inventory/items", data=data2, token=user_token)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
"(psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint",
422,
)
......@@ -432,7 +543,7 @@ def test_patch_item_integrity_error(client, user_token, item_factory):
)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
"(psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint",
422,
)
......@@ -567,19 +678,35 @@ def test_get_items(client, location_factory, item_factory, readonly_token):
check_response_message(response, "Invalid query arguments", 422)
def test_get_networks(client, network_factory, readonly_token):
@pytest.mark.parametrize("url", ["/network/networks", "/network/scopes"])
def test_get_restricted_url_as_non_admin(url, client, readonly_token):
response = get(client, f"{API_URL}{url}", token=readonly_token)
assert response.status_code == 403
def test_get_networks(client, network_scope_factory, network_factory, admin_token):
# Create some networks
scope = network_scope_factory(supernet="172.16.0.0/16")
network1 = network_factory(
address="172.16.1.0/24", first_ip="172.16.1.1", last_ip="172.16.1.254"
address="172.16.1.0/24",
first_ip="172.16.1.1",
last_ip="172.16.1.254",
scope=scope,
)
network2 = network_factory(
address="172.16.20.0/22", first_ip="172.16.20.11", last_ip="172.16.20.250"
address="172.16.20.0/22",
first_ip="172.16.20.11",
last_ip="172.16.20.250",
scope=scope,
)
network3 = network_factory(
address="172.16.5.0/24", first_ip="172.16.5.10", last_ip="172.16.5.254"
address="172.16.5.0/24",
first_ip="172.16.5.10",
last_ip="172.16.5.254",
scope=scope,
)
response = get(client, f"{API_URL}/network/networks", token=readonly_token)
response = get(client, f"{API_URL}/network/networks", token=admin_token)
assert response.status_code == 200
assert len(response.get_json()) == 3
check_input_is_subset_of_response(
......@@ -588,15 +715,92 @@ def test_get_networks(client, network_factory, readonly_token):
# test filtering by address
response = get(
client,
f"{API_URL}/network/networks?address=172.16.20.0/22",
token=readonly_token,
client, f"{API_URL}/network/networks?address=172.16.20.0/22", token=admin_token
)
assert response.status_code == 200
assert len(response.get_json()) == 1
check_input_is_subset_of_response(response, (network2.to_dict(),))
def test_get_networks_normal_user(
client, network_scope_factory, network_factory, user_token
):
# ProdNetworks scope not available to user - can't see sensitive networks
prod_scope = network_scope_factory(name="ProdNetworks", supernet="172.16.0.0/16")
network_factory(
vlan_name="network-prod",
address="172.16.3.0/24",
first_ip="172.16.3.10",
last_ip="172.16.3.250",
scope=prod_scope,
)
network_factory(
vlan_name="network-prod-sensitive",
address="172.16.4.0/24",
first_ip="172.16.4.10",
last_ip="172.16.4.250",
scope=prod_scope,
sensitive=True,
)
# User can view all networks on FooNetworks scopes (including sensitive)
foo_scope = network_scope_factory(name="FooNetworks", supernet="192.168.0.0/16")
network_factory(
vlan_name="admin-network",
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
admin_only=True,
scope=foo_scope,
)
network_factory(
vlan_name="sensitive-network",
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
admin_only=False,
sensitive=True,
scope=foo_scope,
)
# Create 30 networks
for n in range(3, 33):
network_factory(
vlan_name=f"network{n}",
address=f"192.168.{n}.0/24",
first_ip=f"192.168.{n}.10",
last_ip=f"192.168.{n}.250",
scope=foo_scope,
)
url = f"{API_URL}/network/networks"
response = get(client, url, token=user_token)
assert response.status_code == 200
assert response.headers["x-total-count"] == "33"
networks1 = response.get_json()
assert len(networks1) == 20
assert (
f'{url}?per_page=20&page=2&recursive=False>; rel="next",'
in response.headers["link"]
)
# Get second page
response = get(
client,
f"{url}?per_page=20&page=2&recursive=False",
token=user_token,
)
assert response.status_code == 200
networks2 = response.get_json()
assert len(networks2) == 13
# Check that the sensitive prod network isn't part of the result
retrieved_networks = sorted(
[network["vlan_name"] for network in networks1 + networks2]
)
expected = sorted(
[f"network{n}" for n in range(3, 33)]
+ ["network-prod", "admin-network", "sensitive-network"]
)
assert "network-prod-sensitive" not in set(retrieved_networks)
assert retrieved_networks == expected
def test_create_network_auth_fail(client, session, user_token):
# admin is required to create networks
response = post(client, f"{API_URL}/network/networks", data={}, token=user_token)
......@@ -605,9 +809,9 @@ def test_create_network_auth_fail(client, session, user_token):
def test_create_network(client, admin_token, network_scope_factory):
scope = network_scope_factory(supernet="172.16.0.0/16")
# check that vlan_name, vlan_id, address, first_ip, last_ip and scope are mandatory
# check that vlan_name, vlan_id, address, first_ip, last_ip, gateway and scope are mandatory
response = post(client, f"{API_URL}/network/networks", data={}, token=admin_token)
check_response_message(response, "Missing mandatory field 'vlan_name'", 422)
check_response_message(response, "At least one field is required", 422)
response = post(
client,
f"{API_URL}/network/networks",
......@@ -648,6 +852,20 @@ def test_create_network(client, admin_token, network_scope_factory):
token=admin_token,
)
check_response_message(response, "Missing mandatory field 'last_ip'", 422)
response = post(
client,
f"{API_URL}/network/networks",
data={
"vlan_name": "network1",
"vlan_id": 1600,
"address": "172.16.1.0/24",
"first_ip": "172.16.1.10",
"last_ip": "172.16.1.250",
"scope": scope.name,
},
token=admin_token,
)
check_response_message(response, "Missing mandatory field 'gateway'", 422)
data = {
"vlan_name": "network1",
......@@ -655,50 +873,29 @@ def test_create_network(client, admin_token, network_scope_factory):
"address": "172.16.1.0/24",
"first_ip": "172.16.1.10",
"last_ip": "172.16.1.250",
"gateway": "172.16.1.254",
"scope": scope.name,
}
response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token)
assert response.status_code == 201
assert {
"id",
"vlan_name",
"vlan_id",
"address",
"netmask",
"first_ip",
"last_ip",
"description",
"admin_only",
"scope",
"domain",
"interfaces",
"created_at",
"updated_at",
"user",
} == set(response.get_json().keys())
assert NETWORK_KEYS == set(response.get_json().keys())
assert response.get_json()["vlan_name"] == "network1"
assert response.get_json()["vlan_id"] == 1600
assert response.get_json()["address"] == "172.16.1.0/24"
assert response.get_json()["first_ip"] == "172.16.1.10"
assert response.get_json()["last_ip"] == "172.16.1.250"
assert response.get_json()["gateway"] == "172.16.1.254"
assert response.get_json()["netmask"] == "255.255.255.0"
assert response.get_json()["broadcast"] == "172.16.1.255"
# Check that address and name shall be unique
response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
422,
)
data_same_address = data.copy()
data_same_address["vlan_name"] = "networkX"
response = post(
client, f"{API_URL}/network/networks", data=data_same_address, token=admin_token
)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
422,
response, "172.16.1.0/24 overlaps network1 (172.16.1.0/24)", 422
)
data_same_name = {
"vlan_name": "network1",
......@@ -706,6 +903,7 @@ def test_create_network(client, admin_token, network_scope_factory):
"address": "172.16.2.0/24",
"first_ip": "172.16.2.10",
"last_ip": "172.16.2.250",
"gateway": "172.16.2.254",
"scope": scope.name,
}
response = post(
......@@ -713,7 +911,7 @@ def test_create_network(client, admin_token, network_scope_factory):
)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
"(psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint",
422,
)
......@@ -724,6 +922,7 @@ def test_create_network(client, admin_token, network_scope_factory):
"address": "172.16.5.0/24",
"first_ip": "172.16.5.11",
"last_ip": "172.16.5.250",
"gateway": "172.16.5.254",
"description": "long description",
"scope": scope.name,
}
......@@ -745,6 +944,7 @@ def test_create_network_invalid_address(client, admin_token, network_scope):
"address": "foo",
"first_ip": "172.16.1.10",
"last_ip": "172.16.1.250",
"gateway": "172.16.1.254",
"scope": network_scope.name,
}
response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token)
......@@ -773,6 +973,7 @@ def test_create_network_invalid_ip(
"address": "192.168.0.0/24",
"first_ip": address,
"last_ip": "192.168.0.250",
"gateway": "192.168.0.254",
"scope": network_scope.name,
}
response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token)
......@@ -786,6 +987,7 @@ def test_create_network_invalid_ip(
"address": "192.168.0.0/24",
"first_ip": "192.168.0.250",
"last_ip": address,
"gateway": "192.168.0.254",
"scope": network_scope.name,
}
response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token)
......@@ -802,6 +1004,7 @@ def test_create_network_invalid_range(client, session, admin_token, network_scop
"address": "172.16.1.0/24",
"first_ip": "172.16.2.10",
"last_ip": "172.16.1.250",
"gateway": "172.16.1.254",
"scope": network_scope.name,
}
response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token)
......@@ -815,6 +1018,7 @@ def test_create_network_invalid_range(client, session, admin_token, network_scop
"address": "172.16.1.0/24",
"first_ip": "172.16.1.10",
"last_ip": "172.16.5.250",
"gateway": "172.16.1.1",
"scope": network_scope.name,
}
response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token)
......@@ -828,6 +1032,7 @@ def test_create_network_invalid_range(client, session, admin_token, network_scop
"address": "172.16.1.0/24",
"first_ip": "172.16.1.10",
"last_ip": "172.16.1.9",
"gateway": "172.16.1.1",
"scope": network_scope.name,
}
response = post(client, f"{API_URL}/network/networks", data=data, token=admin_token)
......@@ -838,17 +1043,26 @@ def test_create_network_invalid_range(client, session, admin_token, network_scop
)
def test_get_interfaces(client, network_factory, interface_factory, readonly_token):
def test_get_interfaces(
client, network_scope_factory, network_factory, interface_factory, readonly_token
):
# Create some interfaces
scope = network_scope_factory(supernet="192.168.0.0/16")
network1 = network_factory(
address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250"
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
network2 = network_factory(
address="192.168.2.0/24", first_ip="192.168.2.10", last_ip="192.168.2.250"
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
scope=scope,
)
interface1 = interface_factory(network=network1, ip="192.168.1.10")
interface2 = interface_factory(
network=network1, ip="192.168.1.11", name="interface2"
network=network1, ip="192.168.1.11", host=interface1.host
)
interface3 = interface_factory(network=network2, ip="192.168.2.10")
......@@ -871,8 +1085,14 @@ def test_get_interfaces(client, network_factory, interface_factory, readonly_tok
def test_get_interfaces_by_domain(
client, domain_factory, network_factory, interface_factory, readonly_token
client,
domain_factory,
network_scope_factory,
network_factory,
interface_factory,
readonly_token,
):
scope = network_scope_factory(supernet="192.168.0.0/16")
# Create some interfaces
domain1 = domain_factory(name="tn.esss.lu.se")
domain2 = domain_factory(name="ics.esss.lu.se")
......@@ -881,17 +1101,17 @@ def test_get_interfaces_by_domain(
first_ip="192.168.1.10",
last_ip="192.168.1.250",
domain=domain1,
scope=scope,
)
network2 = network_factory(
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
domain=domain2,
scope=scope,
)
interface1 = interface_factory(network=network1, ip="192.168.1.10")
interface2 = interface_factory(
network=network1, ip="192.168.1.11", name="interface2"
)
interface2 = interface_factory(network=network1, ip="192.168.1.11")
interface3 = interface_factory(network=network2, ip="192.168.2.10")
# test filtering by domain
......@@ -917,25 +1137,26 @@ def test_get_interfaces_by_domain(
def test_get_interfaces_by_network(
client, network_factory, interface_factory, readonly_token
client, network_scope_factory, network_factory, interface_factory, readonly_token
):
scope = network_scope_factory(supernet="192.168.0.0/16")
# Create some interfaces
network1 = network_factory(
vlan_name="MyNetwork1",
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
network2 = network_factory(
vlan_name="MyNetwork2",
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
scope=scope,
)
interface1 = interface_factory(network=network1, ip="192.168.1.10")
interface2 = interface_factory(
network=network1, ip="192.168.1.11", name="interface2"
)
interface2 = interface_factory(network=network1, ip="192.168.1.11")
interface3 = interface_factory(network=network2, ip="192.168.2.10")
# test filtering by network name
......@@ -962,116 +1183,196 @@ def test_get_interfaces_with_model(
host1 = host_factory()
model1 = model_factory(name="EX3400")
item_factory(model=model1, host_id=host1.id)
interface_factory(host_id=host1.id)
interface_factory(host=host1)
response = get(client, f"{API_URL}/network/interfaces", token=readonly_token)
assert response.get_json()[0]["model"] == "EX3400"
def test_create_interface(client, network_factory, user_token):
network = network_factory(
address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250"
@pytest.mark.parametrize(
"username, password, expected_interfaces, expected_host3_result",
[
("user_ro", "userro", ["host2", "host1"], 0),
("user_prod", "userprod", ["host2", "host1", "host3"], 1),
("admin", "adminpasswd", ["host2", "host1", "host3"], 1),
("audit", "auditpasswd", ["host2", "host1", "host3"], 1),
],
)
def test_get_interfaces_with_sensitive_network(
client,
network_scope_factory,
network_factory,
interface_factory,
host_factory,
username,
password,
expected_interfaces,
expected_host3_result,
):
# Create some interfaces
scope = network_scope_factory(supernet="192.168.0.0/16", name="ProdNetworks")
network1 = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
network2 = network_factory(
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
scope=scope,
sensitive=True,
)
host1 = host_factory(name="host1")
host2 = host_factory(name="host2")
host3 = host_factory(name="host3")
interface_factory(name=host1.name, network=network1, ip="192.168.1.15", host=host1)
interface_factory(name=host2.name, network=network1, ip="192.168.1.10", host=host2)
interface_factory(name=host3.name, network=network2, ip="192.168.2.10", host=host3)
# Retrieve interfaces
token = get_token(client, username, password)
response = get(client, f"{API_URL}/network/interfaces", token=token)
assert response.status_code == 200
interfaces = [interface["name"] for interface in response.get_json()]
assert interfaces == expected_interfaces
response = get(client, f"{API_URL}/network/interfaces?name=host1", token=token)
assert response.status_code == 200
assert len(response.get_json()) == 1
response = get(client, f"{API_URL}/network/interfaces?name=host3", token=token)
assert response.status_code == 200
assert len(response.get_json()) == expected_host3_result
def test_create_interface_fails(client, host, network_192_168_1, no_login_check_token):
# check that network_id and ip are mandatory
response = post(client, f"{API_URL}/network/interfaces", data={}, token=user_token)
check_response_message(response, "Missing mandatory field 'network'", 422)
response = post(
client,
f"{API_URL}/network/interfaces",
data={"ip": "192.168.1.20"},
token=user_token,
client, f"{API_URL}/network/interfaces", data={}, token=no_login_check_token
)
check_response_message(response, "Missing mandatory field 'network'", 422)
check_response_message(response, "At least one field is required", 422)
response = post(
client,
f"{API_URL}/network/interfaces",
data={"network": network.address},
token=user_token,
data={"ip": "192.168.1.20"},
token=no_login_check_token,
)
check_response_message(response, "Missing mandatory field 'ip'", 422)
check_response_message(response, "Missing mandatory field 'network'", 422)
data = {"network": network.vlan_name, "ip": "192.168.1.20", "name": "interface1"}
data = {
"network": network_192_168_1.vlan_name,
"ip": "192.168.1.20",
"name": "interface1",
}
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=user_token
client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token
)
assert response.status_code == 201
assert {
"id",
"network",
"ip",
"name",
"mac",
"domain",
"host",
"device_type",
"model",
"cnames",
"tags",
"created_at",
"updated_at",
"user",
} == set(response.get_json().keys())
assert response.get_json()["network"] == network.vlan_name
assert response.get_json()["ip"] == "192.168.1.20"
assert response.get_json()["name"] == "interface1"
check_response_message(response, "Missing mandatory field 'host'", 422)
# Check that IP and name shall be unique
data["host"] = host.name
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=user_token
client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token
)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
422,
response, f"Interface name shall start with the host name '{host.name}'", 422
)
def test_create_interface(client, host, network_192_168_1, no_login_check_token):
data = {
"network": network_192_168_1.vlan_name,
"ip": "192.168.1.20",
"name": host.name,
"host": host.name,
}
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token
)
assert response.status_code == 201
assert INTERFACE_KEYS == set(response.get_json().keys())
assert response.get_json()["network"] == network_192_168_1.vlan_name
assert response.get_json()["ip"] == "192.168.1.20"
assert response.get_json()["name"] == host.name
# This is the main interface
assert response.get_json()["is_main"]
# Check that all parameters can be passed
data2 = {"network": network.vlan_name, "ip": "192.168.1.21", "name": "myhostname"}
data2 = {
"network": network_192_168_1.vlan_name,
"ip": "192.168.1.21",
"name": host.name + "-2",
"description": "The second interface",
"host": host.name,
"mac": "7c:e2:ca:64:d0:68",
}
response = post(
client, f"{API_URL}/network/interfaces", data=data2, token=user_token
client, f"{API_URL}/network/interfaces", data=data2, token=no_login_check_token
)
assert response.status_code == 201
# This is not the main interface
assert not response.get_json()["is_main"]
# check all items that were created
assert models.Interface.query.count() == 2
# Check that IP and name shall be unique
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token
)
check_response_message(
response,
"(psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint",
422,
)
@pytest.mark.parametrize("ip", ("", "foo", "192.168"))
def test_create_interface_invalid_ip(ip, client, network_factory, user_token):
network = network_factory(
address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250"
)
def test_create_interface_invalid_ip(
ip, client, host, network_192_168_1, no_login_check_token
):
# invalid IP address
data = {"network": network.vlan_name, "ip": ip, "name": "hostname"}
data = {
"network": network_192_168_1.vlan_name,
"ip": ip,
"name": host.name,
"host": host.name,
}
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=user_token
client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token
)
check_response_message(
response, f"'{ip}' does not appear to be an IPv4 or IPv6 address", 422
)
def test_create_interface_ip_not_in_network(client, network_factory, user_token):
network = network_factory(
address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250"
)
def test_create_interface_ip_not_in_network(
client, host, network_192_168_1, no_login_check_token
):
# IP address not in range
data = {"network": network.vlan_name, "ip": "192.168.2.4", "name": "hostname"}
data = {
"network": network_192_168_1.vlan_name,
"ip": "192.168.2.4",
"name": host.name,
"host": host.name,
}
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=user_token
client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token
)
check_response_message(
response, "IP address 192.168.2.4 is not in network 192.168.1.0/24", 422
)
def test_create_interface_ip_not_in_range(client, network_factory, user_token):
network = network_factory(
address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250"
)
def test_create_interface_ip_not_in_range(
client, host, network_192_168_1, no_login_check_token
):
# IP address not in range
data = {"network": network.vlan_name, "ip": "192.168.1.4", "name": "hostname"}
data = {
"network": network_192_168_1.vlan_name,
"ip": "192.168.1.4",
"name": host.name,
"host": host.name,
}
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=user_token
client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token
)
check_response_message(
response,
......@@ -1081,35 +1382,98 @@ def test_create_interface_ip_not_in_range(client, network_factory, user_token):
def test_create_interface_ip_not_in_range_as_admin(
client, network_factory, admin_token
client, host, network_192_168_1, admin_token
):
network = network_factory(
address="192.168.1.0/24", first_ip="192.168.1.10", last_ip="192.168.1.250"
)
# IP address not in range
data = {"network": network.vlan_name, "ip": "192.168.1.4", "name": "hostname"}
data = {
"network": network_192_168_1.vlan_name,
"ip": "192.168.1.4",
"name": host.name,
"host": host.name,
}
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=admin_token
)
assert response.status_code == 201
def test_delete_interface_invalid_credentials(client, interface_factory, user_token):
interface1 = interface_factory()
response = delete(
client, f"{API_URL}/network/interfaces/{interface1.id}", token=user_token
def test_normal_user_can_create_interface_on_empty_host(
client, host, network_scope_factory, network_factory, user_prod_token
):
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
network = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
assert response.status_code == 403
assert len(models.Interface.query.all()) == 1
data = {
"network": network.vlan_name,
"ip": "192.168.1.20",
"name": host.name,
"host": host.name,
}
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=user_prod_token
)
assert response.status_code == 201
assert response.get_json()["network"] == network.vlan_name
assert response.get_json()["ip"] == "192.168.1.20"
assert response.get_json()["name"] == host.name
# This is the main interface
assert response.get_json()["is_main"]
def test_create_interface_with_ip(
client, host, network_192_168_1, no_login_check_token
):
data = {
"network": network_192_168_1.vlan_name,
"name": host.name,
"host": host.name,
"ip": "192.168.1.12",
}
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token
)
assert response.status_code == 201
assert response.get_json()["ip"] == "192.168.1.12"
def test_create_interface_without_ip(
client, host, network_192_168_1, no_login_check_token
):
data = {
"network": network_192_168_1.vlan_name,
"name": host.name,
"host": host.name,
}
response = post(
client, f"{API_URL}/network/interfaces", data=data, token=no_login_check_token
)
assert response.status_code == 201
assert response.get_json()["ip"] == "192.168.1.10"
def test_delete_interface_success(client, interface_factory, admin_token):
def test_delete_interface_normal_user(client, interface_factory, user_token):
interface1 = interface_factory()
response = delete(
client, f"{API_URL}/network/interfaces/{interface1.id}", token=admin_token
check_delete_success(
client,
user_token,
interface1,
"network/interfaces",
models.Interface,
)
def test_delete_interface_success(client, admin_token, interface):
check_delete_success(
client,
admin_token,
interface,
"network/interfaces",
models.Interface,
)
assert response.status_code == 204
assert len(models.Interface.query.all()) == 0
def test_delete_interface_invalid_id(client, interface_factory, admin_token):
......@@ -1127,7 +1491,7 @@ def test_get_macs(client, mac_factory, readonly_token):
mac1 = mac_factory()
mac2 = mac_factory()
response = get(client, f"{API_URL}/network/macs", token=readonly_token)
response = get(client, f"{API_URL}/inventory/macs", token=readonly_token)
assert response.status_code == 200
assert len(response.get_json()) == 2
check_input_is_subset_of_response(response, (mac1.to_dict(), mac2.to_dict()))
......@@ -1136,26 +1500,26 @@ def test_get_macs(client, mac_factory, readonly_token):
def test_create_mac(client, item_factory, user_token):
item = item_factory()
# check that address is mandatory
response = post(client, f"{API_URL}/network/macs", data={}, token=user_token)
check_response_message(response, "Missing mandatory field 'address'", 422)
response = post(client, f"{API_URL}/inventory/macs", data={}, token=user_token)
check_response_message(response, "At least one field is required", 422)
data = {"address": "b5:4b:7d:a4:23:43"}
response = post(client, f"{API_URL}/network/macs", data=data, token=user_token)
response = post(client, f"{API_URL}/inventory/macs", data=data, token=user_token)
assert response.status_code == 201
assert {"id", "address", "item", "interfaces"} == set(response.get_json().keys())
assert {"id", "address", "item"} == set(response.get_json().keys())
assert response.get_json()["address"] == data["address"]
# Check that address shall be unique
response = post(client, f"{API_URL}/network/macs", data=data, token=user_token)
response = post(client, f"{API_URL}/inventory/macs", data=data, token=user_token)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
"(psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint",
422,
)
# Check that all parameters can be passed
data2 = {"address": "b5:4b:7d:a4:23:44", "item_id": item.id}
response = post(client, f"{API_URL}/network/macs", data=data2, token=user_token)
response = post(client, f"{API_URL}/inventory/macs", data=data2, token=user_token)
assert response.status_code == 201
# check that all items were created
......@@ -1165,7 +1529,7 @@ def test_create_mac(client, item_factory, user_token):
@pytest.mark.parametrize("address", ("", "foo", "b5:4b:7d:a4:23"))
def test_create_mac_invalid_address(address, client, user_token):
data = {"address": address}
response = post(client, f"{API_URL}/network/macs", data=data, token=user_token)
response = post(client, f"{API_URL}/inventory/macs", data=data, token=user_token)
check_response_message(
response, f"'{address}' does not appear to be a MAC address", 422
)
......@@ -1184,7 +1548,7 @@ def test_get_ansible_groups(client, ansible_group_factory, readonly_token):
def test_create_ansible_group(client, admin_token):
# check that name is mandatory
response = post(client, f"{API_URL}/network/groups", data={}, token=admin_token)
check_response_message(response, "Missing mandatory field 'name'", 422)
check_response_message(response, "At least one field is required", 422)
data = {"name": "mygroup"}
response = post(client, f"{API_URL}/network/groups", data=data, token=admin_token)
......@@ -1206,7 +1570,7 @@ def test_create_ansible_group(client, admin_token):
response = post(client, f"{API_URL}/network/groups", data=data, token=admin_token)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
"Group name matches an existing group",
422,
)
......@@ -1220,83 +1584,211 @@ def test_create_ansible_group_with_vars(client, admin_token):
assert group.vars == data["vars"]
def test_get_hosts(client, host_factory, readonly_token):
def test_get_hosts(client, host_factory, admin_token):
# Create some hosts
host1 = host_factory()
host2 = host_factory()
response = get(client, f"{API_URL}/network/hosts", token=readonly_token)
response = get(client, f"{API_URL}/network/hosts", token=admin_token)
assert response.status_code == 200
assert len(response.get_json()) == 2
assert HOST_KEYS == set(response.get_json()[0].keys())
check_input_is_subset_of_response(response, (host1.to_dict(), host2.to_dict()))
def test_get_hosts_with_ansible_vars(client, host_factory, readonly_token):
def test_get_hosts_with_ansible_vars(client, host_factory, admin_token):
vars = {"foo": "hello", "mylist": [1, 2, 3]}
host_factory(ansible_vars=vars)
response = get(client, f"{API_URL}/network/hosts", token=readonly_token)
response = get(client, f"{API_URL}/network/hosts", token=admin_token)
assert response.status_code == 200
assert response.get_json()[0]["ansible_vars"] == vars
def test_get_hosts_with_model(
client, model_factory, item_factory, host_factory, readonly_token
client, model_factory, item_factory, host_factory, admin_token
):
host1 = host_factory()
model1 = model_factory(name="EX3400")
item_factory(model=model1, host_id=host1.id)
response = get(client, f"{API_URL}/network/hosts", token=readonly_token)
response = get(client, f"{API_URL}/network/hosts", token=admin_token)
assert response.status_code == 200
assert response.get_json()[0]["model"] == "EX3400"
def test_get_hosts_with_no_model(client, host_factory, readonly_token):
def test_get_hosts_with_no_model(client, host_factory, admin_token):
host_factory()
response = get(client, f"{API_URL}/network/hosts", token=readonly_token)
response = get(client, f"{API_URL}/network/hosts", token=admin_token)
assert response.status_code == 200
assert response.get_json()[0]["model"] is None
def test_create_host(client, device_type_factory, user_token):
device_type = device_type_factory(name="Virtual")
# check that name and device_type are mandatory
response = post(client, f"{API_URL}/network/hosts", data={}, token=user_token)
check_response_message(response, "Missing mandatory field 'name'", 422)
response = post(
client, f"{API_URL}/network/hosts", data={"name": "myhost"}, token=user_token
)
check_response_message(response, "Missing mandatory field 'device_type'", 422)
response = post(
client,
f"{API_URL}/network/hosts",
data={"device_type": "Physical"},
token=user_token,
def test_get_hosts_recursive_interfaces(
client, host_factory, interface_factory, readonly_token
):
# Create some hosts with interfaces
host1 = host_factory()
interface11 = interface_factory(name=host1.name, host=host1)
interface12 = interface_factory(host=host1)
host2 = host_factory()
interface21 = interface_factory(host=host2)
# Without recursive, we only get the name of the interfaces
response = get(client, f"{API_URL}/network/hosts", token=readonly_token)
assert response.status_code == 200
assert len(response.get_json()) == 2
rhost1, rhost2 = response.get_json()
# Interfaces shall be sorted
assert rhost1["interfaces"] == sorted([interface11.name, interface12.name])
assert rhost2["interfaces"] == [interface21.name]
# With recursive, interfaces are expanded
response = get(
client, f"{API_URL}/network/hosts?recursive=true", token=readonly_token
)
check_response_message(response, "Missing mandatory field 'name'", 422)
assert response.status_code == 200
assert len(response.get_json()) == 2
rhost1, rhost2 = response.get_json()
assert len(rhost1["interfaces"]) == 2
rinterface11, rinterface12 = rhost1["interfaces"]
assert INTERFACE_KEYS == set(rinterface11.keys())
assert INTERFACE_KEYS == set(rinterface12.keys())
assert len(rhost2["interfaces"]) == 1
rinterface21 = rhost2["interfaces"][0]
assert INTERFACE_KEYS == set(rinterface21.keys())
assert rinterface21["network"] == interface21.network.vlan_name
def test_get_hosts_recursive_items(client, item_factory, host_factory, admin_token):
host1 = host_factory()
item11 = item_factory(ics_id="AAA001", host_id=host1.id, stack_member=1)
item12 = item_factory(ics_id="AAA002", host_id=host1.id, stack_member=0)
host2 = host_factory()
item21 = item_factory(ics_id="AAB001", host_id=host2.id)
item22 = item_factory(ics_id="AAB002", host_id=host2.id)
# Without recursive, we only get the ics_id of the items
response = get(client, f"{API_URL}/network/hosts", token=admin_token)
assert response.status_code == 200
assert len(response.get_json()) == 2
rhost1, rhost2 = response.get_json()
# items are sorted by stack_member
assert rhost1["items"] == ["AAA002", "AAA001"]
# or by ics_id when stack_member is None
assert rhost2["items"] == ["AAB001", "AAB002"]
# With recursive, items are expanded
response = get(client, f"{API_URL}/network/hosts?recursive=true", token=admin_token)
assert response.status_code == 200
assert len(response.get_json()) == 2
rhost1, rhost2 = response.get_json()
assert len(rhost1["items"]) == 2
# Items shall be sorted by stack_member
ritem11, ritem12 = rhost1["items"]
assert ritem11 == {
"ics_id": item12.ics_id,
"serial_number": item12.serial_number,
"stack_member": 0,
}
assert ritem12 == {
"ics_id": item11.ics_id,
"serial_number": item11.serial_number,
"stack_member": 1,
}
assert len(rhost2["items"]) == 2
# or ics_id when no stack_member
ritem21, ritem22 = rhost2["items"]
assert ritem21 == {
"ics_id": item21.ics_id,
"serial_number": item21.serial_number,
"stack_member": None,
}
assert ritem22 == {
"ics_id": item22.ics_id,
"serial_number": item22.serial_number,
"stack_member": None,
}
data = {"name": "my-hostname", "device_type": device_type.name}
@pytest.mark.parametrize(
"username, password, expected_hosts, expected_host3_result",
[
("admin", "adminpasswd", ["host1", "host2", "host3", "host4"], 1),
("audit", "auditpasswd", ["host1", "host2", "host3", "host4"], 1),
("user_prod", "userprod", ["host1", "host2", "host3"], 1),
("user_ro", "userro", ["host1", "host2"], 0),
],
)
def test_get_hosts_sensitive(
client,
network_scope_factory,
network_factory,
interface_factory,
host_factory,
username,
password,
expected_hosts,
expected_host3_result,
):
# Create some hosts
scope = network_scope_factory(supernet="192.168.0.0/16", name="ProdNetworks")
network1 = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
network2 = network_factory(
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
scope=scope,
sensitive=True,
)
host1 = host_factory(name="host1")
host2 = host_factory(name="host2")
host3 = host_factory(name="host3")
# Add a host without interface
host_factory(name="host4")
interface_factory(name=host1.name, network=network1, ip="192.168.1.10", host=host1)
interface_factory(name=host2.name, network=network1, ip="192.168.1.11", host=host2)
interface_factory(name=host3.name, network=network2, ip="192.168.2.10", host=host3)
# Retrieve the hosts
token = get_token(client, username, password)
response = get(client, f"{API_URL}/network/hosts", token=token)
assert response.status_code == 200
hosts = [host["name"] for host in response.get_json()]
assert hosts == expected_hosts
response = get(client, f"{API_URL}/network/hosts?name=host1", token=token)
assert response.status_code == 200
assert len(response.get_json()) == 1
response = get(client, f"{API_URL}/network/hosts?name=host3", token=token)
assert response.status_code == 200
assert len(response.get_json()) == expected_host3_result
def test_create_host(client, device_type_factory, user_token):
device_type = device_type_factory(name="Virtual")
# check that name and device_type are mandatory
response = post(client, f"{API_URL}/network/hosts", data={}, token=user_token)
check_response_message(response, "At least one field is required", 422)
response = post(
client, f"{API_URL}/network/hosts", data={"name": "myhost"}, token=user_token
)
check_response_message(response, "Missing mandatory field 'device_type'", 422)
response = post(
client,
f"{API_URL}/network/hosts",
data={"device_type": "Physical"},
token=user_token,
)
check_response_message(response, "Missing mandatory field 'name'", 422)
data = {"name": "my-hostname", "device_type": device_type.name}
response = post(client, f"{API_URL}/network/hosts", data=data, token=user_token)
assert response.status_code == 201
assert {
"id",
"name",
"device_type",
"model",
"description",
"items",
"interfaces",
"ansible_vars",
"ansible_groups",
"created_at",
"updated_at",
"user",
} == set(response.get_json().keys())
assert HOST_KEYS == set(response.get_json().keys())
assert response.get_json()["name"] == data["name"]
# Check that name shall be unique
response = post(client, f"{API_URL}/network/hosts", data=data, token=user_token)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
"(psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint",
422,
)
......@@ -1370,11 +1862,14 @@ def test_delete_host_invalid_credentials(client, host_factory, user_token):
assert len(models.Host.query.all()) == 1
def test_delete_host_success(client, host_factory, admin_token):
host1 = host_factory()
response = delete(client, f"{API_URL}/network/hosts/{host1.id}", token=admin_token)
assert response.status_code == 204
assert len(models.Host.query.all()) == 0
def test_delete_host_success(client, admin_token, host):
check_delete_success(
client,
admin_token,
host,
"network/hosts",
models.Host,
)
def test_delete_host_invalid_id(client, host_factory, admin_token):
......@@ -1387,15 +1882,12 @@ def test_delete_host_invalid_id(client, host_factory, admin_token):
assert len(models.Host.query.all()) == 1
def test_delete_host_with_interfaces(
client, interface_factory, host_factory, admin_token
):
interface1 = interface_factory()
interface2 = interface_factory()
host1 = host_factory(interfaces=[interface1, interface2])
assert len(host1.interfaces) == 2
def test_delete_host_with_interfaces(client, interface_factory, host, admin_token):
interface_factory(host=host)
interface_factory(host=host)
assert len(host.interfaces) == 2
assert len(models.Interface.query.all()) == 2
response = delete(client, f"{API_URL}/network/hosts/{host1.id}", token=admin_token)
response = delete(client, f"{API_URL}/network/hosts/{host.id}", token=admin_token)
assert response.status_code == 204
assert len(models.Host.query.all()) == 0
assert len(models.Interface.query.all()) == 0
......@@ -1424,7 +1916,7 @@ def test_get_domains(client, domain_factory, readonly_token):
def test_create_domain(client, admin_token):
# check that name is mandatory
response = post(client, f"{API_URL}/network/domains", data={}, token=admin_token)
check_response_message(response, "Missing mandatory field 'name'", 422)
check_response_message(response, "At least one field is required", 422)
data = {"name": "tn.esss.lu.se"}
response = post(client, f"{API_URL}/network/domains", data=data, token=admin_token)
......@@ -1444,7 +1936,7 @@ def test_create_domain(client, admin_token):
response = post(client, f"{API_URL}/network/domains", data=data, token=admin_token)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
"(psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint",
422,
)
......@@ -1504,7 +1996,7 @@ def test_get_cnames_by_domain(
def test_create_cname(client, interface, admin_token):
# check that name and interface_id are mandatory
response = post(client, f"{API_URL}/network/cnames", data={}, token=admin_token)
check_response_message(response, "Missing mandatory field 'name'", 422)
check_response_message(response, "At least one field is required", 422)
response = post(
client, f"{API_URL}/network/cnames", data={"name": "myhost"}, token=admin_token
)
......@@ -1517,7 +2009,7 @@ def test_create_cname(client, interface, admin_token):
)
check_response_message(response, "Missing mandatory field 'name'", 422)
data = {"name": "myhost.tn.esss.lu.se", "interface_id": interface.id}
data = {"name": "myhost", "interface_id": interface.id}
response = post(client, f"{API_URL}/network/cnames", data=data, token=admin_token)
assert response.status_code == 201
assert {"id", "name", "interface", "created_at", "updated_at", "user"} == set(
......@@ -1525,10 +2017,984 @@ def test_create_cname(client, interface, admin_token):
)
assert response.get_json()["name"] == data["name"]
# Check that name shall be unique
# Check that name shall be unique by domain
response = post(client, f"{API_URL}/network/cnames", data=data, token=admin_token)
check_response_message(
response, f"Duplicate cname on the {interface.network.domain} domain", 422
)
def test_search_hosts(client, host_factory, readonly_token):
# Create some hosts
host1 = host_factory(name="test-beautiful", description="The Zen of Python")
host_factory(name="test-explicit", description="Beautiful is better than ugly.")
host_factory(name="another-host")
# When no query is passed, all hosts are returned
response = get(client, f"{API_URL}/network/hosts/search", token=readonly_token)
assert response.status_code == 200
assert len(response.get_json()) == 3
# a keyword is searched in all fields by default
response = get(
client, f"{API_URL}/network/hosts/search?q=beautiful", token=readonly_token
)
assert response.status_code == 200
assert len(response.get_json()) == 2
# a search can be restricted to a specific field
response = get(
client, f"{API_URL}/network/hosts/search?q=name:beautiful", token=readonly_token
)
assert response.status_code == 200
r = response.get_json()
assert len(r) == 1
assert HOST_KEYS == set(r[0].keys())
assert r[0]["name"] == host1.name
assert r[0]["description"] == host1.description
@pytest.mark.parametrize(
"username,password,query,expected",
[
("admin", "adminpasswd", "", 3),
("admin", "adminpasswd", "?q=beautiful", 2),
("admin", "adminpasswd", "?q=description:beautiful", 1),
("audit", "auditpasswd", "", 3),
("audit", "auditpasswd", "?q=beautiful", 2),
("audit", "auditpasswd", "?q=description:beautiful", 1),
("user_ro", "userro", "", 1),
("user_ro", "userro", "?q=beautiful", 1),
("user_ro", "userro", "?q=description:beautiful", 0),
("user_ro", "userro", "?q=description:explicit", 1),
("user_prod", "userprod", "", 3),
("user_prod", "userprod", "?q=beautiful", 2),
("user_prod", "userprod", "?q=description:beautiful", 1),
],
)
def test_search_hosts_sensitive(
username,
password,
query,
expected,
client,
network_scope_factory,
network_factory,
interface_factory,
host_factory,
):
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
network1 = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
sensitive=False,
scope=scope,
)
network2 = network_factory(
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
sensitive=True,
scope=scope,
)
host1 = host_factory(name="test-beautiful", description="search explicit")
interface_factory(name=host1.name, host=host1, network=network1, ip="192.168.1.10")
host2 = host_factory(name="test-explicit")
interface_factory(name=host2.name, host=host2, network=network2, ip="192.168.2.10")
host3 = host_factory(name="another-host", description="search beautiful")
interface_factory(name=host3.name, host=host3, network=network2, ip="192.168.2.11")
token = get_token(client, username, password)
response = get(client, f"{API_URL}/network/hosts/search{query}", token=token)
assert response.status_code == 200
assert len(response.get_json()) == expected
@pytest.mark.parametrize("endpoint", ["network/hosts", "network/hosts/search"])
def test_pagination(endpoint, client, host_factory, admin_token):
# MAX_PER_PAGE set to 25 for testing
# Create 30 hosts
for i in range(30):
host_factory()
if endpoint == "network/hosts":
extra_args = "&recursive=False"
else:
extra_args = ""
# By default 20 hosts per page shall be returned
response = get(client, f"{API_URL}/{endpoint}", token=admin_token)
assert response.status_code == 200
assert len(response.get_json()) == 20
assert response.headers["x-total-count"] == "30"
assert (
f'{API_URL}/{endpoint}?per_page=20&page=2{extra_args}>; rel="next",'
in response.headers["link"]
)
assert 'rel="prev"' not in response.headers["link"]
assert 'rel="first"' not in response.headers["link"]
# Get second page (which is last)
response = get(
client,
f"{API_URL}/{endpoint}?per_page=20&page=2{extra_args}",
token=admin_token,
)
assert response.status_code == 200
assert len(response.get_json()) == 10
assert (
f'{API_URL}/{endpoint}?per_page=20&page=1{extra_args}>; rel="first",'
in response.headers["link"]
)
assert (
f'{API_URL}/{endpoint}?per_page=20&page=1{extra_args}>; rel="prev"'
in response.headers["link"]
)
assert 'rel="next"' not in response.headers["link"]
assert 'rel="last"' not in response.headers["link"]
# Request 10 elements per_page
response = get(client, f"{API_URL}/{endpoint}?per_page=10", token=admin_token)
assert response.status_code == 200
assert len(response.get_json()) == 10
assert response.headers["x-total-count"] == "30"
assert (
f'{API_URL}/{endpoint}?per_page=10&page=2{extra_args}>; rel="next",'
in response.headers["link"]
)
assert (
f'{API_URL}/{endpoint}?per_page=10&page=3{extra_args}>; rel="last"'
in response.headers["link"]
)
# You can't request more than MAX_PER_PAGE elements
response = get(client, f"{API_URL}/{endpoint}?per_page=50", token=admin_token)
assert response.status_code == 200
assert len(response.get_json()) == 25
assert response.headers["x-total-count"] == "30"
assert (
f'{API_URL}/{endpoint}?per_page=25&page=2{extra_args}>; rel="next",'
in response.headers["link"]
)
assert (
f'{API_URL}/{endpoint}?per_page=25&page=2{extra_args}>; rel="last"'
in response.headers["link"]
)
@pytest.mark.parametrize(
"username, password, expected_total",
[("user_ro", "userro", 45), ("user_prod", "userprod", 55)],
)
def test_get_hosts_normal_user(
client,
network_scope_factory,
network_factory,
interface_factory,
host_factory,
username,
password,
expected_total,
):
# Create several networks
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
admin_network = network_factory(
vlan_name="admin-network",
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
admin_only=True,
scope=scope,
)
sensitive_network = network_factory(
vlan_name="sensitive-network",
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
admin_only=True,
sensitive=True,
scope=scope,
)
user1_network = network_factory(
vlan_name="user1-network",
address="192.168.3.0/24",
first_ip="192.168.3.10",
last_ip="192.168.3.250",
scope=scope,
)
user2_network = network_factory(
vlan_name="user2-network",
address="192.168.4.0/24",
first_ip="192.168.4.10",
last_ip="192.168.4.250",
scope=scope,
)
user3_network = network_factory(
vlan_name="user3-network",
address="192.168.5.0/24",
first_ip="192.168.5.10",
last_ip="192.168.5.250",
scope=scope,
)
# Create 15 hosts on admin network
create_hosts(15, host_factory, interface_factory, admin_network)
# Create 10 hosts on sensitive network (should be filtered out for user not member of scope)
create_hosts(10, host_factory, interface_factory, sensitive_network)
# Create 20 hosts on user1 network
create_hosts(20, host_factory, interface_factory, user1_network)
# Create 10 hosts on user2 and user3 networks
# Hosts with several interfaces required to reproduce INFRA-1888
create_hosts(10, host_factory, interface_factory, user2_network, user3_network)
# Create 8 more hosts without interfaces (should be filtered out)
create_hosts(8, host_factory, interface_factory)
url = f"{API_URL}/network/hosts"
# By default 20 hosts per page shall be returned
token = get_token(client, username, password)
response = get(client, url, token=token)
assert response.status_code == 200
assert len(response.get_json()) == 20
assert response.headers["x-total-count"] == str(expected_total)
assert (
f'{url}?per_page=20&page=2&recursive=False>; rel="next",'
in response.headers["link"]
)
# Get second page
response = get(
client,
f"{url}?per_page=20&page=2&recursive=False",
token=token,
)
assert response.status_code == 200
assert len(response.get_json()) == 20
assert (
f'{url}?per_page=20&page=1&recursive=False>; rel="first",'
in response.headers["link"]
)
assert (
f'{url}?per_page=20&page=1&recursive=False>; rel="prev"'
in response.headers["link"]
)
assert (
f'{url}?per_page=20&page=3&recursive=False>; rel="next",'
in response.headers["link"]
)
assert (
f'{url}?per_page=20&page=3&recursive=False>; rel="last"'
in response.headers["link"]
)
# Get third page
response = get(
client,
f"{url}?per_page=20&page=3&recursive=False",
token=token,
)
assert response.status_code == 200
assert len(response.get_json()) == expected_total - 40
assert (
f'{url}?per_page=20&page=2&recursive=False>; rel="prev"'
in response.headers["link"]
)
assert 'rel="next"' not in response.headers["link"]
assert 'rel="last"' not in response.headers["link"]
def test_patch_host_no_data(client, host_factory, admin_token):
host = host_factory()
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data={}, token=admin_token
)
check_response_message(response, "At least one field is required", 422)
@pytest.mark.parametrize("field,value", [("foo", "xxxx"), ("name", "myhost")])
def test_patch_host_invalid_fields(client, host_factory, admin_token, field, value):
host = host_factory()
response = patch(
client,
f"{API_URL}/network/hosts/{host.id}",
data={field: value},
token=admin_token,
)
check_response_message(response, f"Invalid field '{field}'", 422)
@pytest.mark.parametrize(
"field,value",
[
("description", "This is a test"),
("ansible_vars", {"myvar": "hello", "another": "world"}),
("is_ioc", False),
("is_ioc", True),
],
)
def test_patch_host(client, host_factory, admin_token, field, value):
# Create a host
host = host_factory()
data = {field: value}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=admin_token
)
assert response.status_code == 200
assert response.get_json()[field] == value
updated_host = models.Host.query.get(host.id)
assert getattr(updated_host, field) == value
def test_patch_host_device_type(client, host_factory, device_type_factory, admin_token):
host = host_factory()
device_type = device_type_factory(name="MyDevice")
data = {"device_type": device_type.name}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=admin_token
)
assert response.status_code == 200
assert response.get_json()["device_type"] == device_type.name
updated_host = models.Host.query.get(host.id)
assert updated_host.device_type == device_type
def test_patch_host_invalid_device_type(client, host_factory, admin_token):
host = host_factory()
data = {"device_type": "foo"}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=admin_token
)
check_response_message(response, "foo is not a valid devicetype", 400)
@pytest.mark.parametrize("groups", (["group1"], ["group1", "group2"]))
def test_patch_host_ansible_groups(
client, host_factory, ansible_group_factory, admin_token, groups
):
host = host_factory()
for group_name in groups:
ansible_group_factory(name=group_name)
data = {"ansible_groups": groups}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=admin_token
)
assert response.status_code == 200
assert response.get_json()["ansible_groups"] == groups
updated_host = models.Host.query.get(host.id)
for group_name in groups:
group = models.AnsibleGroup.query.filter_by(name=group_name).first()
assert group.hosts == [updated_host]
def test_patch_host_single_ansible_group(
client, host_factory, ansible_group_factory, admin_token
):
host = host_factory()
group = ansible_group_factory(name="my_group")
data = {"ansible_groups": group.name}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=admin_token
)
assert response.status_code == 200
assert response.get_json()["ansible_groups"] == [group.name]
updated_host = models.Host.query.get(host.id)
assert updated_host.ansible_groups == [group]
def test_patch_host_invalid_ansible_group(client, host_factory, admin_token):
host = host_factory()
data = {"ansible_groups": "unknown_group"}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=admin_token
)
check_response_message(response, "unknown_group is not a valid ansiblegroup", 400)
@pytest.mark.parametrize("items", (["AAA001"], ["AAB001", "AAB002"]))
def test_patch_host_items(client, host_factory, item_factory, admin_token, items):
host = host_factory()
for ics_id in items:
item_factory(ics_id=ics_id)
data = {"items": items}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=admin_token
)
assert response.status_code == 200
assert response.get_json()["items"] == items
updated_host = models.Host.query.get(host.id)
for ics_id in items:
item = models.Item.query.filter_by(ics_id=ics_id).first()
assert item.host == updated_host
def test_patch_host_single_item(client, host_factory, item_factory, admin_token):
host = host_factory()
item = item_factory(ics_id="BBB001")
data = {"items": item.ics_id}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=admin_token
)
assert response.status_code == 200
assert response.get_json()["items"] == [item.ics_id]
updated_host = models.Host.query.get(host.id)
assert updated_host.items == [item]
def test_patch_host_invalid_item(client, host_factory, admin_token):
host = host_factory()
data = {"items": "ABC002"}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=admin_token
)
check_response_message(response, "ABC002 is not a valid item", 400)
def test_patch_host_network_permission(
client,
network_scope_factory,
network_factory,
host_factory,
interface_factory,
user_token,
):
scope = network_scope_factory(name="FooNetworks", supernet="192.168.0.0/16")
network = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
host = host_factory()
interface_factory(ip="192.168.1.11", host=host, network=network)
data = {"description": "Hello world"}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=user_token
)
assert response.status_code == 200
def test_patch_host_invalid_network_permission(
client,
network_scope_factory,
network_factory,
host_factory,
interface_factory,
user_token,
):
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
network = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
host = host_factory()
interface_factory(ip="192.168.1.11", host=host, network=network)
data = {"description": "Hello world"}
response = patch(
client, f"{API_URL}/network/hosts/{host.id}", data=data, token=user_token
)
check_response_message(response, "User doesn't have the required group", 403)
def test_patch_interface_no_data(client, interface_factory, admin_token):
interface = interface_factory()
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data={},
token=admin_token,
)
check_response_message(response, "At least one field is required", 422)
@pytest.mark.parametrize(
"field,value", [("foo", "xxxx"), ("host", "myhost"), ("cnames", "alias")]
)
def test_patch_interface_invalid_fields(
client, interface_factory, admin_token, field, value
):
interface = interface_factory()
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data={field: value},
token=admin_token,
)
check_response_message(response, f"Invalid field '{field}'", 422)
def test_patch_interface_mac(client, interface_factory, admin_token):
interface = interface_factory()
data = {"mac": "02:42:42:b2:01:c6"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data=data,
token=admin_token,
)
assert response.status_code == 200
assert response.get_json()["mac"] == data["mac"]
updated_interface = models.Interface.query.get(interface.id)
assert updated_interface.mac == data["mac"]
def test_patch_interface_ip(client, interface_factory, network_192_168_1, admin_token):
interface = interface_factory(network=network_192_168_1, ip="192.168.1.11")
data = {"ip": "192.168.2.12"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data=data,
token=admin_token,
)
check_response_message(
response,
"(psycopg2.IntegrityError) duplicate key value violates unique constraint",
f"IP address {data['ip']} is not in network {network_192_168_1.address}",
422,
)
data = {"ip": "192.168.1.12"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data=data,
token=admin_token,
)
assert response.status_code == 200
assert response.get_json()["ip"] == data["ip"]
updated_interface = models.Interface.query.get(interface.id)
assert updated_interface.ip == data["ip"]
def test_patch_interface_name(client, host_factory, interface_factory, admin_token):
host = host_factory(name="myhost")
interface = interface_factory(host=host)
data = {"name": "foo"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data=data,
token=admin_token,
)
check_response_message(
response, f"Interface name shall start with the host name '{host.name}'", 422
)
data = {"name": host.name + "-2"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data=data,
token=admin_token,
)
assert response.status_code == 200
assert response.get_json()["name"] == data["name"]
updated_interface = models.Interface.query.get(interface.id)
assert updated_interface.name == data["name"]
def test_patch_interface_network(
client, network_scope_factory, network_factory, interface_factory, admin_token
):
scope = network_scope_factory(supernet="192.168.0.0/16")
network1 = network_factory(
vlan_name="mynetwork",
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
network2 = network_factory(
vlan_name="new-network",
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
scope=scope,
)
interface = interface_factory(network=network1, ip="192.168.1.20")
data = {"network": "unknown"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data=data,
token=admin_token,
)
check_response_message(response, "Resource not found", 404)
data = {"network": network2.vlan_name}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data=data,
token=admin_token,
)
check_response_message(
response, f"IP address {interface.ip} is not in network {network2.address}", 422
)
data = {"network": network2.vlan_name, "ip": "192.168.5.10"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data=data,
token=admin_token,
)
check_response_message(
response, f"IP address {data['ip']} is not in network {network2.address}", 422
)
data = {"network": network2.vlan_name, "ip": "192.168.2.10"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface.id}",
data=data,
token=admin_token,
)
assert response.status_code == 200
assert response.get_json()["network"] == data["network"]
assert response.get_json()["ip"] == data["ip"]
updated_interface = models.Interface.query.get(interface.id)
assert updated_interface.network == network2
def test_patch_interface_current_network_permission(
client, network_scope_factory, network_factory, interface_factory, user_token
):
scope_prod = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/22")
scope_foo = network_scope_factory(name="FooNetworks", supernet="192.168.4.0/22")
network_prod = network_factory(
vlan_name="prod-network",
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope_prod,
)
network_foo = network_factory(
vlan_name="foo-network",
address="192.168.4.0/24",
first_ip="192.168.4.10",
last_ip="192.168.4.250",
scope=scope_foo,
)
# User can't update an interface part of the ProdNetworks
interface_prod = interface_factory(network=network_prod, ip="192.168.1.20")
data = {"ip": "192.168.1.21"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface_prod.id}",
data=data,
token=user_token,
)
check_response_message(response, "User doesn't have the required group", 403)
# but can on the FooNetworks
interface_foo = interface_factory(network=network_foo, ip="192.168.4.20")
data = {"ip": "192.168.4.21"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface_foo.id}",
data=data,
token=user_token,
)
assert response.status_code == 200
def test_patch_interface_new_network_permission(
client, network_scope_factory, network_factory, interface_factory, user_token
):
scope_prod = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/22")
scope_foo = network_scope_factory(name="FooNetworks", supernet="192.168.4.0/22")
network_prod = network_factory(
vlan_name="prod-network",
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope_prod,
)
network_foo1 = network_factory(
vlan_name="foo-network1",
address="192.168.4.0/24",
first_ip="192.168.4.10",
last_ip="192.168.4.250",
scope=scope_foo,
)
network_foo2 = network_factory(
vlan_name="foo-network2",
address="192.168.5.0/24",
first_ip="192.168.5.10",
last_ip="192.168.5.250",
scope=scope_foo,
)
interface_foo = interface_factory(network=network_foo1, ip="192.168.4.20")
# User can't change the network to the ProdNetworks
data = {"network": network_prod.vlan_name}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface_foo.id}",
data=data,
token=user_token,
)
# but can on the same scope it has access to
check_response_message(response, "User doesn't have the required group", 403)
data = {"network": network_foo2.vlan_name, "ip": "192.168.5.10"}
response = patch(
client,
f"{API_URL}/network/interfaces/{interface_foo.id}",
data=data,
token=user_token,
)
assert response.status_code == 200
def test_patch_network_no_data(client, network_factory, admin_token):
network = network_factory()
response = patch(
client, f"{API_URL}/network/networks/{network.id}", data={}, token=admin_token
)
check_response_message(response, "At least one field is required", 422)
@pytest.mark.parametrize("field,value", [("foo", "xxxx"), ("name", "mynetwork")])
def test_patch_network_invalid_fields(
client, network_factory, admin_token, field, value
):
network = network_factory()
response = patch(
client,
f"{API_URL}/network/networks/{network.id}",
data={field: value},
token=admin_token,
)
check_response_message(response, f"Invalid field '{field}'", 422)
@pytest.mark.parametrize(
"field,value",
[
("vlan_name", "new-name"),
("description", "This is a test"),
("admin_only", True),
("sensitive", False),
("sensitive", True),
],
)
def test_patch_network(client, network_factory, admin_token, field, value):
network = network_factory()
data = {field: value}
response = patch(
client, f"{API_URL}/network/networks/{network.id}", data=data, token=admin_token
)
assert response.status_code == 200
assert response.get_json()[field] == value
updated_network = models.Network.query.get(network.id)
if isinstance(value, bool):
assert getattr(updated_network, field) is value
else:
assert getattr(updated_network, field) == value
def test_patch_network_domain(client, network_factory, domain_factory, admin_token):
network = network_factory()
domain = domain_factory(name="foo.example.org")
data = {"domain": domain.name}
response = patch(
client, f"{API_URL}/network/networks/{network.id}", data=data, token=admin_token
)
assert response.status_code == 200
assert response.get_json()["domain"] == domain.name
updated_network = models.Network.query.get(network.id)
assert updated_network.domain == domain
def test_patch_network_invalid_domain(client, network_factory, admin_token):
network = network_factory()
data = {"domain": "foo"}
response = patch(
client, f"{API_URL}/network/networks/{network.id}", data=data, token=admin_token
)
check_response_message(response, "foo is not a valid domain", 400)
@pytest.mark.parametrize(
"field,value",
[
("name", "new-name"),
("description", "This is a test"),
("first_vlan", 110),
("last_vlan", 300),
("supernet", "172.16.0.0/16"),
],
)
def test_patch_network_scope(client, network_scope_factory, admin_token, field, value):
scope = network_scope_factory(first_vlan=100, last_vlan=400)
data = {field: value}
response = patch(
client, f"{API_URL}/network/scopes/{scope.id}", data=data, token=admin_token
)
assert response.status_code == 200
assert response.get_json()[field] == value
updated_scope = models.NetworkScope.query.get(scope.id)
assert getattr(updated_scope, field) == value
def test_patch_network_scope_domain(
client, network_scope_factory, domain_factory, admin_token
):
scope = network_scope_factory()
domain = domain_factory(name="foo.example.org")
data = {"domain": domain.name}
response = patch(
client, f"{API_URL}/network/scopes/{scope.id}", data=data, token=admin_token
)
assert response.status_code == 200
assert response.get_json()["domain"] == domain.name
updated_scope = models.NetworkScope.query.get(scope.id)
assert updated_scope.domain == domain
def test_delete_item_success(client, admin_token, item):
check_delete_success(
client,
admin_token,
item,
"inventory/items",
models.Item,
)
def test_delete_item_comment_success(client, admin_token, item_comment):
check_delete_success(
client,
admin_token,
item_comment,
"inventory/items/comments",
models.ItemComment,
)
def test_delete_manufacturer_success(client, admin_token, manufacturer):
check_delete_success(
client,
admin_token,
manufacturer,
"inventory/manufacturers",
models.Manufacturer,
)
def test_delete_model_success(client, admin_token, model):
check_delete_success(
client,
admin_token,
model,
"inventory/models",
models.Model,
)
def test_delete_location_success(client, admin_token, location):
check_delete_success(
client,
admin_token,
location,
"inventory/locations",
models.Location,
)
def test_delete_status_success(client, admin_token, status):
check_delete_success(
client,
admin_token,
status,
"inventory/statuses",
models.Status,
)
def test_delete_network_scope_success(client, admin_token, network_scope):
check_delete_success(
client,
admin_token,
network_scope,
"network/scopes",
models.NetworkScope,
)
def test_delete_network_scope_with_network_fail(
client, network_scope, network_factory, admin_token
):
network_factory(scope=network_scope)
response = delete(
client, f"{API_URL}/network/scopes/{network_scope.id}", token=admin_token
)
check_response_message(
response,
'(psycopg2.errors.NotNullViolation) null value in column "scope_id" violates not-null constraint',
422,
)
def test_delete_network_success(client, admin_token, network):
check_delete_success(
client,
admin_token,
network,
"network/networks",
models.Network,
)
def test_delete_network_with_host_fail(
client, network_192_168_1, interface_factory, admin_token
):
interface_factory(network=network_192_168_1)
response = delete(
client, f"{API_URL}/network/networks/{network_192_168_1.id}", token=admin_token
)
check_response_message(
response,
'(psycopg2.errors.NotNullViolation) null value in column "network_id" violates not-null constraint',
422,
)
def test_delete_ansible_group_success(client, admin_token, ansible_group):
check_delete_success(
client,
admin_token,
ansible_group,
"network/groups",
models.AnsibleGroup,
)
def test_delete_domain_success(client, admin_token, domain):
check_delete_success(
client,
admin_token,
domain,
"network/domains",
models.Domain,
)
def test_delete_domain_with_network_fail(client, domain, network_factory, admin_token):
network_factory(domain=domain)
response = delete(
client, f"{API_URL}/network/domains/{domain.id}", token=admin_token
)
check_response_message(
response,
'(psycopg2.errors.NotNullViolation) null value in column "domain_id" violates not-null constraint',
422,
)
def test_delete_cname_success(client, admin_token, cname):
check_delete_success(
client,
admin_token,
cname,
"network/cnames",
models.Cname,
)
@pytest.mark.parametrize("url", ["/network/scopes", "/user/users"])
@pytest.mark.parametrize(
"username, password, status_code",
[
("user_rw", "userrw", 403),
("user_prod", "userprod", 403),
("admin", "adminpasswd", 200),
("audit", "auditpasswd", 200),
],
)
def test_get_admin_protected_url(client, url, username, password, status_code):
token = get_token(client, username, password)
response = get(client, f"{API_URL}{url}", token=token)
assert response.status_code == status_code
......@@ -31,26 +31,314 @@ def test_user_is_admin(user_factory):
assert user.is_admin
def test_user_is_auditor(user_factory):
user = user_factory(groups=["foo", "CSEntry User"])
assert not user.is_auditor
user = user_factory(groups=["foo", "CSEntry Auditor"])
assert user.is_auditor
def test_user_is_member_of_one_group(user_factory):
user = user_factory(groups=["one", "two"])
assert not user.is_member_of_one_group(["create", "admin"])
assert not user.is_member_of_one_group(["network", "admin"])
user = user_factory(groups=["one", "CSEntry Consultant"])
assert user.is_member_of_one_group(["create"])
assert user.is_member_of_one_group(["create", "admin"])
assert user.is_member_of_one_group(["network"])
assert user.is_member_of_one_group(["network", "admin"])
assert not user.is_member_of_one_group(["admin"])
user = user_factory(groups=["one", "CSEntry Admin"])
assert not user.is_member_of_one_group(["create"])
assert user.is_member_of_one_group(["create", "admin"])
assert not user.is_member_of_one_group(["network"])
assert user.is_member_of_one_group(["network", "admin"])
assert user.is_member_of_one_group(["admin"])
user = user_factory(groups=["CSEntry Auditor"])
assert not user.is_member_of_one_group(["network", "admin"])
assert user.is_member_of_one_group(["auditor"])
def test_user_network_scopes(user_factory):
user = user_factory(groups=["CSEntry Prod", "CSEntry User"])
assert user.csentry_network_scopes == ["ProdNetworks", "FooNetworks"]
user = user_factory(groups=["foo", "CSEntry Lab"])
assert user.csentry_network_scopes == ["LabNetworks"]
@pytest.mark.parametrize(
"groups, sensitive_filter",
[
([], "sensitive:false"),
(["foo"], "sensitive:false"),
(
["CSEntry Lab"],
"sensitive:false OR (sensitive:true AND (scope:LabNetworks))",
),
(
["CSEntry Prod", "CSEntry User"],
"sensitive:false OR (sensitive:true AND (scope:ProdNetworks OR scope:FooNetworks))",
),
],
)
def test_user_sensitive_filter(user_factory, groups, sensitive_filter):
user = user_factory(groups=groups)
assert user.sensitive_filter == sensitive_filter
def test_network_ip_properties(network_factory):
@pytest.mark.parametrize(
"scope_name, groups, sensitive, expected",
[
("ProdNetworks", ["CSEntry Admin"], False, True),
("ProdNetworks", ["CSEntry Admin"], True, True),
("ProdNetworks", ["CSEntry Prod"], False, True),
("ProdNetworks", ["CSEntry Prod"], True, True),
("ProdNetworks", ["CSEntry Lab"], False, True),
("ProdNetworks", ["CSEntry Lab"], True, False),
],
)
@pytest.mark.parametrize("admin_only", [False, True])
def test_user_can_view_host(
user_factory,
network_scope_factory,
network_factory,
interface_factory,
host_factory,
scope_name,
groups,
sensitive,
expected,
admin_only,
):
scope = network_scope_factory(name=scope_name)
network = network_factory(scope=scope, admin_only=admin_only, sensitive=sensitive)
host = host_factory()
interface_factory(name=host.name, host=host, network=network)
user = user_factory(groups=groups)
assert user.can_view_host(host) == expected
@pytest.mark.parametrize(
"scope_name, groups, sensitive, expected",
[
("ProdNetworks", ["CSEntry Admin"], False, True),
("ProdNetworks", ["CSEntry Admin"], True, True),
("ProdNetworks", ["CSEntry Prod"], False, True),
("ProdNetworks", ["CSEntry Prod"], True, True),
("ProdNetworks", ["CSEntry Lab"], False, True),
("ProdNetworks", ["CSEntry Lab"], True, False),
],
)
@pytest.mark.parametrize("admin_only", [False, True])
def test_user_can_view_network(
user_factory,
network_scope_factory,
network_factory,
scope_name,
groups,
sensitive,
expected,
admin_only,
):
scope = network_scope_factory(name=scope_name)
network = network_factory(scope=scope, admin_only=admin_only, sensitive=sensitive)
user = user_factory(groups=groups)
assert user.can_view_network(network) == expected
def test_user_has_access_to_network(
user_factory, network_scope_factory, network_factory
):
scope_prod = network_scope_factory(name="ProdNetworks")
scope_lab = network_scope_factory(name="LabNetworks")
network_prod = network_factory(scope=scope_prod)
network_lab = network_factory(scope=scope_lab)
network_lab_admin = network_factory(scope=scope_lab, admin_only=True)
user = user_factory(groups=["CSEntry Prod", "CSEntry Lab"])
assert user.has_access_to_network(network_prod)
assert user.has_access_to_network(network_lab)
assert not user.has_access_to_network(network_lab_admin)
assert user.has_access_to_network(None)
user = user_factory(groups=["foo", "CSEntry Lab"])
assert not user.has_access_to_network(network_prod)
assert user.has_access_to_network(network_lab)
assert not user.has_access_to_network(network_lab_admin)
user = user_factory(groups=["one", "two"])
assert not user.has_access_to_network(network_prod)
assert not user.has_access_to_network(network_lab)
assert not user.has_access_to_network(network_lab_admin)
user = user_factory(groups=["CSEntry Admin"])
assert user.has_access_to_network(network_prod)
assert user.has_access_to_network(network_lab)
assert user.has_access_to_network(network_lab_admin)
assert user.has_access_to_network(None)
def test_user_can_create_vm(
user_factory,
network_scope_factory,
network_factory,
device_type_factory,
host_factory,
interface_factory,
):
virtualmachine = device_type_factory(name="VirtualMachine")
scope_prod = network_scope_factory(name="ProdNetworks")
scope_lab = network_scope_factory(name="LabNetworks")
network_prod = network_factory(scope=scope_prod)
network_lab = network_factory(scope=scope_lab)
network_lab_admin = network_factory(scope=scope_lab, admin_only=True)
vm_prod = host_factory(device_type=virtualmachine)
interface_factory(name=vm_prod.name, host=vm_prod, network=network_prod)
vioc_prod = host_factory(device_type=virtualmachine, is_ioc=True)
interface_factory(name=vioc_prod.name, host=vioc_prod, network=network_prod)
vm_lab = host_factory(device_type=virtualmachine)
interface_factory(name=vm_lab.name, host=vm_lab, network=network_lab)
vioc_lab = host_factory(device_type=virtualmachine, is_ioc=True)
interface_factory(name=vioc_lab.name, host=vioc_lab, network=network_lab)
vm_lab_admin = host_factory(device_type=virtualmachine)
interface_factory(
name=vm_lab_admin.name, host=vm_lab_admin, network=network_lab_admin
)
vioc_lab_admin = host_factory(device_type=virtualmachine, is_ioc=True)
interface_factory(
name=vioc_lab_admin.name, host=vioc_lab_admin, network=network_lab_admin
)
non_vm = host_factory()
non_vm_ioc = host_factory(is_ioc=True)
interface_factory(name=non_vm.name, host=non_vm, network=network_lab)
interface_factory(name=non_vm_ioc.name, host=non_vm_ioc, network=network_lab)
# User has access to prod and lab networks but can only create a VM in the lab
# (due to ALLOWED_VM_CREATION_NETWORK_SCOPES) and VIOC in both
user = user_factory(groups=["CSEntry Prod", "CSEntry Lab"])
assert user.can_create_vm(vm_lab)
assert not user.can_create_vm(vm_prod)
assert not user.can_create_vm(vm_lab_admin)
assert not user.can_create_vm(non_vm)
assert user.can_create_vm(vioc_lab)
assert user.can_create_vm(vioc_prod)
assert not user.can_create_vm(vioc_lab_admin)
assert not user.can_create_vm(non_vm_ioc)
# User has only access to the lab networks and can only create a VM and VIOC in the lab
user = user_factory(groups=["foo", "CSEntry Lab"])
assert user.can_create_vm(vm_lab)
assert not user.can_create_vm(vm_prod)
assert not user.can_create_vm(vm_lab_admin)
assert not user.can_create_vm(non_vm)
assert user.can_create_vm(vioc_lab)
assert not user.can_create_vm(vioc_prod)
assert not user.can_create_vm(vioc_lab_admin)
assert not user.can_create_vm(non_vm_ioc)
# User can't create any VM or VIOC
user = user_factory(groups=["one", "two"])
assert not user.can_create_vm(vm_lab)
assert not user.can_create_vm(vm_prod)
assert not user.can_create_vm(vm_lab_admin)
assert not user.can_create_vm(non_vm)
assert not user.can_create_vm(vioc_lab)
assert not user.can_create_vm(vioc_prod)
assert not user.can_create_vm(vioc_lab_admin)
assert not user.can_create_vm(non_vm_ioc)
# Admin can create VM and VIOC
user = user_factory(groups=["CSEntry Admin"])
assert user.can_create_vm(vm_lab)
assert user.can_create_vm(vm_prod)
assert user.can_create_vm(vm_lab_admin)
assert not user.can_create_vm(non_vm)
assert user.can_create_vm(vioc_lab)
assert user.can_create_vm(vioc_prod)
assert user.can_create_vm(vioc_lab_admin)
assert not user.can_create_vm(non_vm_ioc)
def test_user_can_set_boot_profile(
user_factory,
network_scope_factory,
network_factory,
device_type_factory,
host_factory,
interface_factory,
):
physicalmachine = device_type_factory(name="PhysicalMachine")
scope_prod = network_scope_factory(name="ProdNetworks")
scope_lab = network_scope_factory(name="LabNetworks")
network_prod = network_factory(scope=scope_prod)
network_lab = network_factory(scope=scope_lab)
network_lab_admin = network_factory(scope=scope_lab, admin_only=True)
server_prod = host_factory(device_type=physicalmachine)
interface_factory(name=server_prod.name, host=server_prod, network=network_prod)
ioc_prod = host_factory(device_type=physicalmachine, is_ioc=True)
interface_factory(name=ioc_prod.name, host=ioc_prod, network=network_prod)
server_lab = host_factory(device_type=physicalmachine)
interface_factory(name=server_lab.name, host=server_lab, network=network_lab)
ioc_lab = host_factory(device_type=physicalmachine, is_ioc=True)
interface_factory(name=ioc_lab.name, host=ioc_lab, network=network_lab)
server_lab_admin = host_factory(device_type=physicalmachine)
interface_factory(
name=server_lab_admin.name, host=server_lab_admin, network=network_lab_admin
)
ioc_lab_admin = host_factory(device_type=physicalmachine, is_ioc=True)
interface_factory(
name=ioc_lab_admin.name, host=ioc_lab_admin, network=network_lab_admin
)
non_physical = host_factory()
non_physical_ioc = host_factory(is_ioc=True)
interface_factory(name=non_physical.name, host=non_physical, network=network_lab)
interface_factory(
name=non_physical_ioc.name, host=non_physical_ioc, network=network_lab
)
# User has access to prod and lab networks but can only set the boot profile in the lab
# (due to ALLOWED_SET_BOOT_PROFILE_NETWORK_SCOPES)
user = user_factory(groups=["CSEntry Prod", "CSEntry Lab"])
assert user.can_set_boot_profile(server_lab)
assert not user.can_set_boot_profile(server_prod)
assert not user.can_set_boot_profile(server_lab_admin)
assert not user.can_set_boot_profile(non_physical)
assert user.can_set_boot_profile(ioc_lab)
assert not user.can_set_boot_profile(ioc_prod)
assert not user.can_set_boot_profile(ioc_lab_admin)
assert not user.can_set_boot_profile(non_physical_ioc)
# User has only access to the lab networks and can only set the boot profile in the lab
user = user_factory(groups=["foo", "CSEntry Lab"])
assert user.can_set_boot_profile(server_lab)
assert not user.can_set_boot_profile(server_prod)
assert not user.can_set_boot_profile(server_lab_admin)
assert not user.can_set_boot_profile(non_physical)
assert user.can_set_boot_profile(ioc_lab)
assert not user.can_set_boot_profile(ioc_prod)
assert not user.can_set_boot_profile(ioc_lab_admin)
assert not user.can_set_boot_profile(non_physical_ioc)
# User can't set the boot profile
user = user_factory(groups=["one", "two"])
assert not user.can_set_boot_profile(server_lab)
assert not user.can_set_boot_profile(server_prod)
assert not user.can_set_boot_profile(server_lab_admin)
assert not user.can_set_boot_profile(non_physical)
assert not user.can_set_boot_profile(ioc_lab)
assert not user.can_set_boot_profile(ioc_prod)
assert not user.can_set_boot_profile(ioc_lab_admin)
assert not user.can_set_boot_profile(non_physical_ioc)
# Admin can set the boot profile on all physical machines
user = user_factory(groups=["CSEntry Admin"])
assert user.can_set_boot_profile(server_lab)
assert user.can_set_boot_profile(server_prod)
assert user.can_set_boot_profile(server_lab_admin)
assert not user.can_set_boot_profile(non_physical)
assert user.can_set_boot_profile(ioc_lab)
assert user.can_set_boot_profile(ioc_prod)
assert user.can_set_boot_profile(ioc_lab_admin)
assert not user.can_set_boot_profile(non_physical_ioc)
def test_network_ip_properties(network_scope_factory, network_factory):
scope = network_scope_factory(supernet="172.16.0.0/16")
# Create some networks
network1 = network_factory(
address="172.16.1.0/24", first_ip="172.16.1.10", last_ip="172.16.1.250"
address="172.16.1.0/24",
first_ip="172.16.1.10",
last_ip="172.16.1.250",
scope=scope,
)
network2 = network_factory(
address="172.16.20.0/26", first_ip="172.16.20.11", last_ip="172.16.20.14"
address="172.16.20.0/26",
first_ip="172.16.20.11",
last_ip="172.16.20.14",
scope=scope,
)
assert network1.network_ip == ipaddress.ip_network("172.16.1.0/24")
......@@ -74,13 +362,22 @@ def test_network_ip_properties(network_factory):
assert network2.used_ips() == []
def test_network_available_and_used_ips(network_factory, interface_factory):
def test_network_available_and_used_ips(
network_factory, interface_factory, network_scope_factory
):
scope = network_scope_factory(supernet="172.16.0.0/16")
# Create some networks and interfaces
network1 = network_factory(
address="172.16.1.0/24", first_ip="172.16.1.10", last_ip="172.16.1.250"
address="172.16.1.0/24",
first_ip="172.16.1.10",
last_ip="172.16.1.250",
scope=scope,
)
network2 = network_factory(
address="172.16.20.0/26", first_ip="172.16.20.11", last_ip="172.16.20.14"
address="172.16.20.0/26",
first_ip="172.16.20.11",
last_ip="172.16.20.14",
scope=scope,
)
for i in range(10, 20):
interface_factory(network=network1, ip=f"172.16.1.{i}")
......@@ -118,10 +415,16 @@ def test_network_available_and_used_ips(network_factory, interface_factory):
assert list(network2.available_ips()) == []
def test_network_gateway(network_factory):
network = network_factory(address="192.168.0.0/24")
assert str(network.gateway) == "192.168.0.254"
network = network_factory(address="172.16.110.0/23")
def test_network_gateway(network_scope_factory, network_factory):
scope1 = network_scope_factory(supernet="192.168.0.0/16")
scope2 = network_scope_factory(supernet="172.16.0.0/16")
network = network_factory(
address="192.168.0.0/24", gateway="192.168.0.1", scope=scope1
)
assert str(network.gateway) == "192.168.0.1"
network = network_factory(
address="172.16.110.0/23", gateway="172.16.111.254", scope=scope2
)
assert str(network.gateway) == "172.16.111.254"
......@@ -162,14 +465,6 @@ def test_device_type_validation(device_type_factory):
assert "'Physical Machine' is an invalid device type name" in str(excinfo.value)
def test_tag_validation(tag_factory):
tag = tag_factory(name="IOC")
assert tag.name == "IOC"
with pytest.raises(ValidationError) as excinfo:
tag = tag_factory(name="My tag")
assert "'My tag' is an invalid tag name" in str(excinfo.value)
def test_ansible_groups(ansible_group_factory, host_factory):
group1 = ansible_group_factory()
group2 = ansible_group_factory()
......@@ -189,19 +484,133 @@ def test_ansible_group_is_dynamic(ansible_group_factory):
assert group2.is_dynamic
group3 = ansible_group_factory(type=models.AnsibleGroupType.NETWORK)
assert group3.is_dynamic
group4 = ansible_group_factory(type=models.AnsibleGroupType.IOC)
assert group4.is_dynamic
group5 = ansible_group_factory(type=models.AnsibleGroupType.HOSTNAME)
assert group5.is_dynamic
def test_ansible_groups_children(ansible_group_factory, host_factory):
group1 = ansible_group_factory()
group2 = ansible_group_factory()
group3 = ansible_group_factory()
group1.children.append(group2)
group1.children.append(group3)
assert group1.children == [group2, group3]
group1.children = [group2, group3]
assert group1.children == sorted([group2, group3], key=lambda grp: grp.name)
assert group2.parents == [group1]
assert group3.parents == [group1]
group4 = ansible_group_factory(parents=[group1])
assert group1.children == [group2, group3, group4]
assert group1.children == sorted([group2, group3, group4], key=lambda grp: grp.name)
def test_ansible_groups_children_all_forbidden(ansible_group_factory):
all = ansible_group_factory(name="all")
group1 = ansible_group_factory()
group2 = ansible_group_factory()
with pytest.raises(ValidationError) as excinfo:
group1.children = [all]
assert (
f"Adding group 'all' as child to '{group1.name}' creates a recursive dependency loop"
in str(excinfo.value)
)
with pytest.raises(ValidationError) as excinfo:
ansible_group_factory(children=[all])
assert "creates a recursive dependency loop" in str(excinfo.value)
with pytest.raises(ValidationError) as excinfo:
ansible_group_factory(children=[group2, all])
assert "creates a recursive dependency loop" in str(excinfo.value)
def test_ansible_groups_no_recursive_dependency(ansible_group_factory):
group3 = ansible_group_factory()
group2 = ansible_group_factory(children=[group3])
group1 = ansible_group_factory(children=[group2])
with pytest.raises(ValidationError) as excinfo:
group3.children = [group1]
assert (
f"Adding group '{group1.name}' as child to '{group3.name}' creates a recursive dependency loop"
in str(excinfo.value)
)
def test_ansible_groups_no_child_of_itself(ansible_group_factory):
group1 = ansible_group_factory()
with pytest.raises(ValidationError) as excinfo:
group1.children = [group1]
assert f"Group '{group1.name}' can't be a child of itself" in str(excinfo.value)
@pytest.mark.parametrize(
"grp_type",
[
models.AnsibleGroupType.STATIC,
models.AnsibleGroupType.DEVICE_TYPE,
models.AnsibleGroupType.IOC,
models.AnsibleGroupType.HOSTNAME,
],
)
def test_ansible_group_network_scope_children(ansible_group_factory, grp_type):
group = ansible_group_factory(type=grp_type)
scope_group = ansible_group_factory(
type=models.AnsibleGroupType.NETWORK_SCOPE, children=[group]
)
assert scope_group.children == [group]
assert group.parents == [scope_group]
@pytest.mark.parametrize(
"grp_type",
[
models.AnsibleGroupType.NETWORK,
models.AnsibleGroupType.NETWORK_SCOPE,
],
)
def test_ansible_group_network_scope_children_forbidden(
ansible_group_factory, grp_type
):
child_group = ansible_group_factory(name="mygroup", type=grp_type)
with pytest.raises(ValidationError) as excinfo:
ansible_group_factory(
type=models.AnsibleGroupType.NETWORK_SCOPE, children=[child_group]
)
assert (
f"can't set {str(grp_type).lower()} group 'mygroup' as a network scope child"
in str(excinfo.value)
)
@pytest.mark.parametrize(
"grp_type",
[
models.AnsibleGroupType.STATIC,
models.AnsibleGroupType.DEVICE_TYPE,
models.AnsibleGroupType.IOC,
models.AnsibleGroupType.HOSTNAME,
],
)
def test_ansible_group_network_parent(ansible_group_factory, grp_type):
group = ansible_group_factory(type=grp_type)
network_group = ansible_group_factory(
type=models.AnsibleGroupType.NETWORK, parents=[group]
)
assert network_group.parents == [group]
assert group.children == [network_group]
@pytest.mark.parametrize(
"grp_type",
[
models.AnsibleGroupType.NETWORK,
models.AnsibleGroupType.NETWORK_SCOPE,
],
)
def test_ansible_group_network_parent_forbidden(ansible_group_factory, grp_type):
group = ansible_group_factory(name="mygroup", type=grp_type)
with pytest.raises(ValidationError) as excinfo:
ansible_group_factory(type=models.AnsibleGroupType.NETWORK, parents=[group])
assert (
f"can't set {str(grp_type).lower()} group 'mygroup' as a network parent"
in str(excinfo.value)
)
def test_host_model(model_factory, item_factory, host_factory):
......@@ -238,27 +647,28 @@ def test_ansible_dynamic_network_group(
):
network1 = network_factory(vlan_name="network1")
network2 = network_factory(vlan_name="network2")
interface1_n1 = interface_factory(network=network1)
interface2_n1 = interface_factory(network=network1)
interface1_n2 = interface_factory(network=network2)
host1_n1 = host_factory(name="host1", interfaces=[interface1_n1])
host2_n1 = host_factory(name="host2", interfaces=[interface2_n1])
host1_n2 = host_factory(interfaces=[interface1_n2])
group_n1 = ansible_group_factory(
host1 = host_factory(name="host1")
host2 = host_factory(name="host2")
host3 = host_factory(name="host3")
interface_factory(name="host1", host=host1, network=network1)
interface_factory(name="host2", host=host2, network=network1)
interface_factory(name="host2-2", host=host2, network=network2)
interface_factory(name="host3", host=host3, network=network2)
group1 = ansible_group_factory(
name="network1", type=models.AnsibleGroupType.NETWORK
)
group_n2 = ansible_group_factory(
group2 = ansible_group_factory(
name="network2", type=models.AnsibleGroupType.NETWORK
)
group_n3 = ansible_group_factory(
name="unknown", type=models.AnsibleGroupType.NETWORK
)
assert group_n1.hosts == [host1_n1, host2_n1]
assert group_n2.hosts == [host1_n2]
assert group_n3.hosts == []
group3 = ansible_group_factory(name="unknown", type=models.AnsibleGroupType.NETWORK)
# host2 has an interface on network1 and one on network2.
# It's only in group1, because its main interface (same name as host) is on network1.
assert group1.hosts == [host1, host2]
assert group2.hosts == [host3]
assert group3.hosts == []
def test_ansible_dynamic_network_scope_group(
def test_ansible_dynamic_network_scope_group_hosts(
ansible_group_factory,
network_scope_factory,
network_factory,
......@@ -267,27 +677,115 @@ def test_ansible_dynamic_network_scope_group(
):
scope1 = network_scope_factory(name="scope1")
scope2 = network_scope_factory(name="scope2")
network1_s1 = network_factory(scope=scope1)
network2_s1 = network_factory(scope=scope1)
network1_s2 = network_factory(scope=scope2)
interface1_s1 = interface_factory(network=network1_s1)
interface2_s1 = interface_factory(network=network2_s1)
interface1_s2 = interface_factory(network=network1_s2)
host1_s1 = host_factory(name="host1", interfaces=[interface1_s1])
host2_s1 = host_factory(name="host2", interfaces=[interface2_s1])
host1_s2 = host_factory(interfaces=[interface1_s2])
group_s1 = ansible_group_factory(
network1 = network_factory(scope=scope1)
network2 = network_factory(scope=scope1)
network3 = network_factory(scope=scope2)
host1 = host_factory(name="host1")
host2 = host_factory(name="host2")
host3 = host_factory(name="host3")
interface_factory(name="host1", host=host1, network=network1)
interface_factory(name="host2", host=host2, network=network2)
interface_factory(name="host2-2", host=host2, network=network3)
interface_factory(name="host3", host=host3, network=network3)
group1 = ansible_group_factory(
name="scope1", type=models.AnsibleGroupType.NETWORK_SCOPE
)
group_s2 = ansible_group_factory(
group2 = ansible_group_factory(
name="scope2", type=models.AnsibleGroupType.NETWORK_SCOPE
)
group_s3 = ansible_group_factory(
group3 = ansible_group_factory(
name="unknown", type=models.AnsibleGroupType.NETWORK_SCOPE
)
assert group_s1.hosts == [host1_s1, host2_s1]
assert group_s2.hosts == [host1_s2]
assert group_s3.hosts == []
# host2 has an interface on scope1 and one on scope2.
# It's only in group1, because its main interface (same name as host) is on scope1.
assert group1.hosts == [host1, host2]
assert group2.hosts == [host3]
assert group3.hosts == []
@pytest.mark.parametrize(
"networks, groups, expected_names",
[
(
["network1", "network2", "network3"],
[],
[],
),
(
["network1", "network2", "network3"],
[("network2", models.AnsibleGroupType.NETWORK)],
["network2"],
),
(
["network1", "network2", "network3"],
[
("network2", models.AnsibleGroupType.NETWORK),
("network3", models.AnsibleGroupType.NETWORK),
],
["network2", "network3"],
),
(
["network1"],
[
("network2", models.AnsibleGroupType.NETWORK),
],
[],
),
(
["network1", "network2"],
[
("mygroup1", models.AnsibleGroupType.DEVICE_TYPE),
("network2", models.AnsibleGroupType.NETWORK),
("mygroup2", models.AnsibleGroupType.STATIC),
],
["mygroup1", "mygroup2", "network2"],
),
],
)
def test_ansible_dynamic_network_scope_group_children(
ansible_group_factory,
network_scope_factory,
network_factory,
networks,
groups,
expected_names,
):
name = "myscope"
scope = network_scope_factory(name=name)
group = ansible_group_factory(name=name, type=models.AnsibleGroupType.NETWORK_SCOPE)
for network in networks:
network_factory(vlan_name=network, scope=scope)
for grp, grp_type in groups:
ag = ansible_group_factory(name=grp, type=grp_type)
if grp_type != models.AnsibleGroupType.NETWORK:
ag.parents = [group]
expected = [
grp for grp in models.AnsibleGroup.query.all() if grp.name in expected_names
]
assert group.children == sorted(expected, key=lambda grp: grp.name)
def test_ansible_dynamic_ioc_group(ansible_group_factory, host_factory):
host1 = host_factory(name="host1", is_ioc=True)
host2 = host_factory(name="host2", is_ioc=True)
host_factory(name="host3", is_ioc=False)
group = ansible_group_factory(name="iocs", type=models.AnsibleGroupType.IOC)
assert group.hosts == [host1, host2]
def test_ansible_dynamic_hostname_group(ansible_group_factory, host_factory):
host1 = host_factory(name="sw-gpn-rtp-01")
host2 = host_factory(name="sw-gpn-cso-campus-01")
host3 = host_factory(name="sw-tn-g02-lcr-01")
host4 = host_factory(name="sw-tn-g02-vbox-01")
host_factory(name="host1", is_ioc=False)
host_factory(name="foo-sw-tn-g02", is_ioc=False)
group1 = ansible_group_factory(name="sw-gpn", type=models.AnsibleGroupType.HOSTNAME)
group2 = ansible_group_factory(
name="sw-tn-g02", type=models.AnsibleGroupType.HOSTNAME
)
assert group1.hosts == [host2, host1]
assert group2.hosts == [host3, host4]
@pytest.mark.parametrize("status", [None, "FINISHED", "FAILED", "STARTED"])
......@@ -305,8 +803,8 @@ def test_task_waiting(status, user, task_factory):
assert user.is_task_waiting("my-task")
@pytest.mark.parametrize("minutes", [5, 10, 29])
def test_task_waiting_with_recent_deferred(minutes, user, task_factory):
@pytest.mark.parametrize("minutes", [5, 10, 29, 31, 60, 7200])
def test_task_waiting_with_recent_or_old_deferred(minutes, user, task_factory):
minutes_ago = datetime.datetime.utcnow() - datetime.timedelta(minutes=minutes)
task_factory(
created_at=minutes_ago, name="my-task", status=models.JobStatus.DEFERRED
......@@ -314,15 +812,6 @@ def test_task_waiting_with_recent_deferred(minutes, user, task_factory):
assert user.is_task_waiting("my-task")
@pytest.mark.parametrize("minutes", [31, 60, 7200])
def test_no_task_waiting_with_old_deferred(minutes, user, task_factory):
minutes_ago = datetime.datetime.utcnow() - datetime.timedelta(minutes=minutes)
task_factory(
created_at=minutes_ago, name="my-task", status=models.JobStatus.DEFERRED
)
assert not user.is_task_waiting("my-task")
@pytest.mark.parametrize("minutes", [5, 30, 7200])
def test_task_waiting_with_old_queued(minutes, user, task_factory):
minutes_ago = datetime.datetime.utcnow() - datetime.timedelta(minutes=minutes)
......@@ -352,3 +841,484 @@ def test_get_tasks_in_progress(user, task_factory):
task_factory(name="my-task", status=models.JobStatus.FAILED)
task3 = task_factory(name="my-task", status=models.JobStatus.DEFERRED)
assert user.get_tasks_in_progress("my-task") == [task1, task2, task3]
def test_update_task_reverse_dependencies(user, task_factory):
task1 = task_factory(name="my-task", status=models.JobStatus.STARTED)
task2 = task_factory(
name="my-task", status=models.JobStatus.DEFERRED, depends_on=task1
)
task3 = task_factory(
name="my-task", status=models.JobStatus.DEFERRED, depends_on=task1
)
task4 = task_factory(
name="my-task", status=models.JobStatus.DEFERRED, depends_on=task3
)
task1.update_reverse_dependencies()
for task in (task2, task3, task4):
assert task.status == models.JobStatus.FAILED
def test_host_indexed(db, host_factory):
host1 = host_factory(name="myhost")
res = db.app.elasticsearch.search(index="host-test", q="*")
assert res["hits"]["total"]["value"] == 1
assert res["hits"]["hits"][0]["_id"] == str(host1.id)
assert res["hits"]["hits"][0]["_id"] == str(host1.id)
def test_host_with_interfaces_indexed(db, host_factory, interface_factory):
host1 = host_factory(name="myhost")
interface_factory(name="myhost", host=host1)
interface_factory(name="myhost2", host=host1)
for name in ("myhost", "myhost2"):
res = db.app.elasticsearch.search(index="host-test", q=name)
assert res["hits"]["total"]["value"] == 1
assert res["hits"]["hits"][0]["_id"] == str(host1.id)
def test_interface_name_starts_with_host_name(db, host_factory, interface_factory):
host1 = host_factory(name="myhost")
interface = interface_factory(name="myhost-1", host=host1)
assert host1.interfaces[0] == interface
with pytest.raises(ValidationError) as excinfo:
interface_factory(name="myinterface", host=host1)
assert "Interface name shall start with the host name 'myhost'" in str(
excinfo.value
)
def test_interface_name_existing_host(db, host_factory, interface_factory):
host1 = host_factory(name="myhost")
host_factory(name="myhost2")
with pytest.raises(ValidationError) as excinfo:
interface_factory(name="myhost2", host=host1)
assert "Interface name matches an existing host" in str(excinfo.value)
def test_interface_name_existing_cname(
db, host_factory, interface_factory, cname_factory
):
host1 = host_factory(name="myhost")
cname_factory(name="myhost2")
with pytest.raises(ValidationError) as excinfo:
interface_factory(name="myhost2", host=host1)
assert "Interface name matches an existing cname" in str(excinfo.value)
def test_interface_is_main(host_factory, interface_factory):
# The interface with the same name as the host is the main one
host1 = host_factory(name="myhost")
interface11 = interface_factory(name=host1.name + "-2", host=host1)
interface12 = interface_factory(name=host1.name, host=host1)
interface13 = interface_factory(name=host1.name + "-3", host=host1)
assert interface12.is_main
assert not interface11.is_main
assert not interface13.is_main
# Interfaces are sorted by name
assert host1.interfaces == [interface12, interface11, interface13]
host2 = host_factory(name="anotherhost")
interface21 = interface_factory(name=host2.name + "-1", host=host2)
# If no interface has the same name as the host, the first one is the main
assert interface21.is_main
interface22 = interface_factory(name=host2.name + "-2", host=host2)
# The first interface in the list is the main one
assert host2.interfaces == [interface21, interface22]
assert interface21.is_main
assert not interface22.is_main
interface23 = interface_factory(name=host2.name, host=host2)
# The new interface has the same name as the host, so this is the main one
assert not interface21.is_main
assert not interface22.is_main
assert interface23.is_main
assert host2.interfaces == [interface23, interface21, interface22]
def test_host_existing_interface(db, host_factory, interface):
with pytest.raises(ValidationError) as excinfo:
host_factory(name=interface.name)
assert "Host name matches an existing interface" in str(excinfo.value)
def test_host_existing_cname(db, host_factory, cname):
with pytest.raises(ValidationError) as excinfo:
host_factory(name=cname.name)
assert "Host name matches an existing cname" in str(excinfo.value)
def test_host_fqdn(host_factory, interface_factory):
host1 = host_factory(name="myhost")
interface1 = interface_factory(name=host1.name, host=host1)
interface2 = interface_factory(name=host1.name + "-2", host=host1)
assert interface1.network.domain != interface2.network.domain
# The domain is the one from the main interface
assert interface1.is_main
assert host1.fqdn == f"{host1.name}.{interface1.network.domain.name}"
def test_host_is_ioc(host_factory, interface_factory):
host1 = host_factory(is_ioc=True)
interface1 = interface_factory(name=host1.name, host=host1)
interface2 = interface_factory(host=host1)
assert host1.is_ioc
assert interface1.is_ioc
assert not interface2.is_ioc
host2 = host_factory()
interface3 = interface_factory(name=host2.name, host=host2)
assert not host2.is_ioc
assert not interface3.is_ioc
def test_host_items_sorted_with_stack_member(host_factory, item_factory):
host1 = host_factory()
item1 = item_factory(ics_id="AAA001", host_id=host1.id, stack_member=1)
item2 = item_factory(ics_id="AAA002", host_id=host1.id, stack_member=0)
data = host1.to_dict()
# Items sorted by stack_member
assert data["items"] == ["AAA002", "AAA001"]
data = host1.to_dict(recursive=True)
assert data["items"] == [
{"ics_id": "AAA002", "serial_number": item2.serial_number, "stack_member": 0},
{"ics_id": "AAA001", "serial_number": item1.serial_number, "stack_member": 1},
]
def test_host_items_sorted_without_stack_member(host_factory, item_factory):
host1 = host_factory()
item1 = item_factory(ics_id="AAA001", host_id=host1.id)
item2 = item_factory(ics_id="AAA002", host_id=host1.id)
data = host1.to_dict()
# Items sorted by ics_id
assert data["items"] == ["AAA001", "AAA002"]
data = host1.to_dict(recursive=True)
assert data["items"] == [
{
"ics_id": "AAA001",
"serial_number": item1.serial_number,
"stack_member": None,
},
{
"ics_id": "AAA002",
"serial_number": item2.serial_number,
"stack_member": None,
},
]
def test_host_items_sorted_with_mixed_stack_member(host_factory, item_factory):
host1 = host_factory()
item1 = item_factory(ics_id="AAA001", host_id=host1.id, stack_member=1)
item2 = item_factory(ics_id="AAA002", host_id=host1.id)
item3 = item_factory(ics_id="AAA003", host_id=host1.id, stack_member=0)
data = host1.to_dict()
# Items sorted by stack_member and then ics_id
assert data["items"] == ["AAA003", "AAA001", "AAA002"]
data = host1.to_dict(recursive=True)
assert data["items"] == [
{"ics_id": "AAA003", "serial_number": item3.serial_number, "stack_member": 0},
{"ics_id": "AAA001", "serial_number": item1.serial_number, "stack_member": 1},
{
"ics_id": "AAA002",
"serial_number": item2.serial_number,
"stack_member": None,
},
]
def test_host_no_scope(host_factory):
host = host_factory()
assert host.scope is None
def test_host_scope(
host_factory, network_scope_factory, network_factory, interface_factory
):
scope = network_scope_factory()
network = network_factory(scope=scope)
host = host_factory()
interface_factory(name=host.name, host=host, network=network)
assert host.scope == scope
def test_cname_existing_host(db, host_factory, cname_factory):
host_factory(name="myhost")
with pytest.raises(ValidationError) as excinfo:
cname_factory(name="myhost")
assert "cname matches an existing host" in str(excinfo.value)
def test_cname_existing_interface(db, interface, cname_factory):
with pytest.raises(ValidationError) as excinfo:
cname_factory(name=interface.name)
assert "cname matches an existing interface" in str(excinfo.value)
def test_cname_unique_by_domain(db, interface_factory, network_factory, cname_factory):
network1 = network_factory()
network2 = network_factory()
assert network1.domain.name != network2.domain.name
interface1 = interface_factory(network=network1)
interface2 = interface_factory(network=network2)
# We can have identical cname on different domains
cname1 = cname_factory(name="mycname", interface=interface1)
cname2 = cname_factory(name="mycname", interface=interface2)
assert cname1.fqdn == f"mycname.{network1.domain}"
assert cname2.fqdn == f"mycname.{network2.domain}"
assert cname1.fqdn != cname2.fqdn
# cname must be unique by domain
interface3 = interface_factory(network=network1)
with pytest.raises(ValidationError) as excinfo:
cname_factory(name="mycname", interface=interface3)
assert f"Duplicate cname on the {network1.domain} domain" in str(excinfo.value)
def test_task_awx_job_url(db, task_factory):
task1 = task_factory(awx_resource="job", awx_job_id=42)
assert task1.awx_job_url == "https://awx.example.org/#/jobs/playbook/42"
task2 = task_factory(awx_resource="workflow_job", awx_job_id=43)
assert task2.awx_job_url == "https://awx.example.org/#/workflows/43"
task3 = task_factory(awx_resource="foo", awx_job_id=44)
assert task3.awx_job_url is None
task4 = task_factory(awx_job_id=45)
assert task4.awx_job_url is None
task5 = task_factory(awx_resource="inventory_source", awx_job_id=12)
assert task5.awx_job_url == "https://awx.example.org/#/jobs/inventory/12"
@pytest.mark.parametrize("length", (1, 25, 50))
def test_hostname_invalid_length(db, host_factory, length):
with pytest.raises(ValidationError) as excinfo:
host_factory(name="x" * length)
assert r"Host name shall match ^[a-z0-9\-]{2,24}" in str(excinfo.value)
@pytest.mark.parametrize("name", ("my_host", "host@", "foo:bar", "U02.K02"))
def test_hostname_invalid_characters(db, host_factory, name):
with pytest.raises(ValidationError) as excinfo:
host_factory(name=name)
assert r"Host name shall match ^[a-z0-9\-]{2,24}" in str(excinfo.value)
@pytest.mark.parametrize("length", (1, 30, 50))
def test_interface_name_invalid_length(db, interface_factory, length):
with pytest.raises(ValidationError) as excinfo:
interface_factory(name="x" * length)
assert r"Interface name shall match ^[a-z0-9\-]{2,29}" in str(excinfo.value)
def test_interface_name_length(db, host_factory, interface_factory):
hostname = "x" * 24
interface_name = hostname + "-yyyy"
host1 = host_factory(name=hostname)
interface_factory(name=interface_name, host=host1)
assert host1.interfaces[0].name == interface_name
with pytest.raises(ValidationError) as excinfo:
interface_factory(name=interface_name + "y", host=host1)
assert r"Interface name shall match ^[a-z0-9\-]{2,29}" in str(excinfo.value)
@pytest.mark.parametrize("ics_id", ("123", "AA123", "AAA1234"))
def test_item_invalid_ics_id(db, item_factory, ics_id):
with pytest.raises(ValidationError) as excinfo:
item_factory(ics_id=ics_id)
assert r"ICS id shall match [A-Z]{3}[0-9]{3}" in str(excinfo.value)
@pytest.mark.parametrize(
"address", ("172.30.0.0/25", "172.30.1.0/24", "172.30.0.0/22", "172.30.0.192/26")
)
def test_network_overlapping(address, network_scope_factory, network_factory):
scope = network_scope_factory(
first_vlan=3800, last_vlan=4000, supernet="172.30.0.0/16"
)
network1 = network_factory(
vlan_id=3800,
address="172.30.0.0/23",
first_ip="172.30.0.3",
last_ip="172.30.1.240",
scope=scope,
)
with pytest.raises(ValidationError) as excinfo:
network_factory(vlan_id=3801, address=address, scope=scope)
assert f"{address} overlaps {network1} ({network1.network_ip})" in str(
excinfo.value
)
@pytest.mark.parametrize("address", ("172.30.2.0/25", "172.30.0.0/16"))
def test_network_not_subnet_of_scope(address, network_scope_factory, network_factory):
scope = network_scope_factory(
first_vlan=3800, last_vlan=4000, supernet="172.30.0.0/23"
)
with pytest.raises(ValidationError) as excinfo:
network_factory(vlan_id=3800, address=address, scope=scope)
assert f"{address} is not a subnet of 172.30.0.0/23" in str(excinfo.value)
def test_scope_available_subnets(network_scope_factory, network_factory):
address = "172.30.0.0/16"
scope_ip = ipaddress.ip_network(address)
scope = network_scope_factory(first_vlan=3800, last_vlan=4000, supernet=address)
full_24 = [str(subnet) for subnet in scope_ip.subnets(new_prefix=24)]
assert scope.available_subnets(24) == full_24
network1 = network_factory(vlan_id=3800, address="172.30.60.0/24", scope=scope)
expected1 = [subnet for subnet in full_24 if subnet != network1.address]
assert scope.available_subnets(24) == expected1
network_factory(vlan_id=3801, address="172.30.244.0/22", scope=scope)
network2_24 = [
"172.30.244.0/24",
"172.30.245.0/24",
"172.30.246.0/24",
"172.30.247.0/24",
]
expected2 = [subnet for subnet in expected1 if subnet not in network2_24]
assert scope.available_subnets(24) == expected2
network_factory(vlan_id=3802, address="172.30.238.64/26", scope=scope)
expected3 = [subnet for subnet in expected2 if subnet != "172.30.238.0/24"]
assert scope.available_subnets(24) == expected3
@pytest.mark.parametrize("address", ("172.30.0.0/22", "172.30.244.0/22"))
def test_network_scope_overlapping(address, network_scope_factory):
scope = network_scope_factory(
first_vlan=3800, last_vlan=4000, supernet="172.30.0.0/16"
)
with pytest.raises(ValidationError) as excinfo:
network_scope_factory(first_vlan=2000, last_vlan=2200, supernet=address)
assert f"{address} overlaps {scope.name} ({scope.supernet_ip})" in str(
excinfo.value
)
def test_network_scope_supernet_validation(network_scope_factory, network_factory):
scope = network_scope_factory(
first_vlan=3800, last_vlan=4000, supernet="172.30.0.0/16"
)
network1 = network_factory(
vlan_id=3800,
address="172.30.0.0/23",
first_ip="172.30.0.3",
last_ip="172.30.1.240",
scope=scope,
)
address = "192.168.0.0/16"
with pytest.raises(ValidationError) as excinfo:
scope.supernet = "192.168.0.0/16"
assert f"{network1.network_ip} is not a subnet of {address}" in str(excinfo.value)
def test_network_scope_first_vlan_validation(network_scope_factory, network_factory):
scope = network_scope_factory(
first_vlan=200, last_vlan=400, supernet="172.30.0.0/16"
)
network1 = network_factory(
vlan_id=220,
address="172.30.0.0/23",
first_ip="172.30.0.3",
last_ip="172.30.1.240",
scope=scope,
)
with pytest.raises(ValidationError) as excinfo:
scope.first_vlan = 230
assert f"First vlan shall be lower than {network1} vlan: 220" in str(excinfo.value)
def test_network_scope_last_vlan_validation(network_scope_factory, network_factory):
scope = network_scope_factory(
first_vlan=200, last_vlan=400, supernet="172.30.0.0/16"
)
network1 = network_factory(
vlan_id=220,
address="172.30.0.0/23",
first_ip="172.30.0.3",
last_ip="172.30.1.240",
scope=scope,
)
with pytest.raises(ValidationError) as excinfo:
scope.last_vlan = 210
assert f"Last vlan shall be greater than {network1} vlan: 220" in str(excinfo.value)
def test_host_sensitive_field_update_on_network_change(
network_scope_factory, network_factory, interface_factory, host_factory
):
scope = network_scope_factory(
first_vlan=3800, last_vlan=4000, supernet="192.168.0.0/16"
)
network = network_factory(
vlan_id=3800,
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
sensitive=False,
scope=scope,
)
name = "host1"
host = host_factory(name=name)
interface_factory(name=name, host=host, ip="192.168.1.20", network=network)
# There is no sensitive host
instances, nb = models.Host.search("sensitive:true")
assert nb == 0
# Updating the network should update the host in the elasticsearch index
network.sensitive = True
instances, nb = models.Host.search("sensitive:true")
assert nb == 1
assert instances[0].name == name
@pytest.mark.parametrize(
"dn,username,user_info,user_groups,expected_display_name,expected_email,expected_groups",
[
(
"uid=johndoe,ou=Users,dc=esss,dc=lu,dc=se",
"johndoe",
{"mail": "john.doe@example.org", "cn": "John Doe"},
[{"cn": "group2"}, {"cn": "group1"}],
"John Doe",
"john.doe@example.org",
["group1", "group2"],
),
(
"uid=johndoe,ou=Users,dc=esss,dc=lu,dc=se",
"johndoe",
{"mail": ["john.doe@example.org"], "cn": ["John Doe"]},
[{"cn": ["group2"]}, {"cn": ["group1"]}],
"John Doe",
"john.doe@example.org",
["group1", "group2"],
),
(
"uid=auditor,ou=Service accounts,dc=esss,dc=lu,dc=se",
"auditor",
{
"uid": ["auditor"],
"cn": [],
"mail": [],
"dn": "uid=csentry_svc,ou=Service accounts,dc=esss,dc=lu,dc=se",
},
[
{
"cn": ["csentry_auditors"],
"dn": "cn=csentry_auditors,ou=ICS,ou=Groups,dc=esss,dc=lu,dc=se",
}
],
"auditor",
"",
["csentry_auditors"],
),
],
)
def test_save_user(
dn,
username,
user_info,
user_groups,
expected_display_name,
expected_email,
expected_groups,
):
user = models.save_user(dn, username, user_info, user_groups)
assert user.username == username
assert user.display_name == expected_display_name
assert user.email == expected_email
assert user.groups == expected_groups
# -*- coding: utf-8 -*-
"""
tests.functional.test_search
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This module defines search tests.
:copyright: (c) 2017 European Spallation Source ERIC
:license: BSD 2-Clause, see LICENSE for more details.
"""
import elasticsearch
import pytest
from app import search
class MyModel:
def __init__(self, id, name, description=""):
self.id = id
self.name = name
self.description = description
def to_dict(self, recursive=False):
return {"id": self.id, "name": self.name, "description": self.description}
def test_add_to_index(db):
model1 = MyModel(2, "foo", "This is a test")
search.add_to_index("index-test", model1.to_dict())
res = db.app.elasticsearch.get(index="index-test", id=2)
assert res["_source"] == {"name": "foo", "description": "This is a test"}
def test_remove_from_index(db):
model1 = MyModel(3, "hello world!")
search.add_to_index("index-test", model1.to_dict())
res = db.app.elasticsearch.search(index="index-test", q="*")
assert res["hits"]["total"]["value"] == 1
search.remove_from_index("index-test", model1.id)
res = db.app.elasticsearch.search(index="index-test", q="*")
assert res["hits"]["total"]["value"] == 0
def test_remove_from_index_non_existing():
model1 = MyModel(1, "hello world!")
with pytest.raises(elasticsearch.NotFoundError):
search.remove_from_index("index-test", model1.id)
def test_query_index():
model1 = MyModel(1, "Python", "Python is my favorite language")
search.add_to_index("index-test", model1.to_dict())
model1 = MyModel(2, "Java", "Your should switch to Python!")
search.add_to_index("index-test", model1.to_dict())
# Test query all
ids, total = search.query_index("index-test", "*")
assert sorted(ids) == [1, 2]
assert total == 2
# Test query string
ids, total = search.query_index("index-test", "java")
assert ids == [2]
assert total == 1
ids, total = search.query_index("index-test", "python")
assert sorted(ids) == [1, 2]
# Test query specific field
ids, total = search.query_index("index-test", "name:python")
assert ids == [1]
# Test query sort
ids, total = search.query_index("index-test", "*", sort="name.keyword")
assert ids == [2, 1]
def test_update_document(db):
# Create a document
index = "index-test"
id = 4
name = "a name"
description = "just an example"
model = MyModel(id, name, description)
search.add_to_index(index, model.to_dict())
res = db.app.elasticsearch.get(index="index-test", id=id)
assert res["_source"] == {"name": name, "description": description}
# Update the name field (description doesn't change)
new_name = "new name"
search.update_document(index, id, {"name": new_name})
res = db.app.elasticsearch.get(index=index, id=id)
assert res["_source"] == {"name": new_name, "description": description}
import pytest
@pytest.mark.parametrize(
"name, func, input_kwargs, output_args",
[
("my task1", "my_func1", {}, ""),
(
"my task2",
"my_func2",
{"arg1": "foo", "arg2": True},
"arg1='foo', arg2=True",
),
# job_timeout is used by enqueue for the job
("another task", "func_to_run", {"job_timeout": 180}, ""),
# timeout is NOT used by enqueue for the job (deprecated in RQ >= 1.0)
# it's passed to the function
("task4", "my_func4", {"timeout": 60}, "timeout=60"),
],
)
def test_launch_task_kwargs(user, name, func, input_kwargs, output_args):
task = user.launch_task(name, func=func, **input_kwargs)
assert task.name == name
assert task.command == f"app.tasks.{func}({output_args})"
......@@ -11,6 +11,7 @@ This module defines basic web tests.
"""
import pytest
import re
from app import models
def login(client, username, password):
......@@ -25,49 +26,212 @@ def logout(client):
@pytest.fixture
def logged_client(client):
login(client, "user_ro", "userro")
return client
yield client
logout(client)
@pytest.fixture
def logged_rw_client(client):
login(client, "user_rw", "userrw")
yield client
logout(client)
@pytest.fixture
def logged_admin_client(client):
login(client, "admin", "adminpasswd")
yield client
logout(client)
@pytest.fixture
def no_login_check_client(request, app):
app.config["LOGIN_DISABLED"] = True
client = app.test_client()
# We still need to login, otherwise an AnonymousUserMixin is returned
# An AnonymousUser doesn't have all the User methods
login(client, "user_ro", "userro")
yield client
app.config["LOGIN_DISABLED"] = False
logout(client)
def test_login_logout(client):
response = login(client, "unknown", "invalid")
assert b"<title>Login - CSEntry</title>" in response.data
assert b"<title>Login</title>" in response.data
response = login(client, "user_rw", "invalid")
assert b"<title>Login - CSEntry</title>" in response.data
assert b"<title>Login</title>" in response.data
response = login(client, "user_rw", "userrw")
assert b"Welcome to CSEntry!" in response.data
assert b"Control System Entry" in response.data
assert b"User RW" in response.data
response = logout(client)
assert b"<title>Login - CSEntry</title>" in response.data
assert b"<title>Login</title>" in response.data
def test_index(logged_client):
response = logged_client.get("/")
assert b"Welcome to CSEntry!" in response.data
assert b"Control System Entry" in response.data
assert b"User RO" in response.data
@pytest.mark.parametrize("url", ["/", "/inventory/items", "/network/hosts"])
def test_protected_url_get(url, client):
response = client.get(url)
assert response.status_code == 302
assert "/user/login" in response.headers["Location"]
login(client, "user_ro", "userro")
response = client.get(url)
assert response.status_code == 200
@pytest.mark.parametrize("url", ["/network/scopes"])
@pytest.mark.parametrize(
"url", ["/", "/inventory/items", "/inventory/_retrieve_items", "/network/networks"]
"username, password, status_code",
[
("user_rw", "userrw", 403),
("user_prod", "userprod", 403),
("admin", "adminpasswd", 200),
("audit", "auditpasswd", 200),
],
)
def test_protected_url(url, client):
def test_admin_protected_url_get(client, url, username, password, status_code):
login(client, username, password)
response = client.get(url)
assert response.status_code == status_code
logout(client)
@pytest.mark.parametrize(
"url", ["/inventory/_retrieve_items", "/network/_retrieve_hosts"]
)
def test_protected_url_post(url, client):
response = client.post(url)
assert response.status_code == 302
assert "/user/login" in response.headers["Location"]
login(client, "user_ro", "userro")
response = client.get(url)
response = client.post(url)
assert response.status_code == 200
def test_retrieve_items(logged_client, item_factory):
response = logged_client.get("/inventory/_retrieve_items")
response = logged_client.post("/inventory/_retrieve_items")
assert response.get_json()["data"] == []
serial_numbers = ("12345", "45678")
for sn in serial_numbers:
item_factory(serial_number=sn)
response = logged_client.get("/inventory/_retrieve_items")
response = logged_client.post("/inventory/_retrieve_items")
items = response.get_json()["data"]
assert set(serial_numbers) == set(item["serial_number"] for item in items)
assert len(items[0]) == 18
def test_retrieve_items_pagination(logged_client, item_factory):
for sn in range(1000, 1030):
item_factory(serial_number=sn)
response = logged_client.post(
"/inventory/_retrieve_items", data={"draw": "50", "length": 10, "start": 0}
)
r = response.get_json()
assert r["draw"] == 50
assert r["recordsTotal"] == 30
assert r["recordsFiltered"] == 30
assert len(r["data"]) == 10
serial_numbers = [item["serial_number"] for item in r["data"]]
response = logged_client.post(
"/inventory/_retrieve_items", data={"draw": "51", "length": 10, "start": 10}
)
serial_numbers.extend(
[item["serial_number"] for item in response.get_json()["data"]]
)
response = logged_client.post(
"/inventory/_retrieve_items", data={"draw": "52", "length": 10, "start": 20}
)
serial_numbers.extend(
[item["serial_number"] for item in response.get_json()["data"]]
)
assert sorted(serial_numbers) == list(str(i) for i in range(1000, 1030))
def test_retrieve_items_filter(logged_client, item_factory):
for sn in range(1000, 1010):
item_factory(serial_number=sn)
response = logged_client.post(
"/inventory/_retrieve_items",
data={
"draw": "50",
"length": 20,
"start": 0,
"search[value]": "serial_number:1005",
},
)
r = response.get_json()
assert r["recordsTotal"] == 10
assert r["recordsFiltered"] == 1
assert len(r["data"]) == 1
assert r["data"][0]["serial_number"] == "1005"
def test_retrieve_items_sort(logged_client, item_factory):
serial_numbers = ["AAA001", "AAB034", "AAA100", "AAB001"]
for sn in serial_numbers:
item_factory(serial_number=sn)
response = logged_client.post(
"/inventory/_retrieve_items",
data={
"draw": "50",
"length": 20,
"start": 0,
"order[0][column]": "3",
"columns[3][data]": "serial_number",
},
)
items = response.get_json()["data"]
assert set(serial_numbers) == set(item[4] for item in items)
assert len(items[0]) == 11
assert sorted(serial_numbers) == [item["serial_number"] for item in items]
response = logged_client.post(
"/inventory/_retrieve_items",
data={
"draw": "50",
"length": 20,
"start": 0,
"order[0][column]": "3",
"order[0][dir]": "desc",
"columns[3][data]": "serial_number",
},
)
items = response.get_json()["data"]
assert sorted(serial_numbers, reverse=True) == [
item["serial_number"] for item in items
]
def test_retrieve_items_case_insensitive(logged_client, model_factory, item_factory):
juniper_model = model_factory(name="Juniper")
item_factory(serial_number="BBB001", model=juniper_model)
item_factory(serial_number="ABB042")
response = logged_client.post(
"/inventory/_retrieve_items",
data={"draw": "50", "length": 20, "start": 0, "search[value]": "juniper"},
)
r = response.get_json()
assert r["recordsTotal"] == 2
assert r["recordsFiltered"] == 1
assert len(r["data"]) == 1
assert r["data"][0]["model"] == "Juniper"
def test_retrieve_items_one_word(logged_client, manufacturer_factory, item_factory):
manufacturer = manufacturer_factory(name="Concurrent Technologies")
item_factory(serial_number="AAA001", manufacturer=manufacturer)
item_factory(serial_number="ABB042")
response = logged_client.post(
"/inventory/_retrieve_items",
data={"draw": "50", "length": 20, "start": 0, "search[value]": "concurrent"},
)
r = response.get_json()
assert r["recordsTotal"] == 2
assert r["recordsFiltered"] == 1
assert len(r["data"]) == 1
assert r["data"][0]["manufacturer"] == "Concurrent Technologies"
def test_generate_random_mac(logged_client):
......@@ -75,3 +239,831 @@ def test_generate_random_mac(logged_client):
mac = response.get_json()["data"]["mac"]
assert re.match("^(?:[0-9a-fA-F]{2}:){5}[0-9a-fA-F]{2}$", mac) is not None
assert mac.startswith("02:42:42")
def test_retrieve_hosts(logged_client, interface_factory, host_factory):
response = logged_client.post("/network/_retrieve_hosts")
assert response.get_json()["data"] == []
host1 = host_factory(name="host1")
host2 = host_factory(name="host2")
interface_factory(name="host1", host=host1)
interface_factory(name="host2", host=host2)
response = logged_client.post("/network/_retrieve_hosts")
hosts = response.get_json()["data"]
assert {host1.name, host2.name} == set(host["name"] for host in hosts)
assert len(hosts[0]) == 16
assert len(hosts[0]["interfaces"][0]) == 16
def test_retrieve_hosts_by_ip(logged_client, interface_factory):
interface1 = interface_factory()
interface_factory()
response = logged_client.post(
"/network/_retrieve_hosts",
data={"draw": "50", "length": 20, "start": 0, "search[value]": interface1.ip},
)
r = response.get_json()
assert r["recordsTotal"] == 2
assert r["recordsFiltered"] == 1
assert len(r["data"]) == 1
assert r["data"][0]["name"] == interface1.host.name
def test_retrieve_sensitive_hosts(
client, network_scope_factory, network_factory, host_factory, interface_factory
):
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
network1 = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
sensitive=False,
scope=scope,
)
network2 = network_factory(
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
sensitive=True,
scope=scope,
)
host1 = host_factory()
interface_factory(name=host1.name, host=host1, network=network1, ip="192.168.1.10")
host2 = host_factory()
interface_factory(name=host2.name, host=host2, network=network2, ip="192.168.2.10")
host3 = host_factory()
interface_factory(name=host3.name, host=host3, network=network2, ip="192.168.2.11")
# Normal users can't see hosts on sensitive networks if they aren't member of the scope
login(client, "user_lab", "userlab")
response = client.post("/network/_retrieve_hosts")
r = response.get_json()
assert r["recordsTotal"] == 3
assert r["recordsFiltered"] == 1
assert len(r["data"]) == 1
assert r["data"][0]["name"] == host1.name
logout(client)
# Users member of the scope can see sensitive hosts
# Same for admin and auditor users
for (user, passwd) in (
("user_prod", "userprod"),
("admin", "adminpasswd"),
("audit", "auditpasswd"),
):
login(client, user, passwd)
response = client.post(
"/network/_retrieve_hosts",
)
r = response.get_json()
assert r["recordsTotal"] == 3
assert r["recordsFiltered"] == 3
assert len(r["data"]) == 3
logout(client)
def test_delete_interface_from_index(
no_login_check_client, interface_factory, host_factory
):
host1 = host_factory(name="host1")
interface_factory(name="host1", host=host1)
interface2 = interface_factory(name="host1b", host=host1)
# The interface is in the index
instances, nb = models.Host.search("host1b")
assert list(instances) == [host1]
assert nb == 1
# Delete the interface
response = no_login_check_client.post(
"/network/interfaces/delete", data={"interface_id": interface2.id}
)
assert response.status_code == 302
# It's not in the database anymore
assert models.Interface.query.get(interface2.id) is None
# Neither in the index
instances, nb = models.Host.search("host1b")
assert list(instances) == []
assert nb == 0
# But host1 is still in the index
instances, nb = models.Host.search("host1")
assert list(instances) == [host1]
assert nb == 1
def test_edit_item_comment_in_index(
logged_rw_client, item_factory, item_comment_factory
):
item1 = item_factory(ics_id="AAA001")
comment = item_comment_factory(body="Hello", item=item1)
assert item1.comments == [comment]
# Edit the comment
body = "Hello world!"
response = logged_rw_client.post(
f"/inventory/items/comment/edit/{comment.id}", data={"body": body}
)
assert response.status_code == 302
# The comment was updated in the database
updated_comment = models.ItemComment.query.get(comment.id)
assert updated_comment.body == body
# And in the index
instances, nb = models.Item.search("world")
assert list(instances) == [item1]
def test_create_host(client, network_scope_factory, network_factory, device_type):
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
network = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
name = "myhost"
ip = "192.168.1.11"
mac = "02:42:42:45:3c:89"
form = {
"network_id": network.id,
"name": name,
"device_type_id": device_type.id,
"is_ioc": False,
"ip": ip,
"mac": mac,
"description": "test",
"ansible_vars": "foo: hello",
"ansible_groups": [],
"random_mac": False,
"cnames_string": "",
}
# Invalid network_id with user_lab user
# (form validation error because the network is not part of the choices
# for this user)
login(client, "user_lab", "userlab")
response = client.post("/network/hosts/create", data=form)
assert response.status_code == 200
# The host wasn't created
assert models.Host.query.filter_by(name=name).first() is None
logout(client)
# Success with user_prod user
login(client, "user_prod", "userprod")
response = client.post("/network/hosts/create", data=form, follow_redirects=True)
assert response.status_code == 200
# The host was created
assert b"created!" in response.data
assert b"View host" in response.data
host = models.Host.query.filter_by(name=name).first()
assert host is not None
assert host.interfaces[0].ip == ip
assert host.interfaces[0].mac == mac
assert host.interfaces[0].name == name
def test_create_host_invalid_fields(
session, client, network_scope_factory, network_factory, device_type
):
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
network = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
name = "myhost"
ip = "192.168.1.11"
mac = "02:42:42:45:3c:89"
form = {
"network_id": network.id,
"name": name,
"device_type_id": device_type.id,
"is_ioc": False,
"ip": ip,
"mac": mac,
"description": "test",
"ansible_vars": "",
"ansible_groups": [],
"random_mac": False,
"cnames_string": "",
}
login(client, "user_prod", "userprod")
# Invalid mac
data = form.copy()
data["mac"] = "ea:ea:60:45:a8:96:se"
response = client.post("/network/hosts/create", data=data, follow_redirects=True)
assert response.status_code == 200
assert b"Register new host" in response.data
assert b"Invalid MAC address" in response.data
# An exception was raised during validation (on Select in the Unique Validator),
# so we need to rollback.
session.rollback()
# Invalid hostname
data = form.copy()
data["name"] = "invalid_host"
response = client.post("/network/hosts/create", data=data, follow_redirects=True)
assert response.status_code == 200
assert b"Register new host" in response.data
assert b"Invalid input" in response.data
def test_create_interface(
client, host_factory, network_scope_factory, network_factory, interface_factory
):
host = host_factory(name="myhost")
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
network1 = network_factory(scope=scope)
interface_factory(network=network1, host=host)
network2 = network_factory(
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
scope=scope,
)
name = host.name + "-2"
ip = "192.168.2.11"
mac = "02:42:42:46:3c:75"
form = {
"host_id": host.id,
"interface_name": name,
"network_id": network2.id,
"random_mac": False,
"ip": ip,
"mac": mac,
"cnames_string": "",
}
# Permission denied
# user_lab doesn't have permissions for the host domain: prod.example.org
login(client, "user_lab", "userlab")
response = client.post(f"/network/interfaces/create/{host.name}", data=form)
assert response.status_code == 403
# The host wasn't created
assert models.Interface.query.filter_by(name=name).first() is None
logout(client)
# Success with user_prod user
login(client, "user_prod", "userprod")
response = client.post(f"/network/interfaces/create/{host.name}", data=form)
assert response.status_code == 302
# The interface was created
interface = models.Interface.query.filter_by(name=name).first()
assert interface is not None
assert interface.ip == ip
assert interface.mac == mac
assert interface.name == name
assert interface.host == host
def test_add_interface_to_empty_host(
client, host_factory, network_scope_factory, network_factory
):
host = host_factory(name="myhost")
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
network = network_factory(
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
scope=scope,
)
name = host.name
ip = "192.168.2.11"
mac = "02:42:42:46:3c:75"
form = {
"host_id": host.id,
"interface_name": name,
"network_id": network.id,
"random_mac": False,
"ip": ip,
"mac": mac,
"cnames_string": "",
}
# user_lab doesn't have permissions for the network domain: prod.example.org
# form validation will fail because the network_id won't be in the choices
login(client, "user_lab", "userlab")
response = client.post(f"/network/interfaces/create/{host.name}", data=form)
assert response.status_code == 200
# The host wasn't created
assert models.Interface.query.filter_by(name=name).first() is None
logout(client)
# Success with user_prod user
login(client, "user_prod", "userprod")
response = client.post(f"/network/interfaces/create/{host.name}", data=form)
assert response.status_code == 302
# The interface was created
interface = models.Interface.query.filter_by(name=name).first()
assert interface.ip == ip
assert interface.mac == mac
assert interface.name == name
assert interface.host == host
def check_vm_creation_response(response, success=True):
assert response.status_code == 200
assert (b"View task" in response.data) is success
assert (b"View host" in response.data) is not success
assert (b"Please contact an admin user" in response.data) is not success
def test_create_vm(
client,
network_scope_factory,
network_factory,
device_type_factory,
host_factory,
interface_factory,
):
virtualmachine = device_type_factory(name="VirtualMachine")
scope_prod = network_scope_factory(name="ProdNetworks")
scope_lab = network_scope_factory(name="LabNetworks")
network_prod = network_factory(scope=scope_prod)
network_lab = network_factory(scope=scope_lab)
vm_prod = host_factory(device_type=virtualmachine)
interface_factory(name=vm_prod.name, host=vm_prod, network=network_prod)
vioc_prod = host_factory(device_type=virtualmachine, is_ioc=True)
interface_factory(name=vioc_prod.name, host=vioc_prod, network=network_prod)
vm_lab = host_factory(device_type=virtualmachine)
interface_factory(name=vm_lab.name, host=vm_lab, network=network_lab)
vioc_lab = host_factory(device_type=virtualmachine, is_ioc=True)
interface_factory(name=vioc_lab.name, host=vioc_lab, network=network_lab)
form = {"cores": 1, "memory": 4, "disk": 15, "osversion": "centos7"}
# User has access to the lab networks and can create VM and VIOC there
login(client, "user_lab", "userlab")
response = client.post(
f"/network/hosts/view/{vm_prod.name}", data=form, follow_redirects=True
)
check_vm_creation_response(response, success=False)
response = client.post(
f"/network/hosts/view/{vioc_prod.name}", data=form, follow_redirects=True
)
check_vm_creation_response(response, success=False)
response = client.post(
f"/network/hosts/view/{vm_lab.name}", data=form, follow_redirects=True
)
check_vm_creation_response(response, success=True)
response = client.post(
f"/network/hosts/view/{vioc_lab.name}", data=form, follow_redirects=True
)
check_vm_creation_response(response, success=True)
logout(client)
# User has access to the prod networks but can only create VIOC due to ALLOWED_VM_CREATION_DOMAINS
login(client, "user_prod", "userprod")
response = client.post(
f"/network/hosts/view/{vm_prod.name}", data=form, follow_redirects=True
)
check_vm_creation_response(response, success=False)
response = client.post(
f"/network/hosts/view/{vioc_prod.name}", data=form, follow_redirects=True
)
check_vm_creation_response(response, success=True)
response = client.post(
f"/network/hosts/view/{vm_lab.name}", data=form, follow_redirects=True
)
check_vm_creation_response(response, success=False)
response = client.post(
f"/network/hosts/view/{vioc_lab.name}", data=form, follow_redirects=True
)
check_vm_creation_response(response, success=False)
def test_delete_host_as_admin(logged_admin_client, host_factory, user_factory):
# admin can delete any host
admin = models.User.query.filter_by(username="admin").first()
user1 = user_factory(username="user1")
host1 = host_factory(name="host1", user=admin)
host2 = host_factory(name="host2", user=user1)
assert len(models.Host.query.all()) == 2
response = logged_admin_client.post(
"/network/hosts/delete", data={"host_id": host1.id}
)
assert response.status_code == 302
response = logged_admin_client.post(
"/network/hosts/delete", data={"host_id": host2.id}
)
assert response.status_code == 302
assert len(models.Host.query.all()) == 0
def test_delete_host_as_normal_user(logged_rw_client, host_factory, user_factory):
# a normal user can only delete its own hosts
user_rw = models.User.query.filter_by(username="user_rw").first()
user1 = user_factory(username="user1")
host1 = host_factory(name="host1", user=user_rw)
host2 = host_factory(name="host2", user=user1)
assert len(models.Host.query.all()) == 2
# user_rw can delete its host
response = logged_rw_client.post(
"/network/hosts/delete", data={"host_id": host1.id}
)
assert response.status_code == 302
# user_rw can't delete host owned by user1
response = logged_rw_client.post(
"/network/hosts/delete", data={"host_id": host2.id}
)
assert response.status_code == 403
assert len(models.Host.query.all()) == 1
def test_create_network_scope(logged_admin_client, domain_factory):
domain = domain_factory(name="prod.example.org")
name = "MyNetworks"
first_vlan = 200
last_vlan = 300
supernet = "192.168.0.0/16"
form = {
"name": name,
"first_vlan": first_vlan,
"last_vlan": last_vlan,
"supernet": supernet,
"domain_id": domain.id,
}
response = logged_admin_client.post("/network/scopes/create", data=form)
assert response.status_code == 302
# The network scope was created
scope = models.NetworkScope.query.filter_by(name=name).first()
assert scope is not None
assert scope.name == name
assert scope.first_vlan == first_vlan
assert scope.last_vlan == last_vlan
assert scope.supernet == supernet
def test_create_network_scope_no_vlan(logged_admin_client, domain_factory):
domain = domain_factory(name="lab.example.org")
name = "NoVlan"
supernet = "192.168.0.0/16"
form = {
"name": name,
"first_vlan": "",
"last_vlan": None,
"supernet": supernet,
"domain_id": domain.id,
}
response = logged_admin_client.post("/network/scopes/create", data=form)
assert response.status_code == 302
# The network scope was created
scope = models.NetworkScope.query.filter_by(name=name).first()
assert scope is not None
assert scope.name == name
assert scope.first_vlan is None
assert scope.last_vlan is None
assert scope.supernet == supernet
def test_create_network(logged_admin_client, domain_factory, network_scope_factory):
domain = domain_factory(name="lab.example.org")
scope = network_scope_factory(
name="MyNetworks",
first_vlan=100,
last_vlan=200,
supernet="192.168.0.0/16",
domain_id=domain.id,
)
vlan_name = "my-network"
form = {
"scope_id": scope.id,
"vlan_name": vlan_name,
"vlan_id": 101,
"address": "192.168.0.0/24",
"first_ip": "192.168.0.11",
"last_ip": "192.168.0.249",
"gateway": "192.168.0.254",
"domain_id": domain.id,
"admin_only": False,
}
response = logged_admin_client.post("/network/networks/create", data=form)
assert response.status_code == 302
# The network was created
network = models.Network.query.filter_by(vlan_name=vlan_name).first()
assert network is not None
assert network.vlan_name == vlan_name
assert network.address == form["address"]
assert network.vlan_id == form["vlan_id"]
def test_create_network_no_vlan(
logged_admin_client, domain_factory, network_scope_factory
):
domain = domain_factory(name="lab.example.org")
scope = network_scope_factory(
name="NoVlanNetworks",
first_vlan=None,
last_vlan=None,
supernet="192.168.0.0/16",
domain_id=domain.id,
)
vlan_name = "my-network"
form = {
"scope_id": scope.id,
"vlan_name": vlan_name,
"vlan_id": "",
"address": "192.168.0.0/24",
"first_ip": "192.168.0.11",
"last_ip": "192.168.0.249",
"gateway": "192.168.0.254",
"domain_id": domain.id,
"admin_only": False,
}
response = logged_admin_client.post("/network/networks/create", data=form)
assert response.status_code == 302
# The network was created
network = models.Network.query.filter_by(vlan_name=vlan_name).first()
assert network is not None
assert network.vlan_name == vlan_name
assert network.address == form["address"]
assert network.vlan_id is None
def test_edit_network(
logged_admin_client, domain_factory, network_scope_factory, network_factory
):
domain = domain_factory(name="lab.example.org")
scope = network_scope_factory(
name="MyNetworks",
first_vlan=100,
last_vlan=200,
supernet="192.168.0.0/16",
domain_id=domain.id,
)
vlan_name = "my-network"
network = network_factory(
vlan_name=vlan_name,
domain=domain,
scope=scope,
vlan_id=100,
address="192.168.0.0/24",
first_ip="192.168.0.11",
last_ip="192.168.0.249",
gateway="192.168.0.254",
admin_only=False,
sensitive=False,
)
new_first_ip = "192.168.0.10"
form = {
"vlan_name": vlan_name,
"vlan_id": network.vlan_id,
"address": network.address,
"first_ip": new_first_ip,
"last_ip": network.last_ip,
"gateway": network.gateway,
"domain_id": network.domain_id,
"admin_only": True,
"sensitive": True,
}
response = logged_admin_client.post(
f"/network/networks/edit/{vlan_name}", data=form
)
assert response.status_code == 302
# The network was updated
network = models.Network.query.filter_by(vlan_name=vlan_name).first()
assert network is not None
assert network.vlan_name == vlan_name
assert network.first_ip == new_first_ip
assert network.admin_only is True
assert network.sensitive is True
def test_edit_network_scope(
logged_admin_client, domain_factory, network_scope_factory, network_factory
):
domain1 = domain_factory(name="lab.example.org")
name = "MyNetworks"
scope = network_scope_factory(
name=name,
first_vlan=100,
last_vlan=200,
supernet="192.168.0.0/16",
domain_id=domain1.id,
)
network_factory(
vlan_name="my-network",
domain=domain1,
scope=scope,
vlan_id=110,
address="192.168.0.0/24",
first_ip="192.168.0.11",
last_ip="192.168.0.249",
gateway="192.168.0.254",
)
new_domain = domain_factory(name="lab.example.eu")
form = {
"name": name,
"description": "Scope for MyNetworks",
"first_vlan": 105,
"last_vlan": 150,
"supernet": scope.supernet,
"domain_id": new_domain.id,
}
response = logged_admin_client.post(f"/network/scopes/edit/{name}", data=form)
assert response.status_code == 302
# The scope was updated
updated_scope = models.NetworkScope.query.filter_by(name=name).first()
assert updated_scope is not None
assert updated_scope.first_vlan == 105
assert updated_scope.last_vlan == 150
assert updated_scope.domain == new_domain
def test_create_item_invalid_ics_id(logged_rw_client):
ics_id = "AAA1100"
form = {"ics_id": ics_id, "serial_number": "12345"}
response = logged_rw_client.post(
"/inventory/items/create", data=form, follow_redirects=True
)
assert response.status_code == 200
assert b"Register new item" in response.data
assert b"The ICS id shall be composed of 3 letters and 3 digits" in response.data
def test_create_item_with_stack_member(
host_factory, device_type_factory, item_factory, logged_rw_client
):
# Test for JIRA INFRA-1648
network_type = device_type_factory(name="NETWORK")
host = host_factory(device_type=network_type)
item1 = item_factory(ics_id="AAA001", host=host, stack_member=0)
ics_id = "AAA042"
form = {
"ics_id": ics_id,
"serial_number": "12345",
"host_id": host.id,
"stack_member": 1,
}
response = logged_rw_client.post("/inventory/items/create", data=form)
assert response.status_code == 302
item2 = models.Item.query.filter_by(ics_id=ics_id).first()
assert host.stack_members() == [item1, item2]
def test_create_item_with_host_and_no_stack_member(
host_factory, device_type_factory, item_factory, logged_rw_client
):
network_type = device_type_factory(name="NETWORK")
host = host_factory(device_type=network_type)
ics_id = "AAA042"
form = {
"ics_id": ics_id,
"serial_number": "12345",
"host_id": host.id,
"stack_member": "",
}
response = logged_rw_client.post("/inventory/items/create", data=form)
assert response.status_code == 302
item = models.Item.query.filter_by(ics_id=ics_id).first()
assert item.host == host
assert item.stack_member is None
def test_ansible_groups_no_recursive_dependency(
ansible_group_factory, logged_admin_client
):
group3 = ansible_group_factory()
group2 = ansible_group_factory(children=[group3])
group1 = ansible_group_factory(children=[group2])
form = {"name": group3.name, "type": group3.type, "children": [group1.id]}
response = logged_admin_client.post(
f"/network/groups/edit/{group3.name}", data=form
)
assert response.status_code == 200
assert b"creates a recursive dependency loop" in response.data
def test_create_network_overlapping(
network_scope_factory, network_factory, logged_admin_client
):
scope = network_scope_factory(
first_vlan=3800, last_vlan=4000, supernet="172.30.0.0/16"
)
network_factory(
vlan_name="network1",
vlan_id=3800,
address="172.30.0.0/23",
first_ip="172.30.0.3",
last_ip="172.30.1.240",
scope=scope,
)
form = {
"vlan_name": "network2",
"vlan_id": 3842,
"scope_id": scope.id,
"address": "172.30.1.0/24",
"first_ip": "172.30.1.5",
"last_ip": "172.30.1.245",
"gateway": "172.30.1.248",
"domain_id": scope.domain_id,
}
response = logged_admin_client.post("/network/networks/create", data=form)
assert response.status_code == 200
assert b"172.30.1.0/24 overlaps network1 (172.30.0.0/23)" in response.data
def test_create_network_scope_overlapping(network_scope_factory, logged_admin_client):
scope1 = network_scope_factory(
name="scope1", first_vlan=3800, last_vlan=4000, supernet="172.30.0.0/16"
)
form = {
"name": "scope2",
"first_vlan": 200,
"last_vlan": 500,
"supernet": "172.30.200.0/22",
"domain_id": scope1.domain_id,
}
response = logged_admin_client.post("/network/scopes/create", data=form)
assert response.status_code == 200
assert b"172.30.200.0/22 overlaps scope1 (172.30.0.0/16)" in response.data
@pytest.mark.parametrize(
"user, password, sensitive, status_code",
[
("user_lab", "userlab", False, 200),
("user_lab", "userlab", True, 403),
("user_prod", "userprod", False, 200),
("user_prod", "userprod", True, 200),
("audit", "auditpasswd", False, 200),
("audit", "auditpasswd", True, 200),
],
)
@pytest.mark.parametrize("admin_only", [False, True])
def test_view_network_restriction(
client,
network_scope_factory,
network_factory,
user,
password,
sensitive,
status_code,
admin_only,
):
# admin_only doesn't matter to view networks
# To view sensitive networks, the user has to be member of the scope
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
network = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
sensitive=sensitive,
admin_only=admin_only,
)
login(client, user, password)
response = client.get(f"/network/networks/view/{network}")
assert response.status_code == status_code
logout(client)
def test_view_networks(client, network_scope_factory, network_factory):
scope = network_scope_factory(name="ProdNetworks", supernet="192.168.0.0/16")
network1 = network_factory(
address="192.168.1.0/24",
first_ip="192.168.1.10",
last_ip="192.168.1.250",
scope=scope,
)
network2 = network_factory(
address="192.168.2.0/24",
first_ip="192.168.2.10",
last_ip="192.168.2.250",
admin_only=True,
scope=scope,
)
network3 = network_factory(
address="192.168.3.0/24",
first_ip="192.168.3.10",
last_ip="192.168.3.250",
sensitive=True,
scope=scope,
)
# user_lab can't see sensitive networks
login(client, "user_lab", "userlab")
response = client.get("/network/networks")
assert response.status_code == 200
assert network1.vlan_name in str(response.data)
assert network2.vlan_name in str(response.data)
assert network3.vlan_name not in str(response.data)
logout(client)
# user_prod user can see all networks
# Same for admin and auditor users
for (user, passwd) in (
("user_prod", "userprod"),
("admin", "adminpasswd"),
("audit", "auditpasswd"),
):
login(client, "user_prod", "userprod")
response = client.get("/network/networks")
assert network1.vlan_name in str(response.data)
assert network2.vlan_name in str(response.data)
assert network3.vlan_name in str(response.data)
logout(client)
def test_retrieve_groups(logged_client, ansible_group_factory):
response = logged_client.post("/network/_retrieve_groups")
assert response.get_json()["data"] == []
group1 = ansible_group_factory(name="group1")
group2 = ansible_group_factory(name="group2")
response = logged_client.post("/network/_retrieve_groups")
groups = response.get_json()["data"]
assert {group1.name, group2.name} == set(group["name"] for group in groups)
def test_generate_excel_file(logged_client):
response = logged_client.get("/inventory/items/_generate_excel_file")
assert response.status_code == 202
assert "/status/" in response.headers["Location"]
job_id = response.headers["Location"].split("/")[-1]
task = models.Task.query.get(job_id)
assert task is not None
assert task.name == "generate_items_excel_file"
assert task.command == "app.tasks.generate_items_excel_file()"
......@@ -9,7 +9,13 @@ This module defines fields tests.
:license: BSD 2-Clause, see LICENSE for more details.
"""
from app.fields import yaml
import pytest
from wtforms.form import Form
from app.fields import yaml, YAMLField
class MyForm(Form):
vars = YAMLField("Ansible vars")
def test_vault_yaml_tag_load():
......@@ -33,3 +39,33 @@ def test_vault_yaml_tag_load():
"""
}
}
@pytest.mark.parametrize(
"text_input,expected",
[
("foo: hello", {"foo": "hello"}),
("foo:\n - a\n - b", {"foo": ["a", "b"]}),
("", None),
(" ", None),
],
)
def test_yamlfield_process_formdata(text_input, expected):
form = MyForm()
YAMLField.process_formdata(form.vars, [text_input])
assert form.vars.data == expected
def test_yamlfield_process_formdata_invalid_yaml():
form = MyForm()
with pytest.raises(ValueError, match="This field contains invalid YAML"):
YAMLField.process_formdata(form.vars, ["foo: hello: world"])
@pytest.mark.parametrize("text_input", ("foo", "- a\n- b"))
def test_yamlfield_process_formdata_non_dict(text_input):
form = MyForm()
with pytest.raises(
ValueError, match="This field shall only contain key-value-pairs"
):
YAMLField.process_formdata(form.vars, [text_input])
......@@ -9,6 +9,7 @@ This module defines utils tests.
:license: BSD 2-Clause, see LICENSE for more details.
"""
import pytest
from pathlib import Path
from app import utils
......@@ -45,3 +46,17 @@ class TestUniqueFilename:
p = tmpdir.join("test")
p.write("Hello")
assert utils.unique_filename(p) == Path(tmpdir.join("test-1"))
@pytest.mark.parametrize(
"input,expected",
[
([], ""),
(["foo"], "foo"),
(["foo", "bar"], "foo"),
("hello", "hello"),
("", ""),
],
)
def test_attribute_to_string(input, expected):
assert utils.attribute_to_string(input) == expected