From c9cb9774c696683a4acec0082142ba4b4ae6e87a Mon Sep 17 00:00:00 2001 From: Pepe Fagoaga Date: Mon, 11 Dec 2023 14:09:39 +0100 Subject: [PATCH] fix(aws_regions): Get enabled regions (#3095) --- prowler/providers/aws/aws_provider.py | 40 +++++++++++++++---- .../aws/lib/audit_info/audit_info.py | 1 + .../providers/aws/lib/audit_info/models.py | 3 +- prowler/providers/common/audit_info.py | 4 ++ tests/providers/aws/audit_info_utils.py | 4 ++ tests/providers/aws/aws_provider_test.py | 12 +++--- tests/providers/common/audit_info_test.py | 1 + 7 files changed, 51 insertions(+), 14 deletions(-) diff --git a/prowler/providers/aws/aws_provider.py b/prowler/providers/aws/aws_provider.py index 725213d1..a5c33d6d 100644 --- a/prowler/providers/aws/aws_provider.py +++ b/prowler/providers/aws/aws_provider.py @@ -164,12 +164,19 @@ def generate_regional_clients( regional_clients = {} service_regions = get_available_aws_service_regions(service, audit_info) - for region in service_regions: + # Get the regions enabled for the account and get the intersection with the service available regions + if audit_info.enabled_regions: + enabled_regions = service_regions.intersection(audit_info.enabled_regions) + else: + enabled_regions = service_regions + + for region in enabled_regions: regional_client = audit_info.audit_session.client( service, region_name=region, config=audit_info.session_config ) regional_client.region = region regional_clients[region] = regional_client + return regional_clients except Exception as error: logger.error( @@ -177,6 +184,22 @@ def generate_regional_clients( ) +def get_aws_enabled_regions(audit_info: AWS_Audit_Info) -> set: + """get_aws_enabled_regions returns a set of enabled AWS regions""" + + # EC2 Client to check enabled regions + service = "ec2" + default_region = get_default_region(service, audit_info) + ec2_client = audit_info.audit_session.client(service, region_name=default_region) + + enabled_regions = set() + # With AllRegions=False we only get the enabled regions for the account + for region in ec2_client.describe_regions(AllRegions=False).get("Regions", []): + enabled_regions.add(region.get("RegionName")) + + return enabled_regions + + def get_aws_available_regions(): try: actual_directory = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) @@ -268,17 +291,18 @@ def get_regions_from_audit_resources(audit_resources: list) -> set: return audited_regions -def get_available_aws_service_regions(service: str, audit_info: AWS_Audit_Info) -> list: +def get_available_aws_service_regions(service: str, audit_info: AWS_Audit_Info) -> set: # Get json locally actual_directory = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) with open_file(f"{actual_directory}/{aws_services_json_file}") as f: data = parse_json_file(f) - # Check if it is a subservice - json_regions = data["services"][service]["regions"][audit_info.audited_partition] - if audit_info.audited_regions: # Check for input aws audit_info.audited_regions - regions = list( - set(json_regions).intersection(audit_info.audited_regions) - ) # Get common regions between input and json + json_regions = set( + data["services"][service]["regions"][audit_info.audited_partition] + ) + # Check for input aws audit_info.audited_regions + if audit_info.audited_regions: + # Get common regions between input and json + regions = json_regions.intersection(audit_info.audited_regions) else: # Get all regions from json of the service and partition regions = json_regions return regions diff --git a/prowler/providers/aws/lib/audit_info/audit_info.py b/prowler/providers/aws/lib/audit_info/audit_info.py index 908936c0..65f884c0 100644 --- a/prowler/providers/aws/lib/audit_info/audit_info.py +++ b/prowler/providers/aws/lib/audit_info/audit_info.py @@ -38,4 +38,5 @@ current_audit_info = AWS_Audit_Info( audit_metadata=None, audit_config=None, ignore_unused_services=False, + enabled_regions=set(), ) diff --git a/prowler/providers/aws/lib/audit_info/models.py b/prowler/providers/aws/lib/audit_info/models.py index 838982e3..59a2b15a 100644 --- a/prowler/providers/aws/lib/audit_info/models.py +++ b/prowler/providers/aws/lib/audit_info/models.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime from typing import Any, Optional @@ -53,3 +53,4 @@ class AWS_Audit_Info: audit_metadata: Optional[Any] = None audit_config: Optional[dict] = None ignore_unused_services: bool = False + enabled_regions: set = field(default_factory=set) diff --git a/prowler/providers/common/audit_info.py b/prowler/providers/common/audit_info.py index ec8b302b..39a7f0ce 100644 --- a/prowler/providers/common/audit_info.py +++ b/prowler/providers/common/audit_info.py @@ -8,6 +8,7 @@ from prowler.lib.logger import logger from prowler.providers.aws.aws_provider import ( AWS_Provider, assume_role, + get_aws_enabled_regions, get_checks_from_input_arn, get_regions_from_audit_resources, ) @@ -257,6 +258,9 @@ Azure Identity Type: {Fore.YELLOW}[{audit_info.identity.identity_type}]{Style.RE if arguments.get("resource_arn"): current_audit_info.audit_resources = arguments.get("resource_arn") + # Get Enabled Regions + current_audit_info.enabled_regions = get_aws_enabled_regions(current_audit_info) + return current_audit_info def set_aws_execution_parameters(self, provider, audit_info) -> list[str]: diff --git a/tests/providers/aws/audit_info_utils.py b/tests/providers/aws/audit_info_utils.py index d7b2be1e..821ef32c 100644 --- a/tests/providers/aws/audit_info_utils.py +++ b/tests/providers/aws/audit_info_utils.py @@ -15,6 +15,8 @@ AWS_REGION_EU_WEST_1 = "eu-west-1" AWS_REGION_EU_WEST_1_AZA = "eu-west-1a" AWS_REGION_EU_WEST_1_AZB = "eu-west-1b" AWS_REGION_EU_WEST_2 = "eu-west-2" +AWS_REGION_CN_NORTHWEST_1 = "cn-northwest-1" +AWS_REGION_CN_NORTH_1 = "cn-north-1" AWS_REGION_EU_SOUTH_2 = "eu-south-2" AWS_REGION_US_WEST_2 = "us-west-2" AWS_REGION_US_EAST_2 = "us-east-2" @@ -51,6 +53,7 @@ def set_mocked_aws_audit_info( botocore_session=None, ), original_session: session.Session = None, + enabled_regions: set = None, ): audit_info = AWS_Audit_Info( session_config=None, @@ -77,5 +80,6 @@ def set_mocked_aws_audit_info( ), audit_config=audit_config, ignore_unused_services=ignore_unused_services, + enabled_regions=enabled_regions if enabled_regions else set(audited_regions), ) return audit_info diff --git a/tests/providers/aws/aws_provider_test.py b/tests/providers/aws/aws_provider_test.py index b375382b..69d59728 100644 --- a/tests/providers/aws/aws_provider_test.py +++ b/tests/providers/aws/aws_provider_test.py @@ -280,7 +280,7 @@ class Test_AWS_Provider: role_name = "test-role" role_arn = f"arn:aws:iam::{AWS_ACCOUNT_NUMBER}:role/{role_name}" session_duration_seconds = 900 - AWS_REGION_US_EAST_1 = "eu-west-1" + AWS_REGION_US_EAST_1 = AWS_REGION_EU_WEST_1 sts_endpoint_region = AWS_REGION_US_EAST_1 sessionName = "ProwlerAsessmentSession" @@ -352,6 +352,7 @@ class Test_AWS_Provider: audit_session=boto3.session.Session( region_name=AWS_REGION_US_EAST_1, ), + enabled_regions=audited_regions, ) generate_regional_clients_response = generate_regional_clients( @@ -367,6 +368,7 @@ class Test_AWS_Provider: audit_session=boto3.session.Session( region_name=AWS_REGION_US_EAST_1, ), + enabled_regions=audited_regions, ) generate_regional_clients_response = generate_regional_clients( "shield", audit_info @@ -430,7 +432,7 @@ class Test_AWS_Provider: "eu-north-1", "eu-south-1", "eu-south-2", - "eu-west-1", + AWS_REGION_EU_WEST_1, "eu-west-2", "eu-west-3", "me-central-1", @@ -446,9 +448,9 @@ class Test_AWS_Provider: } }, ): - assert get_available_aws_service_regions("ec2", audit_info) == [ + assert get_available_aws_service_regions("ec2", audit_info) == { AWS_REGION_US_EAST_1 - ] + } def test_get_available_aws_service_regions_with_all_regions_audited(self): audit_info = set_mocked_aws_audit_info() @@ -467,7 +469,7 @@ class Test_AWS_Provider: "eu-north-1", "eu-south-1", "eu-south-2", - "eu-west-1", + AWS_REGION_EU_WEST_1, "eu-west-2", "eu-west-3", "me-central-1", diff --git a/tests/providers/common/audit_info_test.py b/tests/providers/common/audit_info_test.py index 2f09c986..57021e62 100644 --- a/tests/providers/common/audit_info_test.py +++ b/tests/providers/common/audit_info_test.py @@ -327,6 +327,7 @@ class Test_Set_Audit_Info: get_tagged_resources(["MY_TAG1=MY_VALUE1"], mock_audit_info) ) + @mock_ec2 @patch( "prowler.providers.common.audit_info.validate_aws_credentials", new=mock_validate_credentials,