fix(providers): Move provider's logic outside main (#2043)

Co-authored-by: Sergio Garcia <sergargar1@gmail.com>
This commit is contained in:
Pepe Fagoaga
2023-03-16 17:32:53 +01:00
committed by GitHub
parent 0d6ca606ea
commit cc58e06b5e
9 changed files with 217 additions and 116 deletions

View File

@@ -10,8 +10,6 @@ from prowler.lib.check.check import (
exclude_checks_to_run,
exclude_services_to_run,
execute_checks,
get_checks_from_input_arn,
get_regions_from_audit_resources,
list_categories,
list_services,
print_categories,
@@ -29,13 +27,16 @@ from prowler.lib.outputs.html import add_html_footer, fill_html_overview_statist
from prowler.lib.outputs.json import close_json
from prowler.lib.outputs.outputs import extract_findings_statistics, send_to_s3_bucket
from prowler.lib.outputs.summary_table import display_summary_table
from prowler.providers.aws.lib.allowlist.allowlist import parse_allowlist_file
from prowler.providers.aws.lib.quick_inventory.quick_inventory import quick_inventory
from prowler.providers.aws.lib.security_hub.security_hub import (
resolve_security_hub_previous_findings,
)
from prowler.providers.common.audit_info import set_provider_audit_info
from prowler.providers.common.allowlist import set_provider_allowlist
from prowler.providers.common.audit_info import (
set_provider_audit_info,
set_provider_execution_parameters,
)
from prowler.providers.common.outputs import set_provider_output_options
from prowler.providers.common.quick_inventory import run_provider_quick_inventory
def prowler():
@@ -128,29 +129,22 @@ def prowler():
# Set the audit info based on the selected provider
audit_info = set_provider_audit_info(provider, args.__dict__)
# Once the audit_info is set and we have the eventual checks from arn, it is time to exclude the others
# Once the audit_info is set and we have the eventual checks based on the resource identifier,
# it is time to check what Prowler's checks are going to be executed
if audit_info.audit_resources:
audit_info.audited_regions = get_regions_from_audit_resources(
audit_info.audit_resources
)
checks_to_execute = get_checks_from_input_arn(
audit_info.audit_resources, provider
)
checks_to_execute = set_provider_execution_parameters(provider, audit_info)
# Parse content from Allowlist file and get it, if necessary, from S3
if provider == "aws" and args.allowlist_file:
allowlist_file = parse_allowlist_file(audit_info, args.allowlist_file)
else:
allowlist_file = None
# Parse Allowlist
allowlist_file = set_provider_allowlist(provider, audit_info, args)
# Setting output options based on the selected provider
# Set output options based on the selected provider
audit_output_options = set_provider_output_options(
provider, args, audit_info, allowlist_file, bulk_checks_metadata
)
# Quick Inventory for AWS
if provider == "aws" and args.quick_inventory:
quick_inventory(audit_info, args.output_directory)
# Run the quick inventory for the provider if available
if hasattr(args, "quick_inventory") and args.quick_inventory:
run_provider_quick_inventory(provider, audit_info, args.output_directory)
sys.exit()
# Execute checks

View File

@@ -517,73 +517,3 @@ def recover_checks_from_service(service_list: list, provider: str) -> list:
# if service_name in group_list: checks_from_arn.add(check_name)
checks.add(check_name)
return checks
def get_checks_from_input_arn(audit_resources: list, provider: str) -> set:
"""get_checks_from_input_arn gets the list of checks from the input arns"""
checks_from_arn = set()
# Handle if there are audit resources so only their services are executed
if audit_resources:
services_without_subservices = ["guardduty", "kms", "s3", "elb"]
service_list = set()
sub_service_list = set()
for resource in audit_resources:
service = resource.split(":")[2]
sub_service = resource.split(":")[5].split("/")[0].replace("-", "_")
# WAF Services does not have checks
if service != "wafv2" and service != "waf":
# Parse services when they are different in the ARNs
if service == "lambda":
service = "awslambda"
if service == "elasticloadbalancing":
service = "elb"
elif service == "logs":
service = "cloudwatch"
# Check if Prowler has checks in service
try:
list_modules(provider, service)
except ModuleNotFoundError:
# Service is not supported
pass
else:
service_list.add(service)
# Get subservices to execute only applicable checks
if service not in services_without_subservices:
# Parse some specific subservices
if service == "ec2":
if sub_service == "security_group":
sub_service = "securitygroup"
if sub_service == "network_acl":
sub_service = "networkacl"
if sub_service == "image":
sub_service = "ami"
if service == "rds":
if sub_service == "cluster_snapshot":
sub_service = "snapshot"
sub_service_list.add(sub_service)
else:
sub_service_list.add(service)
checks = recover_checks_from_service(service_list, provider)
# Filter only checks with audited subservices
for check in checks:
if any(sub_service in check for sub_service in sub_service_list):
if not (sub_service == "policy" and "password_policy" in check):
checks_from_arn.add(check)
# Return final checks list
return sorted(checks_from_arn)
def get_regions_from_audit_resources(audit_resources: list) -> list:
"""get_regions_from_audit_resources gets the regions from the audit resources arns"""
audited_regions = []
for resource in audit_resources:
region = resource.split(":")[3]
if region and region not in audited_regions: # Check if arn has a region
audited_regions.append(region)
if audited_regions:
return audited_regions
return None

View File

@@ -7,6 +7,7 @@ from botocore.credentials import RefreshableCredentials
from botocore.session import get_session
from prowler.config.config import aws_services_json_file
from prowler.lib.check.check import list_modules, recover_checks_from_service
from prowler.lib.logger import logger
from prowler.lib.utils.utils import open_file, parse_json_file
from prowler.providers.aws.lib.audit_info.models import AWS_Assume_Role, AWS_Audit_Info
@@ -156,3 +157,73 @@ def get_aws_available_regions():
except Exception as error:
logger.error(f"{error.__class__.__name__}: {error}")
return []
def get_checks_from_input_arn(audit_resources: list, provider: str) -> set:
"""get_checks_from_input_arn gets the list of checks from the input arns"""
checks_from_arn = set()
# Handle if there are audit resources so only their services are executed
if audit_resources:
services_without_subservices = ["guardduty", "kms", "s3", "elb"]
service_list = set()
sub_service_list = set()
for resource in audit_resources:
service = resource.split(":")[2]
sub_service = resource.split(":")[5].split("/")[0].replace("-", "_")
# WAF Services does not have checks
if service != "wafv2" and service != "waf":
# Parse services when they are different in the ARNs
if service == "lambda":
service = "awslambda"
if service == "elasticloadbalancing":
service = "elb"
elif service == "logs":
service = "cloudwatch"
# Check if Prowler has checks in service
try:
list_modules(provider, service)
except ModuleNotFoundError:
# Service is not supported
pass
else:
service_list.add(service)
# Get subservices to execute only applicable checks
if service not in services_without_subservices:
# Parse some specific subservices
if service == "ec2":
if sub_service == "security_group":
sub_service = "securitygroup"
if sub_service == "network_acl":
sub_service = "networkacl"
if sub_service == "image":
sub_service = "ami"
if service == "rds":
if sub_service == "cluster_snapshot":
sub_service = "snapshot"
sub_service_list.add(sub_service)
else:
sub_service_list.add(service)
checks = recover_checks_from_service(service_list, provider)
# Filter only checks with audited subservices
for check in checks:
if any(sub_service in check for sub_service in sub_service_list):
if not (sub_service == "policy" and "password_policy" in check):
checks_from_arn.add(check)
# Return final checks list
return sorted(checks_from_arn)
def get_regions_from_audit_resources(audit_resources: list) -> list:
"""get_regions_from_audit_resources gets the regions from the audit resources arns"""
audited_regions = []
for resource in audit_resources:
region = resource.split(":")[3]
if region and region not in audited_regions: # Check if arn has a region
audited_regions.append(region)
if audited_regions:
return audited_regions
return None

View File

@@ -0,0 +1,38 @@
import sys
from prowler.lib.logger import logger
from prowler.providers.aws.aws_provider import generate_regional_clients
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
def get_tagged_resources(input_resource_tags: list, current_audit_info: AWS_Audit_Info):
"""
get_tagged_resources returns a list of the resources that are going to be scanned based on the given input tags
"""
try:
resource_tags = []
tagged_resources = []
for tag in input_resource_tags:
key = tag.split("=")[0]
value = tag.split("=")[1]
resource_tags.append({"Key": key, "Values": [value]})
# Get Resources with resource_tags for all regions
for regional_client in generate_regional_clients(
"resourcegroupstaggingapi", current_audit_info
).values():
try:
get_resources_paginator = regional_client.get_paginator("get_resources")
for page in get_resources_paginator.paginate(TagFilters=resource_tags):
for resource in page["ResourceTagMappingList"]:
tagged_resources.append(resource["ResourceARN"])
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
sys.exit(1)
else:
return tagged_resources

View File

@@ -0,0 +1,35 @@
import importlib
import sys
from prowler.lib.logger import logger
from prowler.providers.aws.lib.allowlist.allowlist import parse_allowlist_file
def set_provider_allowlist(provider, audit_info, args):
"""
set_provider_allowlist configures the allowlist based on the selected provider.
"""
try:
# Check if the provider arguments has the allowlist_file
if hasattr(args, "allowlist_file"):
# Dynamically get the Provider allowlist handler
provider_allowlist_function = f"set_{provider}_allowlist"
allowlist_file = getattr(
importlib.import_module(__name__), provider_allowlist_function
)(audit_info, args.allowlist_file)
return allowlist_file
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
sys.exit(1)
def set_aws_allowlist(audit_info, allowlist_file):
# Parse content from Allowlist file and get it, if necessary, from S3
if allowlist_file:
allowlist_file = parse_allowlist_file(audit_info, allowlist_file)
else:
allowlist_file = None
return allowlist_file

View File

@@ -9,7 +9,8 @@ from prowler.lib.logger import logger
from prowler.providers.aws.aws_provider import (
AWS_Provider,
assume_role,
generate_regional_clients,
get_checks_from_input_arn,
get_regions_from_audit_resources,
)
from prowler.providers.aws.lib.arn.arn import arn_parsing
from prowler.providers.aws.lib.audit_info.audit_info import current_audit_info
@@ -18,6 +19,9 @@ from prowler.providers.aws.lib.audit_info.models import (
AWS_Credentials,
AWS_Organizations_Info,
)
from prowler.providers.aws.lib.resource_api_tagging.resource_api_tagging import (
get_tagged_resources,
)
from prowler.providers.azure.azure_provider import Azure_Provider
from prowler.providers.azure.lib.audit_info.audit_info import azure_audit_info
from prowler.providers.azure.lib.audit_info.models import Azure_Audit_Info
@@ -268,6 +272,20 @@ Caller Identity ARN: {Fore.YELLOW}[{audit_info.audited_identity_arn}]{Style.RESE
return current_audit_info
def set_aws_execution_parameters(self, provider, audit_info) -> list[str]:
# Once the audit_info is set and we have the eventual checks from arn, it is time to exclude the others
try:
if audit_info.audit_resources:
audit_info.audited_regions = get_regions_from_audit_resources(
audit_info.audit_resources
)
return get_checks_from_input_arn(audit_info.audit_resources, provider)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
sys.exit(1)
def set_azure_audit_info(self, arguments) -> Azure_Audit_Info:
"""
set_azure_audit_info returns the Azure_Audit_Info
@@ -319,34 +337,21 @@ def set_provider_audit_info(provider: str, arguments: dict):
return provider_audit_info
def get_tagged_resources(input_resource_tags: list, current_audit_info: AWS_Audit_Info):
def set_provider_execution_parameters(provider: str, audit_info):
"""
get_tagged_resources returns a list of the resources that are going to be scanned based on the given input tags
set_provider_audit_info configures automatically the audit execution based on the selected provider and returns the checks that are going to be executed.
"""
try:
resource_tags = []
tagged_resources = []
for tag in input_resource_tags:
key = tag.split("=")[0]
value = tag.split("=")[1]
resource_tags.append({"Key": key, "Values": [value]})
# Get Resources with resource_tags for all regions
for regional_client in generate_regional_clients(
"resourcegroupstaggingapi", current_audit_info
).values():
try:
get_resources_paginator = regional_client.get_paginator("get_resources")
for page in get_resources_paginator.paginate(TagFilters=resource_tags):
for resource in page["ResourceTagMappingList"]:
tagged_resources.append(resource["ResourceARN"])
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
set_provider_execution_parameters_function = (
f"set_{provider}_execution_parameters"
)
checks_to_execute = getattr(
Audit_Info(), set_provider_execution_parameters_function
)(provider, audit_info)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
sys.exit(1)
else:
return tagged_resources
return checks_to_execute

View File

@@ -0,0 +1,26 @@
import importlib
import sys
from prowler.lib.logger import logger
from prowler.providers.aws.lib.quick_inventory.quick_inventory import quick_inventory
def run_provider_quick_inventory(provider, audit_info, output_directory):
"""
run_provider_quick_inventory executes the quick inventory for te provider
"""
try:
# Dynamically get the Provider quick inventory handler
provider_quick_inventory_function = f"{provider}_quick_inventory"
getattr(importlib.import_module(__name__), provider_quick_inventory_function)(
audit_info, output_directory
)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
sys.exit(1)
def aws_quick_inventory(audit_info, output_directory):
quick_inventory(audit_info, output_directory)

View File

@@ -8,8 +8,6 @@ from mock import patch
from prowler.lib.check.check import (
exclude_checks_to_run,
exclude_services_to_run,
get_checks_from_input_arn,
get_regions_from_audit_resources,
list_modules,
list_services,
parse_checks_from_file,
@@ -18,6 +16,10 @@ from prowler.lib.check.check import (
update_audit_metadata,
)
from prowler.lib.check.models import load_check_metadata
from prowler.providers.aws.aws_provider import (
get_checks_from_input_arn,
get_regions_from_audit_resources,
)
expected_packages = [
ModuleInfo(