fix(list_checks): arn filtering checks after audit_info set (#1887)

This commit is contained in:
Nacho Rivera
2023-02-13 14:57:42 +01:00
committed by GitHub
parent 674332fddd
commit 6da45b5c2b
4 changed files with 102 additions and 34 deletions

View File

@@ -10,6 +10,7 @@ from prowler.lib.check.check import (
exclude_checks_to_run, exclude_checks_to_run,
exclude_services_to_run, exclude_services_to_run,
execute_checks, execute_checks,
get_checks_from_input_arn,
list_categories, list_categories,
list_services, list_services,
print_categories, print_categories,
@@ -99,9 +100,6 @@ def prowler():
) )
sys.exit() sys.exit()
# Set the audit info based on the selected provider
audit_info = set_provider_audit_info(provider, args.__dict__)
# Load checks to execute # Load checks to execute
checks_to_execute = load_checks_to_execute( checks_to_execute = load_checks_to_execute(
bulk_checks_metadata, bulk_checks_metadata,
@@ -113,7 +111,6 @@ def prowler():
compliance_framework, compliance_framework,
categories, categories,
provider, provider,
audit_info,
) )
# Exclude checks if -e/--excluded-checks # Exclude checks if -e/--excluded-checks
@@ -134,6 +131,15 @@ def prowler():
print_checks(provider, checks_to_execute, bulk_checks_metadata) print_checks(provider, checks_to_execute, bulk_checks_metadata)
sys.exit() sys.exit()
# Set the audit info based on the selected provider
audit_info = set_provider_audit_info(provider, args.__dict__)
# Once the audit_info is set and we have the eventual checks from arn, it is time to exclude the others
if audit_info.audit_resources:
checks_to_execute = get_checks_from_input_arn(
audit_info.audit_resources, provider
)
# Parse content from Allowlist file and get it, if necessary, from S3 # Parse content from Allowlist file and get it, if necessary, from S3
if provider == "aws" and args.allowlist_file: if provider == "aws" and args.allowlist_file:
allowlist_file = parse_allowlist_file(audit_info, args.allowlist_file) allowlist_file = parse_allowlist_file(audit_info, args.allowlist_file)

View File

