From cc58e06b5ebe481ab1edef07cf9df4b93f8a83d4 Mon Sep 17 00:00:00 2001 From: Pepe Fagoaga Date: Thu, 16 Mar 2023 17:32:53 +0100 Subject: [PATCH] fix(providers): Move provider's logic outside main (#2043) Co-authored-by: Sergio Garcia --- prowler/__main__.py | 36 ++++------ prowler/lib/check/check.py | 70 ------------------ prowler/providers/aws/aws_provider.py | 71 +++++++++++++++++++ .../aws/lib/resource_api_tagging/__init__.py | 0 .../resource_api_tagging.py | 38 ++++++++++ prowler/providers/common/allowlist.py | 35 +++++++++ prowler/providers/common/audit_info.py | 51 +++++++------ prowler/providers/common/quick_inventory.py | 26 +++++++ tests/lib/check/check_test.py | 6 +- 9 files changed, 217 insertions(+), 116 deletions(-) create mode 100644 prowler/providers/aws/lib/resource_api_tagging/__init__.py create mode 100644 prowler/providers/aws/lib/resource_api_tagging/resource_api_tagging.py create mode 100644 prowler/providers/common/allowlist.py create mode 100644 prowler/providers/common/quick_inventory.py diff --git a/prowler/__main__.py b/prowler/__main__.py index 0e231773..ca5dc3b9 100644 --- a/prowler/__main__.py +++ b/prowler/__main__.py @@ -10,8 +10,6 @@ from prowler.lib.check.check import ( exclude_checks_to_run, exclude_services_to_run, execute_checks, - get_checks_from_input_arn, - get_regions_from_audit_resources, list_categories, list_services, print_categories, @@ -29,13 +27,16 @@ from prowler.lib.outputs.html import add_html_footer, fill_html_overview_statist from prowler.lib.outputs.json import close_json from prowler.lib.outputs.outputs import extract_findings_statistics, send_to_s3_bucket from prowler.lib.outputs.summary_table import display_summary_table -from prowler.providers.aws.lib.allowlist.allowlist import parse_allowlist_file -from prowler.providers.aws.lib.quick_inventory.quick_inventory import quick_inventory from prowler.providers.aws.lib.security_hub.security_hub import ( resolve_security_hub_previous_findings, ) -from prowler.providers.common.audit_info import set_provider_audit_info +from prowler.providers.common.allowlist import set_provider_allowlist +from prowler.providers.common.audit_info import ( + set_provider_audit_info, + set_provider_execution_parameters, +) from prowler.providers.common.outputs import set_provider_output_options +from prowler.providers.common.quick_inventory import run_provider_quick_inventory def prowler(): @@ -128,29 +129,22 @@ def prowler(): # Set the audit info based on the selected provider audit_info = set_provider_audit_info(provider, args.__dict__) - # Once the audit_info is set and we have the eventual checks from arn, it is time to exclude the others + # Once the audit_info is set and we have the eventual checks based on the resource identifier, + # it is time to check what Prowler's checks are going to be executed if audit_info.audit_resources: - audit_info.audited_regions = get_regions_from_audit_resources( - audit_info.audit_resources - ) - checks_to_execute = get_checks_from_input_arn( - audit_info.audit_resources, provider - ) + checks_to_execute = set_provider_execution_parameters(provider, audit_info) - # Parse content from Allowlist file and get it, if necessary, from S3 - if provider == "aws" and args.allowlist_file: - allowlist_file = parse_allowlist_file(audit_info, args.allowlist_file) - else: - allowlist_file = None + # Parse Allowlist + allowlist_file = set_provider_allowlist(provider, audit_info, args) - # Setting output options based on the selected provider + # Set output options based on the selected provider audit_output_options = set_provider_output_options( provider, args, audit_info, allowlist_file, bulk_checks_metadata ) - # Quick Inventory for AWS - if provider == "aws" and args.quick_inventory: - quick_inventory(audit_info, args.output_directory) + # Run the quick inventory for the provider if available + if hasattr(args, "quick_inventory") and args.quick_inventory: + run_provider_quick_inventory(provider, audit_info, args.output_directory) sys.exit() # Execute checks diff --git a/prowler/lib/check/check.py b/prowler/lib/check/check.py index 6df42580..7f29a73a 100644 --- a/prowler/lib/check/check.py +++ b/prowler/lib/check/check.py @@ -517,73 +517,3 @@ def recover_checks_from_service(service_list: list, provider: str) -> list: # if service_name in group_list: checks_from_arn.add(check_name) checks.add(check_name) return checks - - -def get_checks_from_input_arn(audit_resources: list, provider: str) -> set: - """get_checks_from_input_arn gets the list of checks from the input arns""" - checks_from_arn = set() - # Handle if there are audit resources so only their services are executed - if audit_resources: - services_without_subservices = ["guardduty", "kms", "s3", "elb"] - service_list = set() - sub_service_list = set() - for resource in audit_resources: - service = resource.split(":")[2] - sub_service = resource.split(":")[5].split("/")[0].replace("-", "_") - # WAF Services does not have checks - if service != "wafv2" and service != "waf": - # Parse services when they are different in the ARNs - if service == "lambda": - service = "awslambda" - if service == "elasticloadbalancing": - service = "elb" - elif service == "logs": - service = "cloudwatch" - # Check if Prowler has checks in service - try: - list_modules(provider, service) - except ModuleNotFoundError: - # Service is not supported - pass - else: - service_list.add(service) - - # Get subservices to execute only applicable checks - if service not in services_without_subservices: - # Parse some specific subservices - if service == "ec2": - if sub_service == "security_group": - sub_service = "securitygroup" - if sub_service == "network_acl": - sub_service = "networkacl" - if sub_service == "image": - sub_service = "ami" - if service == "rds": - if sub_service == "cluster_snapshot": - sub_service = "snapshot" - sub_service_list.add(sub_service) - else: - sub_service_list.add(service) - - checks = recover_checks_from_service(service_list, provider) - - # Filter only checks with audited subservices - for check in checks: - if any(sub_service in check for sub_service in sub_service_list): - if not (sub_service == "policy" and "password_policy" in check): - checks_from_arn.add(check) - - # Return final checks list - return sorted(checks_from_arn) - - -def get_regions_from_audit_resources(audit_resources: list) -> list: - """get_regions_from_audit_resources gets the regions from the audit resources arns""" - audited_regions = [] - for resource in audit_resources: - region = resource.split(":")[3] - if region and region not in audited_regions: # Check if arn has a region - audited_regions.append(region) - if audited_regions: - return audited_regions - return None diff --git a/prowler/providers/aws/aws_provider.py b/prowler/providers/aws/aws_provider.py index 6f4ca198..34f3d778 100644 --- a/prowler/providers/aws/aws_provider.py +++ b/prowler/providers/aws/aws_provider.py @@ -7,6 +7,7 @@ from botocore.credentials import RefreshableCredentials from botocore.session import get_session from prowler.config.config import aws_services_json_file +from prowler.lib.check.check import list_modules, recover_checks_from_service from prowler.lib.logger import logger from prowler.lib.utils.utils import open_file, parse_json_file from prowler.providers.aws.lib.audit_info.models import AWS_Assume_Role, AWS_Audit_Info @@ -156,3 +157,73 @@ def get_aws_available_regions(): except Exception as error: logger.error(f"{error.__class__.__name__}: {error}") return [] + + +def get_checks_from_input_arn(audit_resources: list, provider: str) -> set: + """get_checks_from_input_arn gets the list of checks from the input arns""" + checks_from_arn = set() + # Handle if there are audit resources so only their services are executed + if audit_resources: + services_without_subservices = ["guardduty", "kms", "s3", "elb"] + service_list = set() + sub_service_list = set() + for resource in audit_resources: + service = resource.split(":")[2] + sub_service = resource.split(":")[5].split("/")[0].replace("-", "_") + # WAF Services does not have checks + if service != "wafv2" and service != "waf": + # Parse services when they are different in the ARNs + if service == "lambda": + service = "awslambda" + if service == "elasticloadbalancing": + service = "elb" + elif service == "logs": + service = "cloudwatch" + # Check if Prowler has checks in service + try: + list_modules(provider, service) + except ModuleNotFoundError: + # Service is not supported + pass + else: + service_list.add(service) + + # Get subservices to execute only applicable checks + if service not in services_without_subservices: + # Parse some specific subservices + if service == "ec2": + if sub_service == "security_group": + sub_service = "securitygroup" + if sub_service == "network_acl": + sub_service = "networkacl" + if sub_service == "image": + sub_service = "ami" + if service == "rds": + if sub_service == "cluster_snapshot": + sub_service = "snapshot" + sub_service_list.add(sub_service) + else: + sub_service_list.add(service) + + checks = recover_checks_from_service(service_list, provider) + + # Filter only checks with audited subservices + for check in checks: + if any(sub_service in check for sub_service in sub_service_list): + if not (sub_service == "policy" and "password_policy" in check): + checks_from_arn.add(check) + + # Return final checks list + return sorted(checks_from_arn) + + +def get_regions_from_audit_resources(audit_resources: list) -> list: + """get_regions_from_audit_resources gets the regions from the audit resources arns""" + audited_regions = [] + for resource in audit_resources: + region = resource.split(":")[3] + if region and region not in audited_regions: # Check if arn has a region + audited_regions.append(region) + if audited_regions: + return audited_regions + return None diff --git a/prowler/providers/aws/lib/resource_api_tagging/__init__.py b/prowler/providers/aws/lib/resource_api_tagging/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/prowler/providers/aws/lib/resource_api_tagging/resource_api_tagging.py b/prowler/providers/aws/lib/resource_api_tagging/resource_api_tagging.py new file mode 100644 index 00000000..66310c46 --- /dev/null +++ b/prowler/providers/aws/lib/resource_api_tagging/resource_api_tagging.py @@ -0,0 +1,38 @@ +import sys + +from prowler.lib.logger import logger +from prowler.providers.aws.aws_provider import generate_regional_clients +from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info + + +def get_tagged_resources(input_resource_tags: list, current_audit_info: AWS_Audit_Info): + """ + get_tagged_resources returns a list of the resources that are going to be scanned based on the given input tags + """ + try: + resource_tags = [] + tagged_resources = [] + for tag in input_resource_tags: + key = tag.split("=")[0] + value = tag.split("=")[1] + resource_tags.append({"Key": key, "Values": [value]}) + # Get Resources with resource_tags for all regions + for regional_client in generate_regional_clients( + "resourcegroupstaggingapi", current_audit_info + ).values(): + try: + get_resources_paginator = regional_client.get_paginator("get_resources") + for page in get_resources_paginator.paginate(TagFilters=resource_tags): + for resource in page["ResourceTagMappingList"]: + tagged_resources.append(resource["ResourceARN"]) + except Exception as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + except Exception as error: + logger.critical( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + sys.exit(1) + else: + return tagged_resources diff --git a/prowler/providers/common/allowlist.py b/prowler/providers/common/allowlist.py new file mode 100644 index 00000000..0529ec83 --- /dev/null +++ b/prowler/providers/common/allowlist.py @@ -0,0 +1,35 @@ +import importlib +import sys + +from prowler.lib.logger import logger +from prowler.providers.aws.lib.allowlist.allowlist import parse_allowlist_file + + +def set_provider_allowlist(provider, audit_info, args): + """ + set_provider_allowlist configures the allowlist based on the selected provider. + """ + try: + # Check if the provider arguments has the allowlist_file + if hasattr(args, "allowlist_file"): + # Dynamically get the Provider allowlist handler + provider_allowlist_function = f"set_{provider}_allowlist" + allowlist_file = getattr( + importlib.import_module(__name__), provider_allowlist_function + )(audit_info, args.allowlist_file) + + return allowlist_file + except Exception as error: + logger.critical( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + sys.exit(1) + + +def set_aws_allowlist(audit_info, allowlist_file): + # Parse content from Allowlist file and get it, if necessary, from S3 + if allowlist_file: + allowlist_file = parse_allowlist_file(audit_info, allowlist_file) + else: + allowlist_file = None + return allowlist_file diff --git a/prowler/providers/common/audit_info.py b/prowler/providers/common/audit_info.py index 3e2a66f6..b51106b3 100644 --- a/prowler/providers/common/audit_info.py +++ b/prowler/providers/common/audit_info.py @@ -9,7 +9,8 @@ from prowler.lib.logger import logger from prowler.providers.aws.aws_provider import ( AWS_Provider, assume_role, - generate_regional_clients, + get_checks_from_input_arn, + get_regions_from_audit_resources, ) from prowler.providers.aws.lib.arn.arn import arn_parsing from prowler.providers.aws.lib.audit_info.audit_info import current_audit_info @@ -18,6 +19,9 @@ from prowler.providers.aws.lib.audit_info.models import ( AWS_Credentials, AWS_Organizations_Info, ) +from prowler.providers.aws.lib.resource_api_tagging.resource_api_tagging import ( + get_tagged_resources, +) from prowler.providers.azure.azure_provider import Azure_Provider from prowler.providers.azure.lib.audit_info.audit_info import azure_audit_info from prowler.providers.azure.lib.audit_info.models import Azure_Audit_Info @@ -268,6 +272,20 @@ Caller Identity ARN: {Fore.YELLOW}[{audit_info.audited_identity_arn}]{Style.RESE return current_audit_info + def set_aws_execution_parameters(self, provider, audit_info) -> list[str]: + # Once the audit_info is set and we have the eventual checks from arn, it is time to exclude the others + try: + if audit_info.audit_resources: + audit_info.audited_regions = get_regions_from_audit_resources( + audit_info.audit_resources + ) + return get_checks_from_input_arn(audit_info.audit_resources, provider) + except Exception as error: + logger.critical( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + sys.exit(1) + def set_azure_audit_info(self, arguments) -> Azure_Audit_Info: """ set_azure_audit_info returns the Azure_Audit_Info @@ -319,34 +337,21 @@ def set_provider_audit_info(provider: str, arguments: dict): return provider_audit_info -def get_tagged_resources(input_resource_tags: list, current_audit_info: AWS_Audit_Info): +def set_provider_execution_parameters(provider: str, audit_info): """ - get_tagged_resources returns a list of the resources that are going to be scanned based on the given input tags + set_provider_audit_info configures automatically the audit execution based on the selected provider and returns the checks that are going to be executed. """ try: - resource_tags = [] - tagged_resources = [] - for tag in input_resource_tags: - key = tag.split("=")[0] - value = tag.split("=")[1] - resource_tags.append({"Key": key, "Values": [value]}) - # Get Resources with resource_tags for all regions - for regional_client in generate_regional_clients( - "resourcegroupstaggingapi", current_audit_info - ).values(): - try: - get_resources_paginator = regional_client.get_paginator("get_resources") - for page in get_resources_paginator.paginate(TagFilters=resource_tags): - for resource in page["ResourceTagMappingList"]: - tagged_resources.append(resource["ResourceARN"]) - except Exception as error: - logger.error( - f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" - ) + set_provider_execution_parameters_function = ( + f"set_{provider}_execution_parameters" + ) + checks_to_execute = getattr( + Audit_Info(), set_provider_execution_parameters_function + )(provider, audit_info) except Exception as error: logger.critical( f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) sys.exit(1) else: - return tagged_resources + return checks_to_execute diff --git a/prowler/providers/common/quick_inventory.py b/prowler/providers/common/quick_inventory.py new file mode 100644 index 00000000..49437b9d --- /dev/null +++ b/prowler/providers/common/quick_inventory.py @@ -0,0 +1,26 @@ +import importlib +import sys + +from prowler.lib.logger import logger +from prowler.providers.aws.lib.quick_inventory.quick_inventory import quick_inventory + + +def run_provider_quick_inventory(provider, audit_info, output_directory): + """ + run_provider_quick_inventory executes the quick inventory for te provider + """ + try: + # Dynamically get the Provider quick inventory handler + provider_quick_inventory_function = f"{provider}_quick_inventory" + getattr(importlib.import_module(__name__), provider_quick_inventory_function)( + audit_info, output_directory + ) + except Exception as error: + logger.critical( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + sys.exit(1) + + +def aws_quick_inventory(audit_info, output_directory): + quick_inventory(audit_info, output_directory) diff --git a/tests/lib/check/check_test.py b/tests/lib/check/check_test.py index f48266fd..15fa4129 100644 --- a/tests/lib/check/check_test.py +++ b/tests/lib/check/check_test.py @@ -8,8 +8,6 @@ from mock import patch from prowler.lib.check.check import ( exclude_checks_to_run, exclude_services_to_run, - get_checks_from_input_arn, - get_regions_from_audit_resources, list_modules, list_services, parse_checks_from_file, @@ -18,6 +16,10 @@ from prowler.lib.check.check import ( update_audit_metadata, ) from prowler.lib.check.models import load_check_metadata +from prowler.providers.aws.aws_provider import ( + get_checks_from_input_arn, + get_regions_from_audit_resources, +) expected_packages = [ ModuleInfo(