Files
prowler/tests/providers/azure/services/sqlserver/sqlserver_service_test.py

179 lines
6.2 KiB
Python

from unittest.mock import patch
from azure.mgmt.sql.models import (
EncryptionProtector,
ServerVulnerabilityAssessment,
TransparentDataEncryption,
)
from prowler.providers.azure.services.sqlserver.sqlserver_service import (
Database,
Server,
SQLServer,
)
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION,
set_mocked_azure_audit_info,
)
def mock_sqlserver_get_sql_servers(_):
database = Database(
id="id",
name="name",
type="type",
location="location",
managed_by="managed_by",
tde_encryption=TransparentDataEncryption(status="Disabled"),
)
return {
AZURE_SUBSCRIPTION: [
Server(
id="id",
name="name",
public_network_access="public_network_access",
minimal_tls_version="minimal_tls_version",
administrators=None,
auditing_policies=None,
firewall_rules=None,
encryption_protector=EncryptionProtector(
server_key_type="AzureKeyVault"
),
databases=[database],
vulnerability_assessment=ServerVulnerabilityAssessment(
storage_container_path="/subcription_id/resource_group/sql_server"
),
)
]
}
@patch(
"prowler.providers.azure.services.sqlserver.sqlserver_service.SQLServer.__get_sql_servers__",
new=mock_sqlserver_get_sql_servers,
)
class Test_SqlServer_Service:
def test__get_client__(self):
sql_server = SQLServer(set_mocked_azure_audit_info())
assert (
sql_server.clients[AZURE_SUBSCRIPTION].__class__.__name__
== "SqlManagementClient"
)
def test__get_sql_servers__(self):
database = Database(
id="id",
name="name",
type="type",
location="location",
managed_by="managed_by",
tde_encryption=TransparentDataEncryption(status="Disabled"),
)
sql_server = SQLServer(set_mocked_azure_audit_info())
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][0].__class__.__name__ == "Server"
)
assert sql_server.sql_servers[AZURE_SUBSCRIPTION][0].id == "id"
assert sql_server.sql_servers[AZURE_SUBSCRIPTION][0].name == "name"
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][0].public_network_access
== "public_network_access"
)
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][0].minimal_tls_version
== "minimal_tls_version"
)
assert sql_server.sql_servers[AZURE_SUBSCRIPTION][0].administrators is None
assert sql_server.sql_servers[AZURE_SUBSCRIPTION][0].auditing_policies is None
assert sql_server.sql_servers[AZURE_SUBSCRIPTION][0].firewall_rules is None
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][
0
].encryption_protector.__class__.__name__
== "EncryptionProtector"
)
assert sql_server.sql_servers[AZURE_SUBSCRIPTION][0].databases == [database]
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][
0
].vulnerability_assessment.__class__.__name__
== "ServerVulnerabilityAssessment"
)
def test__get_databases__(self):
sql_server = SQLServer(set_mocked_azure_audit_info())
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][0]
.databases[0]
.__class__.__name__
== "Database"
)
assert sql_server.sql_servers[AZURE_SUBSCRIPTION][0].databases[0].id == "id"
assert sql_server.sql_servers[AZURE_SUBSCRIPTION][0].databases[0].name == "name"
assert sql_server.sql_servers[AZURE_SUBSCRIPTION][0].databases[0].type == "type"
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][0].databases[0].location
== "location"
)
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][0].databases[0].managed_by
== "managed_by"
)
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][0]
.databases[0]
.tde_encryption.__class__.__name__
== "TransparentDataEncryption"
)
def test__get_transparent_data_encryption__(self):
sql_server = SQLServer(set_mocked_azure_audit_info())
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][0]
.databases[0]
.tde_encryption.__class__.__name__
== "TransparentDataEncryption"
)
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][0]
.databases[0]
.tde_encryption.status
== "Disabled"
)
def test__get_encryption_protectors__(self):
sql_server = SQLServer(set_mocked_azure_audit_info())
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][
0
].encryption_protector.__class__.__name__
== "EncryptionProtector"
)
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][
0
].encryption_protector.server_key_type
== "AzureKeyVault"
)
def test__get_resource_group__(self):
id = "/subscriptions/subscription_id/resourceGroups/resource_group/providers/Microsoft.Sql/servers/sql_server"
sql_server = SQLServer(set_mocked_azure_audit_info())
assert sql_server.__get_resource_group__(id) == "resource_group"
def test__get_vulnerability_assessment__(self):
sql_server = SQLServer(set_mocked_azure_audit_info())
storage_container_path = "/subcription_id/resource_group/sql_server"
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][
0
].vulnerability_assessment.__class__.__name__
== "ServerVulnerabilityAssessment"
)
assert (
sql_server.sql_servers[AZURE_SUBSCRIPTION][
0
].vulnerability_assessment.storage_container_path
== storage_container_path
)