fix(list_checks): arn filtering checks after audit_info set (#1887)

This commit is contained in:
Nacho Rivera
2023-02-13 14:57:42 +01:00
committed by GitHub
parent 674332fddd
commit 6da45b5c2b
4 changed files with 102 additions and 34 deletions

View File

@@ -10,6 +10,7 @@ from prowler.lib.check.check import (
exclude_checks_to_run,
exclude_services_to_run,
execute_checks,
get_checks_from_input_arn,
list_categories,
list_services,
print_categories,
@@ -99,9 +100,6 @@ def prowler():
)
sys.exit()
# Set the audit info based on the selected provider
audit_info = set_provider_audit_info(provider, args.__dict__)
# Load checks to execute
checks_to_execute = load_checks_to_execute(
bulk_checks_metadata,
@@ -113,7 +111,6 @@ def prowler():
compliance_framework,
categories,
provider,
audit_info,
)
# Exclude checks if -e/--excluded-checks
@@ -134,6 +131,15 @@ def prowler():
print_checks(provider, checks_to_execute, bulk_checks_metadata)
sys.exit()
# Set the audit info based on the selected provider
audit_info = set_provider_audit_info(provider, args.__dict__)
# Once the audit_info is set and we have the eventual checks from arn, it is time to exclude the others
if audit_info.audit_resources:
checks_to_execute = get_checks_from_input_arn(
audit_info.audit_resources, provider
)
# Parse content from Allowlist file and get it, if necessary, from S3
if provider == "aws" and args.allowlist_file:
allowlist_file = parse_allowlist_file(audit_info, args.allowlist_file)

View File

@@ -371,7 +371,6 @@ def execute_checks(
# If check does not exists in the provider or is from another provider
except ModuleNotFoundError:
logger.critical(
f"Check '{check_name}' was not found for the {provider.upper()} provider"
)
@@ -486,3 +485,44 @@ def update_audit_metadata(
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def recover_checks_from_service(service_list: list, provider: str) -> list:
checks = set()
for service in service_list:
modules = recover_checks_from_provider(provider, service)
if not modules:
logger.error(f"Service '{service}' does not have checks.")
else:
for check_module in modules:
# Recover check name and module name from import path
# Format: "providers.{provider}.services.{service}.{check_name}.{check_name}"
check_name = check_module[0].split(".")[-1]
# If the service is present in the group list passed as parameters
# if service_name in group_list: checks_from_arn.add(check_name)
checks.add(check_name)
return checks
def get_checks_from_input_arn(audit_resources: list, provider: str) -> set:
"""get_checks_from_input_arn gets the list of checks from the input arns"""
checks_from_arn = set()
# Handle if there are audit resources so only their services are executed
if audit_resources:
service_list = []
for resource in audit_resources:
service = resource.split(":")[2]
# Parse services when they are different in the ARNs
if service == "lambda":
service = "awslambda"
if service == "elasticloadbalancing":
service = "elb"
elif service == "logs":
service = "cloudwatch"
service_list.append(service)
checks_from_arn = recover_checks_from_service(service_list, provider)
# Return final checks list
return checks_from_arn

View File

@@ -2,9 +2,9 @@ from prowler.lib.check.check import (
parse_checks_from_compliance_framework,
parse_checks_from_file,
recover_checks_from_provider,
recover_checks_from_service,
)
from prowler.lib.logger import logger
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
# Generate the list of checks to execute
@@ -19,25 +19,10 @@ def load_checks_to_execute(
compliance_frameworks: list,
categories: set,
provider: str,
audit_info: AWS_Audit_Info,
) -> set:
"""Generate the list of checks to execute based on the cloud provider and input arguments specified"""
checks_to_execute = set()
# Handle if there are audit resources so only their services are executed
if audit_info.audit_resources:
service_list = []
for resource in audit_info.audit_resources:
service = resource.split(":")[2]
# Parse services when they are different in the ARNs
if service == "lambda":
service = "awslambda"
if service == "elasticloadbalancing":
service = "elb"
elif service == "logs":
service = "cloudwatch"
service_list.append(service)
# Handle if there are checks passed using -c/--checks
if check_list:
for check_name in check_list:
@@ -59,19 +44,7 @@ def load_checks_to_execute(
# Handle if there are services passed using -s/--services
elif service_list:
# Loaded dynamically from modules within provider/services
for service in service_list:
modules = recover_checks_from_provider(provider, service)
if not modules:
logger.error(f"Service '{service}' does not have checks.")
else:
for check_module in modules:
# Recover check name and module name from import path
# Format: "providers.{provider}.services.{service}.{check_name}.{check_name}"
check_name = check_module[0].split(".")[-1]
# If the service is present in the group list passed as parameters
# if service_name in group_list: checks_to_execute.add(check_name)
checks_to_execute.add(check_name)
checks_to_execute = recover_checks_from_service(service_list, provider)
# Handle if there are compliance frameworks passed using --compliance
elif compliance_frameworks:

View File

@@ -8,10 +8,12 @@ from mock import patch
from prowler.lib.check.check import (
exclude_checks_to_run,
exclude_services_to_run,
get_checks_from_input_arn,
list_modules,
list_services,
parse_checks_from_file,
recover_checks_from_provider,
recover_checks_from_service,
update_audit_metadata,
)
from prowler.lib.check.models import load_check_metadata
@@ -104,6 +106,23 @@ def mock_recover_checks_from_aws_provider(*_):
]
def mock_recover_checks_from_aws_provider_lambda_service(*_):
return [
(
"awslambda_function_invoke_api_operations_cloudtrail_logging_enabled",
"/root_dir/fake_path/awslambda/awslambda_function_invoke_api_operations_cloudtrail_logging_enabled",
),
(
"awslambda_function_url_cors_policy",
"/root_dir/fake_path/awslambda/awslambda_function_url_cors_policy",
),
(
"awslambda_function_no_secrets_in_code",
"/root_dir/fake_path/awslambda/awslambda_function_no_secrets_in_code",
),
]
class Test_Check:
def test_load_check_metadata(self):
test_cases = [
@@ -247,6 +266,36 @@ class Test_Check:
expected_modules = list_modules(provider, service)
assert expected_modules == expected_packages
@patch(
"prowler.lib.check.check.recover_checks_from_provider",
new=mock_recover_checks_from_aws_provider,
)
def test_recover_checks_from_service(self):
service_list = ["accessanalyzer", "awslambda", "ec2"]
provider = "aws"
expected_checks = {
"accessanalyzer_enabled_without_findings",
"awslambda_function_url_cors_policy",
"ec2_securitygroup_allow_ingress_from_internet_to_any_port",
}
recovered_checks = recover_checks_from_service(service_list, provider)
assert recovered_checks == expected_checks
@patch(
"prowler.lib.check.check.recover_checks_from_provider",
new=mock_recover_checks_from_aws_provider_lambda_service,
)
def test_get_checks_from_input_arn(self):
audit_resources = ["arn:aws:lambda:us-east-1:123456789:function:test-lambda"]
provider = "aws"
expected_checks = {
"awslambda_function_url_cors_policy",
"awslambda_function_invoke_api_operations_cloudtrail_logging_enabled",
"awslambda_function_no_secrets_in_code",
}
recovered_checks = get_checks_from_input_arn(audit_resources, provider)
assert recovered_checks == expected_checks
# def test_parse_checks_from_compliance_framework_two(self):
# test_case = {
# "input": {"compliance_frameworks": ["cis_v1.4_aws", "ens_v3_aws"]},