diff --git a/prowler/providers/aws/aws_provider.py b/prowler/providers/aws/aws_provider.py index 6aae6a42..9d034b58 100644 --- a/prowler/providers/aws/aws_provider.py +++ b/prowler/providers/aws/aws_provider.py @@ -256,16 +256,14 @@ def get_checks_from_input_arn(audit_resources: list, provider: str) -> set: return sorted(checks_from_arn) -def get_regions_from_audit_resources(audit_resources: list) -> list: +def get_regions_from_audit_resources(audit_resources: list) -> set: """get_regions_from_audit_resources gets the regions from the audit resources arns""" - audited_regions = [] + audited_regions = set() for resource in audit_resources: region = resource.split(":")[3] - if region and region not in audited_regions: # Check if arn has a region - audited_regions.append(region) - if audited_regions: - return audited_regions - return None + if region: + audited_regions.add(region) + return audited_regions def get_available_aws_service_regions(service: str, audit_info: AWS_Audit_Info) -> list: diff --git a/tests/lib/check/check_test.py b/tests/lib/check/check_test.py index 64a80379..d19f1d02 100644 --- a/tests/lib/check/check_test.py +++ b/tests/lib/check/check_test.py @@ -646,7 +646,7 @@ class Test_Check: recovered_checks = get_checks_from_input_arn(audit_resources, provider) assert recovered_checks == expected_checks - def test_get_regions_from_audit_resources(self): + def test_get_regions_from_audit_resources_with_regions(self): audit_resources = [ f"arn:aws:lambda:us-east-1:{AWS_ACCOUNT_NUMBER}:function:test-lambda", f"arn:aws:iam::{AWS_ACCOUNT_NUMBER}:policy/test", @@ -654,10 +654,15 @@ class Test_Check: "arn:aws:s3:::bucket-name", "arn:aws:apigateway:us-east-2::/restapis/api-id/stages/stage-name", ] - expected_regions = ["us-east-1", "eu-west-1", "us-east-2"] + expected_regions = {"us-east-1", "eu-west-1", "us-east-2"} recovered_regions = get_regions_from_audit_resources(audit_resources) assert recovered_regions == expected_regions + def test_get_regions_from_audit_resources_without_regions(self): + audit_resources = ["arn:aws:s3:::bucket-name"] + recovered_regions = get_regions_from_audit_resources(audit_resources) + assert not recovered_regions + # def test_parse_checks_from_compliance_framework_two(self): # test_case = { # "input": {"compliance_frameworks": ["cis_v1.4_aws", "ens_v3_aws"]},