diff --git a/prowler/providers/aws/services/cloudfront/cloudfront_service.py b/prowler/providers/aws/services/cloudfront/cloudfront_service.py index 667c693a..bbab5379 100644 --- a/prowler/providers/aws/services/cloudfront/cloudfront_service.py +++ b/prowler/providers/aws/services/cloudfront/cloudfront_service.py @@ -55,13 +55,11 @@ class CloudFront(AWSService): ]["Logging"]["Enabled"] distributions[ distribution_id - ].geo_restriction_type = distribution_config["DistributionConfig"][ - "Restrictions" - ][ - "GeoRestriction" - ][ - "RestrictionType" - ] + ].geo_restriction_type = GeoRestrictionType( + distribution_config["DistributionConfig"]["Restrictions"][ + "GeoRestriction" + ]["RestrictionType"] + ) distributions[distribution_id].web_acl_id = distribution_config[ "DistributionConfig" ]["WebACLId"] @@ -71,9 +69,11 @@ class CloudFront(AWSService): realtime_log_config_arn=distribution_config["DistributionConfig"][ "DefaultCacheBehavior" ].get("RealtimeLogConfigArn"), - viewer_protocol_policy=distribution_config["DistributionConfig"][ - "DefaultCacheBehavior" - ].get("ViewerProtocolPolicy"), + viewer_protocol_policy=ViewerProtocolPolicy( + distribution_config["DistributionConfig"][ + "DefaultCacheBehavior" + ].get("ViewerProtocolPolicy") + ), field_level_encryption_id=distribution_config["DistributionConfig"][ "DefaultCacheBehavior" ].get("FieldLevelEncryptionId"), @@ -131,7 +131,7 @@ class DefaultCacheConfigBehaviour(BaseModel): class Distribution(BaseModel): - """Distribution holds a CloudFront Distribution with the required information to run the rela""" + """Distribution holds a CloudFront Distribution resource""" arn: str id: str diff --git a/tests/providers/aws/services/cloudfront/cloudfront_service_test.py b/tests/providers/aws/services/cloudfront/cloudfront_service_test.py index aaeee822..725137c1 100644 --- a/tests/providers/aws/services/cloudfront/cloudfront_service_test.py +++ b/tests/providers/aws/services/cloudfront/cloudfront_service_test.py @@ -8,6 +8,7 @@ from moto.core import DEFAULT_ACCOUNT_ID from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info from prowler.providers.aws.services.cloudfront.cloudfront_service import ( CloudFront, + GeoRestrictionType, ViewerProtocolPolicy, ) from prowler.providers.common.models import Audit_Metadata @@ -243,7 +244,7 @@ class Test_CloudFront_Service: ) assert ( cloudfront.distributions[cloudfront_distribution_id].geo_restriction_type - == "blacklist" + == GeoRestrictionType.blacklist ) assert ( cloudfront.distributions[cloudfront_distribution_id].web_acl_id