@@ -371,7 +371,6 @@ def execute_checks(
# If check does not exists in the provider or is from another provider # If check does not exists in the provider or is from another provider
except ModuleNotFoundError: except ModuleNotFoundError:
logger.critical( logger.critical(
f"Check '{check_name}' was not found for the {provider.upper()} provider" f"Check '{check_name}' was not found for the {provider.upper()} provider"
) )
@@ -486,3 +485,44 @@ def update_audit_metadata(
logger.error( logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
) )
def recover_checks_from_service(service_list: list, provider: str) -> list:
checks = set()
for service in service_list:
modules = recover_checks_from_provider(provider, service)
if not modules:
logger.error(f"Service '{service}' does not have checks.")
else:
for check_module in modules:
# Recover check name and module name from import path
# Format: "providers.{provider}.services.{service}.{check_name}.{check_name}"
check_name = check_module[0].split(".")[-1]
# If the service is present in the group list passed as parameters
# if service_name in group_list: checks_from_arn.add(check_name)
checks.add(check_name)
return checks
def get_checks_from_input_arn(audit_resources: list, provider: str) -> set:
"""get_checks_from_input_arn gets the list of checks from the input arns"""
checks_from_arn = set()
# Handle if there are audit resources so only their services are executed
if audit_resources:
service_list = []
for resource in audit_resources:
service = resource.split(":")[2]
# Parse services when they are different in the ARNs
if service == "lambda":
service = "awslambda"
if service == "elasticloadbalancing":
service = "elb"
elif service == "logs":
service = "cloudwatch"
service_list.append(service)
checks_from_arn = recover_checks_from_service(service_list, provider)
# Return final checks list
return checks_from_arn

View File

@@ -2,9 +2,9 @@ from prowler.lib.check.check import (
parse_checks_from_compliance_framework, parse_checks_from_compliance_framework,
parse_checks_from_file, parse_checks_from_file,
recover_checks_from_provider, recover_checks_from_provider,
recover_checks_from_service,
) )
from prowler.lib.logger import logger from prowler.lib.logger import logger
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
# Generate the list of checks to execute # Generate the list of checks to execute
@@ -19,25 +19,10 @@ def load_checks_to_execute(
compliance_frameworks: list, compliance_frameworks: list,
categories: set, categories: set,
provider: str, provider: str,
audit_info: AWS_Audit_Info,
) -> set: ) -> set:
"""Generate the list of checks to execute based on the cloud provider and input arguments specified""" """Generate the list of checks to execute based on the cloud provider and input arguments specified"""
checks_to_execute = set() checks_to_execute = set()
# Handle if there are audit resources so only their services are executed
if audit_info.audit_resources:
service_list = []
for resource in audit_info.audit_resources:
service = resource.split(":")[2]
# Parse services when they are different in the ARNs
if service == "lambda":
service = "awslambda"
if service == "elasticloadbalancing":
service = "elb"
elif service == "logs":
service = "cloudwatch"
service_list.append(service)
# Handle if there are checks passed using -c/--checks # Handle if there are checks passed using -c/--checks
if check_list: if check_list:
for check_name in check_list: for check_name in check_list:
@@ -59,19 +44,7 @@ def load_checks_to_execute(
# Handle if there are services passed using -s/--services # Handle if there are services passed using -s/--services
elif service_list: elif service_list:
# Loaded dynamically from modules within provider/services checks_to_execute = recover_checks_from_service(service_list, provider)
for service in service_list:
modules = recover_checks_from_provider(provider, service)
if not modules:
logger.error(f"Service '{service}' does not have checks.")
else:
for check_module in modules:
# Recover check name and module name from import path
# Format: "providers.{provider}.services.{service}.{check_name}.{check_name}"
check_name = check_module[0].split(".")[-1]
# If the service is present in the group list passed as parameters
# if service_name in group_list: checks_to_execute.add(check_name)
checks_to_execute.add(check_name)
# Handle if there are compliance frameworks passed using --compliance # Handle if there are compliance frameworks passed using --compliance
elif compliance_frameworks: elif compliance_frameworks:

View File

@@ -8,10 +8,12 @@ from mock import patch
from prowler.lib.check.check import ( from prowler.lib.check.check import (
exclude_checks_to_run, exclude_checks_to_run,
exclude_services_to_run, exclude_services_to_run,
get_checks_from_input_arn,
list_modules, list_modules,
list_services, list_services,
parse_checks_from_file, parse_checks_from_file,
recover_checks_from_provider, recover_checks_from_provider,
recover_checks_from_service,
update_audit_metadata, update_audit_metadata,
) )
from prowler.lib.check.models import load_check_metadata from prowler.lib.check.models import load_check_metadata
@@ -104,6 +106,23 @@ def mock_recover_checks_from_aws_provider(*_):
] ]
def mock_recover_checks_from_aws_provider_lambda_service(*_):
return [
(
"awslambda_function_invoke_api_operations_cloudtrail_logging_enabled",
"/root_dir/fake_path/awslambda/awslambda_function_invoke_api_operations_cloudtrail_logging_enabled",
),
(
"awslambda_function_url_cors_policy",
"/root_dir/fake_path/awslambda/awslambda_function_url_cors_policy",
),
(
"awslambda_function_no_secrets_in_code",
"/root_dir/fake_path/awslambda/awslambda_function_no_secrets_in_code",
),
]
class Test_Check: class Test_Check:
def test_load_check_metadata(self): def test_load_check_metadata(self):
test_cases = [ test_cases = [
@@ -247,6 +266,36 @@ class Test_Check:
expected_modules = list_modules(provider, service) expected_modules = list_modules(provider, service)
assert expected_modules == expected_packages assert expected_modules == expected_packages
@patch(
"prowler.lib.check.check.recover_checks_from_provider",
new=mock_recover_checks_from_aws_provider,
)
def test_recover_checks_from_service(self):
service_list = ["accessanalyzer", "awslambda", "ec2"]
provider = "aws"
expected_checks = {
"accessanalyzer_enabled_without_findings",
"awslambda_function_url_cors_policy",
"ec2_securitygroup_allow_ingress_from_internet_to_any_port",
}
recovered_checks = recover_checks_from_service(service_list, provider)
assert recovered_checks == expected_checks
@patch(
"prowler.lib.check.check.recover_checks_from_provider",
new=mock_recover_checks_from_aws_provider_lambda_service,
)
def test_get_checks_from_input_arn(self):
audit_resources = ["arn:aws:lambda:us-east-1:123456789:function:test-lambda"]
provider = "aws"
expected_checks = {
"awslambda_function_url_cors_policy",
"awslambda_function_invoke_api_operations_cloudtrail_logging_enabled",
"awslambda_function_no_secrets_in_code",
}
recovered_checks = get_checks_from_input_arn(audit_resources, provider)
assert recovered_checks == expected_checks
# def test_parse_checks_from_compliance_framework_two(self): # def test_parse_checks_from_compliance_framework_two(self):
# test_case = { # test_case = {
# "input": {"compliance_frameworks": ["cis_v1.4_aws", "ens_v3_aws"]}, # "input": {"compliance_frameworks": ["cis_v1.4_aws", "ens_v3_aws"]},