diff --git a/prowler/providers/aws/aws_provider.py b/prowler/providers/aws/aws_provider.py index 99b0525d..616a30ec 100644 --- a/prowler/providers/aws/aws_provider.py +++ b/prowler/providers/aws/aws_provider.py @@ -146,27 +146,14 @@ def generate_regional_clients( ) -> dict: try: regional_clients = {} - # Get json locally - actual_directory = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) - with open_file(f"{actual_directory}/{aws_services_json_file}") as f: - data = parse_json_file(f) - # Check if it is a subservice - json_regions = data["services"][service]["regions"][ - audit_info.audited_partition - ] - if audit_info.audited_regions: # Check for input aws audit_info.audited_regions - regions = list( - set(json_regions).intersection(audit_info.audited_regions) - ) # Get common regions between input and json - else: # Get all regions from json of the service and partition - regions = json_regions + service_regions = get_available_aws_service_regions(service, audit_info) # Check if it is global service to gather only one region if global_service: - if regions: - if audit_info.profile_region in regions: - regions = [audit_info.profile_region] - regions = regions[:1] - for region in regions: + if service_regions: + if audit_info.profile_region in service_regions: + service_regions = [audit_info.profile_region] + service_regions = service_regions[:1] + for region in service_regions: regional_client = audit_info.audit_session.client( service, region_name=region, config=audit_info.session_config ) @@ -265,3 +252,46 @@ def get_regions_from_audit_resources(audit_resources: list) -> list: if audited_regions: return audited_regions return None + + +def get_available_aws_service_regions(service: str, audit_info: AWS_Audit_Info) -> list: + # Get json locally + actual_directory = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) + with open_file(f"{actual_directory}/{aws_services_json_file}") as f: + data = parse_json_file(f) + # Check if it is a subservice + json_regions = data["services"][service]["regions"][audit_info.audited_partition] + if audit_info.audited_regions: # Check for input aws audit_info.audited_regions + regions = list( + set(json_regions).intersection(audit_info.audited_regions) + ) # Get common regions between input and json + else: # Get all regions from json of the service and partition + regions = json_regions + return regions + + +def get_default_region(service: str, audit_info: AWS_Audit_Info) -> str: + """get_default_region gets the default region based on the profile and audited service regions""" + service_regions = get_available_aws_service_regions(service, audit_info) + default_region = get_global_region( + audit_info + ) # global region of the partition when all regions are audited and there is no profile region + if audit_info.profile_region in service_regions: + # return profile region only if it is audited + default_region = audit_info.profile_region + # return first audited region if specific regions are audited + elif audit_info.audited_regions: + default_region = audit_info.audited_regions[0] + return default_region + + +def get_global_region(audit_info: AWS_Audit_Info) -> str: + """get_global_region gets the global region based on the audited partition""" + global_region = "us-east-1" + if audit_info.audited_partition == "aws-cn": + global_region = "cn-north-1" + elif audit_info.audited_partition == "aws-us-gov": + global_region = "us-gov-east-1" + elif "aws-iso" in audit_info.audited_partition: + global_region = "aws-iso-global" + return global_region diff --git a/prowler/providers/aws/services/account/account_service.py b/prowler/providers/aws/services/account/account_service.py index 70778206..a3567ac8 100644 --- a/prowler/providers/aws/services/account/account_service.py +++ b/prowler/providers/aws/services/account/account_service.py @@ -1,4 +1,7 @@ -from prowler.providers.aws.aws_provider import generate_regional_clients +from prowler.providers.aws.aws_provider import ( + generate_regional_clients, + get_default_region, +) ################## Account @@ -10,13 +13,7 @@ class Account: self.audited_partition = audit_info.audited_partition self.audited_account_arn = audit_info.audited_account_arn self.regional_clients = generate_regional_clients(self.service, audit_info) - # If the region is not set in the audit profile, - # we pick the first region from the regional clients list - self.region = ( - audit_info.profile_region - if audit_info.profile_region - else list(self.regional_clients.keys())[0] - ) + self.region = get_default_region(self.service, audit_info) def __get_session__(self): return self.session diff --git a/prowler/providers/aws/services/backup/backup_service.py b/prowler/providers/aws/services/backup/backup_service.py index 48b4b31f..6e534fb1 100644 --- a/prowler/providers/aws/services/backup/backup_service.py +++ b/prowler/providers/aws/services/backup/backup_service.py @@ -6,7 +6,10 @@ from pydantic import BaseModel from prowler.lib.logger import logger from prowler.lib.scan_filters.scan_filters import is_resource_filtered -from prowler.providers.aws.aws_provider import generate_regional_clients +from prowler.providers.aws.aws_provider import ( + generate_regional_clients, + get_default_region, +) ################## Backup @@ -19,13 +22,7 @@ class Backup: self.audited_account_arn = audit_info.audited_account_arn self.audit_resources = audit_info.audit_resources self.regional_clients = generate_regional_clients(self.service, audit_info) - # If the region is not set in the audit profile, - # we pick the first region from the regional clients list - self.region = ( - audit_info.profile_region - if audit_info.profile_region - else list(self.regional_clients.keys())[0] - ) + self.region = get_default_region(self.service, audit_info) self.backup_vaults = [] self.__threading_call__(self.__list_backup_vaults__) self.backup_plans = [] diff --git a/prowler/providers/aws/services/cloudtrail/cloudtrail_service.py b/prowler/providers/aws/services/cloudtrail/cloudtrail_service.py index 2a879449..09a2b278 100644 --- a/prowler/providers/aws/services/cloudtrail/cloudtrail_service.py +++ b/prowler/providers/aws/services/cloudtrail/cloudtrail_service.py @@ -7,7 +7,10 @@ from pydantic import BaseModel from prowler.lib.logger import logger from prowler.lib.scan_filters.scan_filters import is_resource_filtered -from prowler.providers.aws.aws_provider import generate_regional_clients +from prowler.providers.aws.aws_provider import ( + generate_regional_clients, + get_default_region, +) ################### CLOUDTRAIL @@ -20,13 +23,7 @@ class Cloudtrail: self.audited_account_arn = audit_info.audited_account_arn self.audit_resources = audit_info.audit_resources self.regional_clients = generate_regional_clients(self.service, audit_info) - # If the region is not set in the audit profile, - # we pick the first region from the regional clients list - self.region = ( - audit_info.profile_region - if audit_info.profile_region - else list(self.regional_clients.keys())[0] - ) + self.region = get_default_region(self.service, audit_info) self.trails = [] self.__threading_call__(self.__get_trails__) self.__get_trail_status__() diff --git a/prowler/providers/aws/services/drs/drs_service.py b/prowler/providers/aws/services/drs/drs_service.py index 6b8bf341..8f238f82 100644 --- a/prowler/providers/aws/services/drs/drs_service.py +++ b/prowler/providers/aws/services/drs/drs_service.py @@ -5,7 +5,10 @@ from pydantic import BaseModel from prowler.lib.logger import logger from prowler.lib.scan_filters.scan_filters import is_resource_filtered -from prowler.providers.aws.aws_provider import generate_regional_clients +from prowler.providers.aws.aws_provider import ( + generate_regional_clients, + get_default_region, +) ################## DRS (Elastic Disaster Recovery Service) @@ -19,13 +22,7 @@ class DRS: self.audited_account_arn = audit_info.audited_account_arn self.audit_resources = audit_info.audit_resources self.regional_clients = generate_regional_clients(self.service, audit_info) - # If the region is not set in the audit profile, - # we pick the first region from the regional clients list - self.region = ( - audit_info.profile_region - if audit_info.profile_region - else list(self.regional_clients.keys())[0] - ) + self.region = get_default_region(self.service, audit_info) self.drs_services = [] self.__threading_call__(self.__describe_jobs__) diff --git a/prowler/providers/aws/services/inspector2/inspector2_service.py b/prowler/providers/aws/services/inspector2/inspector2_service.py index bd59eb01..c61958e7 100644 --- a/prowler/providers/aws/services/inspector2/inspector2_service.py +++ b/prowler/providers/aws/services/inspector2/inspector2_service.py @@ -4,7 +4,10 @@ from pydantic import BaseModel from prowler.lib.logger import logger from prowler.lib.scan_filters.scan_filters import is_resource_filtered -from prowler.providers.aws.aws_provider import generate_regional_clients +from prowler.providers.aws.aws_provider import ( + generate_regional_clients, + get_default_region, +) ################################ Inspector2 @@ -17,13 +20,7 @@ class Inspector2: self.audited_account_arn = audit_info.audited_account_arn self.audit_resources = audit_info.audit_resources self.regional_clients = generate_regional_clients(self.service, audit_info) - # If the region is not set in the audit profile, - # we pick the first region from the regional clients list - self.region = ( - audit_info.profile_region - if audit_info.profile_region - else list(self.regional_clients.keys())[0] - ) + self.region = get_default_region(self.service, audit_info) self.inspectors = [] self.__threading_call__(self.__batch_get_account_status__) self.__list_findings__() diff --git a/prowler/providers/aws/services/networkfirewall/networkfirewall_service.py b/prowler/providers/aws/services/networkfirewall/networkfirewall_service.py index c6da6c64..ca441462 100644 --- a/prowler/providers/aws/services/networkfirewall/networkfirewall_service.py +++ b/prowler/providers/aws/services/networkfirewall/networkfirewall_service.py @@ -4,7 +4,10 @@ from pydantic import BaseModel from prowler.lib.logger import logger from prowler.lib.scan_filters.scan_filters import is_resource_filtered -from prowler.providers.aws.aws_provider import generate_regional_clients +from prowler.providers.aws.aws_provider import ( + generate_regional_clients, + get_default_region, +) ################## NetworkFirewall @@ -16,13 +19,7 @@ class NetworkFirewall: self.audited_partition = audit_info.audited_partition self.audit_resources = audit_info.audit_resources self.regional_clients = generate_regional_clients(self.service, audit_info) - # If the region is not set in the audit profile, - # we pick the first region from the regional clients list - self.region = ( - audit_info.profile_region - if audit_info.profile_region - else list(self.regional_clients.keys())[0] - ) + self.region = get_default_region(self.service, audit_info) self.network_firewalls = [] self.__threading_call__(self.__list_firewalls__) self.__describe_firewall__() diff --git a/prowler/providers/aws/services/resourceexplorer2/resourceexplorer2_service.py b/prowler/providers/aws/services/resourceexplorer2/resourceexplorer2_service.py index dbcb48cc..af6cc530 100644 --- a/prowler/providers/aws/services/resourceexplorer2/resourceexplorer2_service.py +++ b/prowler/providers/aws/services/resourceexplorer2/resourceexplorer2_service.py @@ -4,7 +4,10 @@ from pydantic import BaseModel from prowler.lib.logger import logger from prowler.lib.scan_filters.scan_filters import is_resource_filtered -from prowler.providers.aws.aws_provider import generate_regional_clients +from prowler.providers.aws.aws_provider import ( + generate_regional_clients, + get_default_region, +) ################################ ResourceExplorer2 @@ -17,13 +20,7 @@ class ResourceExplorer2: self.audited_partition = audit_info.audited_partition self.audited_account_arn = audit_info.audited_account_arn self.regional_clients = generate_regional_clients(self.service, audit_info) - # If the region is not set in the audit profile, - # we pick the first region from the regional clients list - self.region = ( - audit_info.profile_region - if audit_info.profile_region - else list(self.regional_clients.keys())[0] - ) + self.region = get_default_region(self.service, audit_info) self.indexes = [] self.__threading_call__(self.__list_indexes__) diff --git a/prowler/providers/aws/services/ssmincidents/ssmincidents_service.py b/prowler/providers/aws/services/ssmincidents/ssmincidents_service.py index e2d02ada..047ddddd 100644 --- a/prowler/providers/aws/services/ssmincidents/ssmincidents_service.py +++ b/prowler/providers/aws/services/ssmincidents/ssmincidents_service.py @@ -5,7 +5,10 @@ from pydantic import BaseModel from prowler.lib.logger import logger from prowler.lib.scan_filters.scan_filters import is_resource_filtered -from prowler.providers.aws.aws_provider import generate_regional_clients +from prowler.providers.aws.aws_provider import ( + generate_regional_clients, + get_default_region, +) # Note: # This service is a bit special because it creates a resource (Replication Set) in one region, but you can list it in from any region using list_replication_sets @@ -24,13 +27,7 @@ class SSMIncidents: self.audited_account_arn = audit_info.audited_account_arn self.audit_resources = audit_info.audit_resources self.regional_clients = generate_regional_clients(self.service, audit_info) - # If the region is not set in the audit profile, - # we pick the first region from the regional clients list - self.region = ( - audit_info.profile_region - if audit_info.profile_region - else list(self.regional_clients.keys())[0] - ) + self.region = get_default_region(self.service, audit_info) self.replication_set = [] self.__list_replication_sets__() self.__get_replication_set__() diff --git a/prowler/providers/aws/services/trustedadvisor/trustedadvisor_service.py b/prowler/providers/aws/services/trustedadvisor/trustedadvisor_service.py index 05e96b72..932b99c6 100644 --- a/prowler/providers/aws/services/trustedadvisor/trustedadvisor_service.py +++ b/prowler/providers/aws/services/trustedadvisor/trustedadvisor_service.py @@ -4,6 +4,7 @@ from botocore.client import ClientError from pydantic import BaseModel from prowler.lib.logger import logger +from prowler.providers.aws.aws_provider import get_default_region ################################ TrustedAdvisor @@ -18,13 +19,14 @@ class TrustedAdvisor: # But only in us-east-1 or us-gov-west-1 https://docs.aws.amazon.com/general/latest/gr/awssupport.html if audit_info.audited_partition != "aws-cn": if audit_info.audited_partition == "aws": + self.region = get_default_region(self.service, audit_info) support_region = "us-east-1" else: support_region = "us-gov-west-1" self.client = audit_info.audit_session.client( self.service, region_name=support_region ) - self.client.region = self.region = support_region + self.client.region = support_region self.__describe_trusted_advisor_checks__() self.__describe_trusted_advisor_check_result__() diff --git a/prowler/providers/aws/services/vpc/vpc_service.py b/prowler/providers/aws/services/vpc/vpc_service.py index 9676f8f6..403d5ca8 100644 --- a/prowler/providers/aws/services/vpc/vpc_service.py +++ b/prowler/providers/aws/services/vpc/vpc_service.py @@ -7,7 +7,10 @@ from pydantic import BaseModel from prowler.lib.logger import logger from prowler.lib.scan_filters.scan_filters import is_resource_filtered -from prowler.providers.aws.aws_provider import generate_regional_clients +from prowler.providers.aws.aws_provider import ( + generate_regional_clients, + get_default_region, +) ################## VPC @@ -33,11 +36,7 @@ class VPC: self.__describe_vpc_endpoint_service_permissions__() self.vpc_subnets = {} self.__threading_call__(self.__describe_vpc_subnets__) - self.region = ( - audit_info.profile_region - if audit_info.profile_region - else list(self.regional_clients.keys())[0] - ) + self.region = get_default_region(self.service, audit_info) def __get_session__(self): return self.session diff --git a/tests/providers/aws/aws_provider_test.py b/tests/providers/aws/aws_provider_test.py index 985f8007..65c4804a 100644 --- a/tests/providers/aws/aws_provider_test.py +++ b/tests/providers/aws/aws_provider_test.py @@ -7,6 +7,9 @@ from prowler.providers.aws.aws_provider import ( AWS_Provider, assume_role, generate_regional_clients, + get_available_aws_service_regions, + get_default_region, + get_global_region, ) from prowler.providers.aws.lib.audit_info.models import AWS_Assume_Role, AWS_Audit_Info @@ -275,3 +278,282 @@ class Test_AWS_Provider: # Shield does not exist in China assert generate_regional_clients_response == {} + + def test_get_default_region(self): + audited_regions = ["eu-west-1"] + profile_region = "eu-west-1" + audit_info = AWS_Audit_Info( + session_config=None, + original_session=None, + audit_session=None, + audited_account=None, + audited_account_arn=None, + audited_partition="aws", + audited_identity_arn=None, + audited_user_id=None, + profile=None, + profile_region=profile_region, + credentials=None, + assumed_role_info=None, + audited_regions=audited_regions, + organizations_metadata=None, + audit_resources=None, + mfa_enabled=False, + ) + assert get_default_region("ec2", audit_info) == "eu-west-1" + + def test_get_default_region_profile_region_not_audited(self): + audited_regions = ["eu-west-1"] + profile_region = "us-east-2" + audit_info = AWS_Audit_Info( + session_config=None, + original_session=None, + audit_session=None, + audited_account=None, + audited_account_arn=None, + audited_partition="aws", + audited_identity_arn=None, + audited_user_id=None, + profile=None, + profile_region=profile_region, + credentials=None, + assumed_role_info=None, + audited_regions=audited_regions, + organizations_metadata=None, + audit_resources=None, + mfa_enabled=False, + ) + assert get_default_region("ec2", audit_info) == "eu-west-1" + + def test_get_default_region_non_profile_region(self): + audited_regions = ["eu-west-1"] + profile_region = None + audit_info = AWS_Audit_Info( + session_config=None, + original_session=None, + audit_session=None, + audited_account=None, + audited_account_arn=None, + audited_partition="aws", + audited_identity_arn=None, + audited_user_id=None, + profile=None, + profile_region=profile_region, + credentials=None, + assumed_role_info=None, + audited_regions=audited_regions, + organizations_metadata=None, + audit_resources=None, + mfa_enabled=False, + ) + assert get_default_region("ec2", audit_info) == "eu-west-1" + + def test_get_default_region_non_profile_or_audited_region(self): + audited_regions = None + profile_region = None + audit_info = AWS_Audit_Info( + session_config=None, + original_session=None, + audit_session=None, + audited_account=None, + audited_account_arn=None, + audited_partition="aws", + audited_identity_arn=None, + audited_user_id=None, + profile=None, + profile_region=profile_region, + credentials=None, + assumed_role_info=None, + audited_regions=audited_regions, + organizations_metadata=None, + audit_resources=None, + mfa_enabled=False, + ) + assert get_default_region("ec2", audit_info) == "us-east-1" + + def test_aws_get_global_region(self): + audit_info = AWS_Audit_Info( + session_config=None, + original_session=None, + audit_session=None, + audited_account=None, + audited_account_arn=None, + audited_partition="aws", + audited_identity_arn=None, + audited_user_id=None, + profile=None, + profile_region=None, + credentials=None, + assumed_role_info=None, + audited_regions=None, + organizations_metadata=None, + audit_resources=None, + mfa_enabled=False, + ) + assert get_default_region("ec2", audit_info) == "us-east-1" + + def test_aws_gov_get_global_region(self): + audit_info = AWS_Audit_Info( + session_config=None, + original_session=None, + audit_session=None, + audited_account=None, + audited_account_arn=None, + audited_partition="aws-us-gov", + audited_identity_arn=None, + audited_user_id=None, + profile=None, + profile_region=None, + credentials=None, + assumed_role_info=None, + audited_regions=None, + organizations_metadata=None, + audit_resources=None, + mfa_enabled=False, + ) + assert get_global_region(audit_info) == "us-gov-east-1" + + def test_aws_cn_get_global_region(self): + audit_info = AWS_Audit_Info( + session_config=None, + original_session=None, + audit_session=None, + audited_account=None, + audited_account_arn=None, + audited_partition="aws-cn", + audited_identity_arn=None, + audited_user_id=None, + profile=None, + profile_region=None, + credentials=None, + assumed_role_info=None, + audited_regions=None, + organizations_metadata=None, + audit_resources=None, + mfa_enabled=False, + ) + assert get_global_region(audit_info) == "cn-north-1" + + def test_aws_iso_get_global_region(self): + audit_info = AWS_Audit_Info( + session_config=None, + original_session=None, + audit_session=None, + audited_account=None, + audited_account_arn=None, + audited_partition="aws-iso", + audited_identity_arn=None, + audited_user_id=None, + profile=None, + profile_region=None, + credentials=None, + assumed_role_info=None, + audited_regions=None, + organizations_metadata=None, + audit_resources=None, + mfa_enabled=False, + ) + assert get_global_region(audit_info) == "aws-iso-global" + + def test_get_available_aws_service_regions_with_us_east_1_audited(self): + audited_regions = ["us-east-1"] + audit_info = AWS_Audit_Info( + session_config=None, + original_session=None, + audit_session=None, + audited_account=None, + audited_account_arn=None, + audited_partition="aws", + audited_identity_arn=None, + audited_user_id=None, + profile=None, + profile_region=None, + credentials=None, + assumed_role_info=None, + audited_regions=audited_regions, + organizations_metadata=None, + audit_resources=None, + mfa_enabled=False, + ) + with patch( + "prowler.providers.aws.aws_provider.parse_json_file", + return_value={ + "services": { + "ec2": { + "regions": { + "aws": [ + "af-south-1", + "ca-central-1", + "eu-central-1", + "eu-central-2", + "eu-north-1", + "eu-south-1", + "eu-south-2", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "me-central-1", + "me-south-1", + "sa-east-1", + "us-east-1", + "us-east-2", + "us-west-1", + "us-west-2", + ], + } + } + } + }, + ): + assert get_available_aws_service_regions("ec2", audit_info) == ["us-east-1"] + + def test_get_available_aws_service_regions_with_all_regions_audited(self): + audit_info = AWS_Audit_Info( + session_config=None, + original_session=None, + audit_session=None, + audited_account=None, + audited_account_arn=None, + audited_partition="aws", + audited_identity_arn=None, + audited_user_id=None, + profile=None, + profile_region=None, + credentials=None, + assumed_role_info=None, + audited_regions=None, + organizations_metadata=None, + audit_resources=None, + mfa_enabled=False, + ) + with patch( + "prowler.providers.aws.aws_provider.parse_json_file", + return_value={ + "services": { + "ec2": { + "regions": { + "aws": [ + "af-south-1", + "ca-central-1", + "eu-central-1", + "eu-central-2", + "eu-north-1", + "eu-south-1", + "eu-south-2", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "me-central-1", + "me-south-1", + "sa-east-1", + "us-east-1", + "us-east-2", + "us-west-1", + "us-west-2", + ], + } + } + } + }, + ): + assert len(get_available_aws_service_regions("ec2", audit_info)) == 17 diff --git a/tests/providers/aws/services/resourceexplorer2/resourceexplorer2_service_test.py b/tests/providers/aws/services/resourceexplorer2/resourceexplorer2_service_test.py index a77e93bf..0c9392ef 100644 --- a/tests/providers/aws/services/resourceexplorer2/resourceexplorer2_service_test.py +++ b/tests/providers/aws/services/resourceexplorer2/resourceexplorer2_service_test.py @@ -60,7 +60,7 @@ class Test_ResourceExplorer2_Service: profile_region=None, credentials=None, assumed_role_info=None, - audited_regions="us-east-1", + audited_regions=["us-east-1"], organizations_metadata=None, audit_resources=None, mfa_enabled=False,