From 6da45b5c2b4d4df2bea0ca6059602f64d49e4340 Mon Sep 17 00:00:00 2001 From: Nacho Rivera Date: Mon, 13 Feb 2023 14:57:42 +0100 Subject: [PATCH] fix(list_checks): arn filtering checks after audit_info set (#1887) --- prowler/__main__.py | 14 ++++++--- prowler/lib/check/check.py | 42 ++++++++++++++++++++++++- prowler/lib/check/checks_loader.py | 31 ++----------------- tests/lib/check/check_test.py | 49 ++++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 34 deletions(-) diff --git a/prowler/__main__.py b/prowler/__main__.py index c470128a..5d073eb6 100644 --- a/prowler/__main__.py +++ b/prowler/__main__.py @@ -10,6 +10,7 @@ from prowler.lib.check.check import ( exclude_checks_to_run, exclude_services_to_run, execute_checks, + get_checks_from_input_arn, list_categories, list_services, print_categories, @@ -99,9 +100,6 @@ def prowler(): ) sys.exit() - # Set the audit info based on the selected provider - audit_info = set_provider_audit_info(provider, args.__dict__) - # Load checks to execute checks_to_execute = load_checks_to_execute( bulk_checks_metadata, @@ -113,7 +111,6 @@ def prowler(): compliance_framework, categories, provider, - audit_info, ) # Exclude checks if -e/--excluded-checks @@ -134,6 +131,15 @@ def prowler(): print_checks(provider, checks_to_execute, bulk_checks_metadata) sys.exit() + # Set the audit info based on the selected provider + audit_info = set_provider_audit_info(provider, args.__dict__) + + # Once the audit_info is set and we have the eventual checks from arn, it is time to exclude the others + if audit_info.audit_resources: + checks_to_execute = get_checks_from_input_arn( + audit_info.audit_resources, provider + ) + # Parse content from Allowlist file and get it, if necessary, from S3 if provider == "aws" and args.allowlist_file: allowlist_file = parse_allowlist_file(audit_info, args.allowlist_file) diff --git a/prowler/lib/check/check.py b/prowler/lib/check/check.py index 4ca67c77..430d33fc 100644 --- a/prowler/lib/check/check.py +++ b/prowler/lib/check/check.py @@ -371,7 +371,6 @@ def execute_checks( # If check does not exists in the provider or is from another provider except ModuleNotFoundError: - logger.critical( f"Check '{check_name}' was not found for the {provider.upper()} provider" ) @@ -486,3 +485,44 @@ def update_audit_metadata( logger.error( f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) + + +def recover_checks_from_service(service_list: list, provider: str) -> list: + checks = set() + for service in service_list: + modules = recover_checks_from_provider(provider, service) + if not modules: + logger.error(f"Service '{service}' does not have checks.") + + else: + for check_module in modules: + # Recover check name and module name from import path + # Format: "providers.{provider}.services.{service}.{check_name}.{check_name}" + check_name = check_module[0].split(".")[-1] + # If the service is present in the group list passed as parameters + # if service_name in group_list: checks_from_arn.add(check_name) + checks.add(check_name) + return checks + + +def get_checks_from_input_arn(audit_resources: list, provider: str) -> set: + """get_checks_from_input_arn gets the list of checks from the input arns""" + checks_from_arn = set() + # Handle if there are audit resources so only their services are executed + if audit_resources: + service_list = [] + for resource in audit_resources: + service = resource.split(":")[2] + # Parse services when they are different in the ARNs + if service == "lambda": + service = "awslambda" + if service == "elasticloadbalancing": + service = "elb" + elif service == "logs": + service = "cloudwatch" + service_list.append(service) + + checks_from_arn = recover_checks_from_service(service_list, provider) + + # Return final checks list + return checks_from_arn diff --git a/prowler/lib/check/checks_loader.py b/prowler/lib/check/checks_loader.py index c1d0337a..1fd2fc70 100644 --- a/prowler/lib/check/checks_loader.py +++ b/prowler/lib/check/checks_loader.py @@ -2,9 +2,9 @@ from prowler.lib.check.check import ( parse_checks_from_compliance_framework, parse_checks_from_file, recover_checks_from_provider, + recover_checks_from_service, ) from prowler.lib.logger import logger -from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info # Generate the list of checks to execute @@ -19,25 +19,10 @@ def load_checks_to_execute( compliance_frameworks: list, categories: set, provider: str, - audit_info: AWS_Audit_Info, ) -> set: """Generate the list of checks to execute based on the cloud provider and input arguments specified""" checks_to_execute = set() - # Handle if there are audit resources so only their services are executed - if audit_info.audit_resources: - service_list = [] - for resource in audit_info.audit_resources: - service = resource.split(":")[2] - # Parse services when they are different in the ARNs - if service == "lambda": - service = "awslambda" - if service == "elasticloadbalancing": - service = "elb" - elif service == "logs": - service = "cloudwatch" - service_list.append(service) - # Handle if there are checks passed using -c/--checks if check_list: for check_name in check_list: @@ -59,19 +44,7 @@ def load_checks_to_execute( # Handle if there are services passed using -s/--services elif service_list: - # Loaded dynamically from modules within provider/services - for service in service_list: - modules = recover_checks_from_provider(provider, service) - if not modules: - logger.error(f"Service '{service}' does not have checks.") - else: - for check_module in modules: - # Recover check name and module name from import path - # Format: "providers.{provider}.services.{service}.{check_name}.{check_name}" - check_name = check_module[0].split(".")[-1] - # If the service is present in the group list passed as parameters - # if service_name in group_list: checks_to_execute.add(check_name) - checks_to_execute.add(check_name) + checks_to_execute = recover_checks_from_service(service_list, provider) # Handle if there are compliance frameworks passed using --compliance elif compliance_frameworks: diff --git a/tests/lib/check/check_test.py b/tests/lib/check/check_test.py index f34f4681..6e927d56 100644 --- a/tests/lib/check/check_test.py +++ b/tests/lib/check/check_test.py @@ -8,10 +8,12 @@ from mock import patch from prowler.lib.check.check import ( exclude_checks_to_run, exclude_services_to_run, + get_checks_from_input_arn, list_modules, list_services, parse_checks_from_file, recover_checks_from_provider, + recover_checks_from_service, update_audit_metadata, ) from prowler.lib.check.models import load_check_metadata @@ -104,6 +106,23 @@ def mock_recover_checks_from_aws_provider(*_): ] +def mock_recover_checks_from_aws_provider_lambda_service(*_): + return [ + ( + "awslambda_function_invoke_api_operations_cloudtrail_logging_enabled", + "/root_dir/fake_path/awslambda/awslambda_function_invoke_api_operations_cloudtrail_logging_enabled", + ), + ( + "awslambda_function_url_cors_policy", + "/root_dir/fake_path/awslambda/awslambda_function_url_cors_policy", + ), + ( + "awslambda_function_no_secrets_in_code", + "/root_dir/fake_path/awslambda/awslambda_function_no_secrets_in_code", + ), + ] + + class Test_Check: def test_load_check_metadata(self): test_cases = [ @@ -247,6 +266,36 @@ class Test_Check: expected_modules = list_modules(provider, service) assert expected_modules == expected_packages + @patch( + "prowler.lib.check.check.recover_checks_from_provider", + new=mock_recover_checks_from_aws_provider, + ) + def test_recover_checks_from_service(self): + service_list = ["accessanalyzer", "awslambda", "ec2"] + provider = "aws" + expected_checks = { + "accessanalyzer_enabled_without_findings", + "awslambda_function_url_cors_policy", + "ec2_securitygroup_allow_ingress_from_internet_to_any_port", + } + recovered_checks = recover_checks_from_service(service_list, provider) + assert recovered_checks == expected_checks + + @patch( + "prowler.lib.check.check.recover_checks_from_provider", + new=mock_recover_checks_from_aws_provider_lambda_service, + ) + def test_get_checks_from_input_arn(self): + audit_resources = ["arn:aws:lambda:us-east-1:123456789:function:test-lambda"] + provider = "aws" + expected_checks = { + "awslambda_function_url_cors_policy", + "awslambda_function_invoke_api_operations_cloudtrail_logging_enabled", + "awslambda_function_no_secrets_in_code", + } + recovered_checks = get_checks_from_input_arn(audit_resources, provider) + assert recovered_checks == expected_checks + # def test_parse_checks_from_compliance_framework_two(self): # test_case = { # "input": {"compliance_frameworks": ["cis_v1.4_aws", "ens_v3_aws"]},