feat(shield): Service and checks (#1504)

This commit is contained in:
Sergio Garcia
2022-11-21 10:18:54 +01:00
committed by GitHub
parent 1370e0dec4
commit 52a3e990c6
49 changed files with 3474 additions and 1663 deletions

2446
Pipfile.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -100,9 +100,6 @@ def print_services(service_list: set):
def print_checks(provider: str, check_list: set, bulk_checks_metadata: dict):
print(
f"There are {Fore.YELLOW}{len(check_list)}{Style.RESET_ALL} available checks: \n"
)
for check in check_list:
try:
print(
@@ -112,6 +109,9 @@ def print_checks(provider: str, check_list: set, bulk_checks_metadata: dict):
logger.error(
f"Check {error} was not found for the {provider.upper()} provider"
)
print(
f"\nThere are {Fore.YELLOW}{len(check_list)}{Style.RESET_ALL} available checks.\n"
)
# List available groups

View File

@@ -1,211 +1,211 @@
# from datetime import datetime
# from unittest import mock
from unittest import mock
# from boto3 import session
# from moto.core import DEFAULT_ACCOUNT_ID
from boto3 import client, session
from mock import patch
from moto import mock_cloudtrail, mock_s3
from moto.core import DEFAULT_ACCOUNT_ID
# from providers.aws.lib.audit_info.audit_info import AWS_Audit_Info
# from providers.aws.services.awslambda.awslambda_service import Function
# from providers.aws.services.cloudtrail.cloudtrail_service import Trail
from providers.aws.lib.audit_info.audit_info import AWS_Audit_Info
from providers.aws.services.awslambda.awslambda_service import Function
# AWS_REGION = "us-east-1"
AWS_REGION = "us-east-1"
# class Test_awslambda_function_invoke_api_operations_cloudtrail_logging_enabled:
# # Mocked Audit Info
# def set_mocked_audit_info(self):
# audit_info = AWS_Audit_Info(
# original_session=None,
# audit_session=session.Session(
# profile_name=None,
# botocore_session=None,
# ),
# audited_account=None,
# audited_user_id=None,
# audited_partition="aws",
# audited_identity_arn=None,
# profile=None,
# profile_region=None,
# credentials=None,
# assumed_role_info=None,
# audited_regions=None,
# organizations_metadata=None,
# )
# return audit_info
# Mock generate_regional_clients()
def mock_generate_regional_clients(service, audit_info):
regional_client = audit_info.audit_session.client(service, region_name=AWS_REGION)
regional_client.region = AWS_REGION
return {AWS_REGION: regional_client}
# def test_no_functions(self):
# lambda_client = mock.MagicMock
# lambda_client.functions = {}
# cloudtrail_client = mock.MagicMock
# cloudtrail_client.trails = []
# with mock.patch(
# "providers.aws.services.awslambda.awslambda_service.Lambda",
# new=lambda_client,
# ), mock.patch(
# "providers.aws.lib.audit_info.audit_info.current_audit_info",
# self.set_mocked_audit_info(),
# ), mock.patch(
# "providers.aws.services.cloudtrail.cloudtrail_service.Cloudtrail",
# new=cloudtrail_client,
# ):
# # Test Check
# from providers.aws.services.awslambda.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled import (
# awslambda_function_invoke_api_operations_cloudtrail_logging_enabled,
# )
# Patch every AWS call using Boto3 and generate_regional_clients to have 1 client
@patch(
"providers.aws.services.accessanalyzer.accessanalyzer_service.generate_regional_clients",
new=mock_generate_regional_clients,
)
class Test_awslambda_function_invoke_api_operations_cloudtrail_logging_enabled:
# Mocked Audit Info
def set_mocked_audit_info(self):
audit_info = AWS_Audit_Info(
original_session=None,
audit_session=session.Session(
profile_name=None,
botocore_session=None,
),
audited_account=None,
audited_user_id=None,
audited_partition="aws",
audited_identity_arn=None,
profile=None,
profile_region=None,
credentials=None,
assumed_role_info=None,
audited_regions=None,
organizations_metadata=None,
)
return audit_info
# check = (
# awslambda_function_invoke_api_operations_cloudtrail_logging_enabled()
# )
# result = check.execute()
@mock_cloudtrail
def test_no_functions(self):
lambda_client = mock.MagicMock
lambda_client.functions = {}
# assert len(result) == 0
from providers.aws.services.cloudtrail.cloudtrail_service import Cloudtrail
# def test_lambda_not_recorded_by_cloudtrail(self):
# # Lambda Client
# lambda_client = mock.MagicMock
# function_name = "test-lambda"
# function_runtime = "python3.9"
# function_arn = (
# f"arn:aws:lambda:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:function/{function_name}"
# )
# lambda_client.functions = {
# function_name: Function(
# name=function_name,
# arn=function_arn,
# region=AWS_REGION,
# runtime=function_runtime,
# )
# }
# # CloudTrail Client
# cloudtrail_client = mock.MagicMock
# cloudtrail_client.trails = [
# Trail(
# name="test-trail",
# is_multiregion=False,
# home_region=AWS_REGION,
# arn="",
# region=AWS_REGION,
# is_logging=True,
# log_file_validation_enabled=True,
# latest_cloudwatch_delivery_time=datetime(2022, 1, 1),
# s3_bucket="",
# kms_key="",
# log_group_arn="",
# data_events=[
# {
# "ReadWriteType": "All",
# "IncludeManagementEvents": True,
# "DataResources": [],
# "ExcludeManagementEventSources": [],
# }
# ],
# )
# ]
with mock.patch(
"providers.aws.services.awslambda.awslambda_service.Lambda",
new=lambda_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.awslambda.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.cloudtrail_client",
new=Cloudtrail(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.awslambda.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled import (
awslambda_function_invoke_api_operations_cloudtrail_logging_enabled,
)
# with mock.patch(
# "providers.aws.services.awslambda.awslambda_service.Lambda",
# new=lambda_client,
# ), mock.patch(
# "providers.aws.services.cloudtrail.cloudtrail_service.Cloudtrail",
# new=cloudtrail_client,
# ):
check = (
awslambda_function_invoke_api_operations_cloudtrail_logging_enabled()
)
result = check.execute()
# # Test Check
# from providers.aws.services.awslambda.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled import (
# awslambda_function_invoke_api_operations_cloudtrail_logging_enabled,
# )
assert len(result) == 0
# check = (
# awslambda_function_invoke_api_operations_cloudtrail_logging_enabled()
# )
# result = check.execute()
@mock_cloudtrail
@mock_s3
def test_lambda_not_recorded_by_cloudtrail(self):
# Lambda Client
lambda_client = mock.MagicMock
function_name = "test-lambda"
function_runtime = "python3.9"
function_arn = (
f"arn:aws:lambda:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:function/{function_name}"
)
lambda_client.functions = {
function_name: Function(
name=function_name,
arn=function_arn,
region=AWS_REGION,
runtime=function_runtime,
)
}
# assert len(result) == 1
# assert result[0].region == AWS_REGION
# assert result[0].resource_id == function_name
# assert result[0].resource_arn == function_arn
# assert result[0].status == "FAIL"
# assert (
# result[0].status_extended
# == f"Lambda function {function_name} is not recorded by CloudTrail"
# )
# CloudTrail Client
cloudtrail_client = client("cloudtrail", region_name=AWS_REGION)
s3_client = client("s3", region_name=AWS_REGION)
trail_name = "test-trail"
bucket_name = "test-bucket"
s3_client.create_bucket(Bucket=bucket_name)
cloudtrail_client.create_trail(
Name=trail_name, S3BucketName=bucket_name, IsMultiRegionTrail=False
)
# def test_lambda_recorded_by_cloudtrail(self):
# # Lambda Client
# lambda_client = mock.MagicMock
# function_name = "test-lambda"
# function_runtime = "python3.9"
# function_arn = (
# f"arn:aws:lambda:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:function/{function_name}"
# )
# lambda_client.functions = {
# function_name: Function(
# name=function_name,
# arn=function_arn,
# region=AWS_REGION,
# runtime=function_runtime,
# )
# }
# # CloudTrail Client
# cloudtrail_client = mock.MagicMock
# trail_name = "test-trail"
# cloudtrail_client.trails = [
# Trail(
# name=trail_name,
# is_multiregion=False,
# home_region=AWS_REGION,
# arn="",
# region=AWS_REGION,
# is_logging=True,
# log_file_validation_enabled=True,
# latest_cloudwatch_delivery_time=datetime(2022, 1, 1),
# s3_bucket="",
# kms_key="",
# log_group_arn="",
# data_events=[
# {
# "ReadWriteType": "All",
# "IncludeManagementEvents": True,
# "DataResources": [
# {
# "Type": "AWS::Lambda::Function",
# "Values": [
# function_arn,
# ],
# },
# ],
# "ExcludeManagementEventSources": [],
# }
# ],
# )
# ]
from providers.aws.services.cloudtrail.cloudtrail_service import Cloudtrail
# with mock.patch(
# "providers.aws.services.awslambda.awslambda_service.Lambda",
# new=lambda_client,
# ), mock.patch(
# "providers.aws.services.cloudtrail.cloudtrail_service.Cloudtrail",
# new=cloudtrail_client,
# ):
with mock.patch(
"providers.aws.services.awslambda.awslambda_service.Lambda",
new=lambda_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.awslambda.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.cloudtrail_client",
new=Cloudtrail(self.set_mocked_audit_info()),
):
# #
# # Test Check
# from providers.aws.services.awslambda.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled import (
# awslambda_function_invoke_api_operations_cloudtrail_logging_enabled,
# )
# Test Check
from providers.aws.services.awslambda.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled import (
awslambda_function_invoke_api_operations_cloudtrail_logging_enabled,
)
# check = (
# awslambda_function_invoke_api_operations_cloudtrail_logging_enabled()
# )
# result = check.execute()
check = (
awslambda_function_invoke_api_operations_cloudtrail_logging_enabled()
)
result = check.execute()
# assert len(result) == 1
# assert result[0].region == AWS_REGION
# assert result[0].resource_id == function_name
# assert result[0].resource_arn == function_arn
# assert result[0].status == "PASS"
# assert (
# result[0].status_extended
# == f"Lambda function {function_name} is recorded by CloudTrail {trail_name}"
# )
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == function_name
assert result[0].resource_arn == function_arn
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"Lambda function {function_name} is not recorded by CloudTrail"
)
@mock_cloudtrail
@mock_s3
def test_lambda_recorded_by_cloudtrail(self):
# Lambda Client
lambda_client = mock.MagicMock
function_name = "test-lambda"
function_runtime = "python3.9"
function_arn = (
f"arn:aws:lambda:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:function/{function_name}"
)
lambda_client.functions = {
function_name: Function(
name=function_name,
arn=function_arn,
region=AWS_REGION,
runtime=function_runtime,
)
}
# CloudTrail Client
cloudtrail_client = client("cloudtrail", region_name=AWS_REGION)
s3_client = client("s3", region_name=AWS_REGION)
trail_name = "test-trail"
bucket_name = "test-bucket"
s3_client.create_bucket(Bucket=bucket_name)
cloudtrail_client.create_trail(
Name=trail_name, S3BucketName=bucket_name, IsMultiRegionTrail=False
)
_ = cloudtrail_client.put_event_selectors(
TrailName=trail_name,
EventSelectors=[
{
"ReadWriteType": "All",
"IncludeManagementEvents": True,
"DataResources": [
{"Type": "AWS::Lambda::Function", "Values": [function_arn]}
],
}
],
)
from providers.aws.services.cloudtrail.cloudtrail_service import Cloudtrail
with mock.patch(
"providers.aws.services.awslambda.awslambda_service.Lambda",
new=lambda_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.awslambda.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.cloudtrail_client",
new=Cloudtrail(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.awslambda.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled.awslambda_function_invoke_api_operations_cloudtrail_logging_enabled import (
awslambda_function_invoke_api_operations_cloudtrail_logging_enabled,
)
check = (
awslambda_function_invoke_api_operations_cloudtrail_logging_enabled()
)
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == function_name
assert result[0].resource_arn == function_arn
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"Lambda function {function_name} is recorded by CloudTrail {trail_name}"
)

View File

@@ -10,6 +10,7 @@ class EC2:
def __init__(self, audit_info):
self.service = "ec2"
self.session = audit_info.audit_session
self.audited_partition = audit_info.audited_partition
self.audited_account = audit_info.audited_account
self.regional_clients = generate_regional_clients(self.service, audit_info)
self.instances = []
@@ -264,11 +265,14 @@ class EC2:
association_id = address["AssociationId"]
if "AllocationId" in address:
allocation_id = address["AllocationId"]
elastic_ip_arn = f"arn:{self.audited_partition}:ec2:{regional_client.region}:{self.audited_account}:eip-allocation/{allocation_id}"
self.elastic_ips.append(
ElasticIP(
public_ip,
association_id,
allocation_id,
elastic_ip_arn,
regional_client.region,
)
)
@@ -403,13 +407,15 @@ class NetworkACL:
class ElasticIP:
public_ip: str
association_id: str
arn: str
allocation_id: str
region: str
def __init__(self, public_ip, association_id, allocation_id, region):
def __init__(self, public_ip, association_id, allocation_id, arn, region):
self.public_ip = public_ip
self.association_id = association_id
self.allocation_id = allocation_id
self.arn = arn
self.region = region

View File

@@ -205,8 +205,14 @@ class Test_EC2_Service:
def test__describe_addresses__(self):
# Generate EC2 Client
ec2_client = client("ec2", region_name=AWS_REGION)
ec2_client.allocate_address(Domain="vpc", Address="127.38.43.222")
allocation_id = ec2_client.allocate_address(
Domain="vpc", Address="127.38.43.222"
)["AllocationId"]
# EC2 client for this test class
audit_info = self.set_mocked_audit_info()
ec2 = EC2(audit_info)
assert "127.38.43.222" in str(ec2.elastic_ips)
assert (
ec2.elastic_ips[0].arn
== f"arn:aws:ec2:{AWS_REGION}:{AWS_ACCOUNT_NUMBER}:eip-allocation/{allocation_id}"
)

View File

@@ -12,6 +12,8 @@ class ELB:
def __init__(self, audit_info):
self.service = "elb"
self.session = audit_info.audit_session
self.audited_partition = audit_info.audited_partition
self.audited_account = audit_info.audited_account
self.regional_clients = generate_regional_clients(self.service, audit_info)
self.loadbalancers = []
self.__threading_call__(self.__describe_load_balancers__)
@@ -48,6 +50,7 @@ class ELB:
self.loadbalancers.append(
LoadBalancer(
name=elb["LoadBalancerName"],
arn=f"arn:{self.audited_partition}:elasticloadbalancing:{regional_client.region}:{self.audited_account}:loadbalancer/{elb['LoadBalancerName']}",
dns=elb["DNSName"],
region=regional_client.region,
scheme=elb["Scheme"],
@@ -85,6 +88,7 @@ class Listener(BaseModel):
class LoadBalancer(BaseModel):
name: str
dns: str
arn: str
region: str
scheme: str
access_logs: Optional[bool]

View File

@@ -83,6 +83,10 @@ class Test_ELB_Service:
assert elb.loadbalancers[0].name == "my-lb"
assert elb.loadbalancers[0].region == AWS_REGION
assert elb.loadbalancers[0].scheme == "internal"
assert (
elb.loadbalancers[0].arn
== f"arn:aws:elasticloadbalancing:{AWS_REGION}:{AWS_ACCOUNT_NUMBER}:loadbalancer/my-lb"
)
# Test ELB Describe Load Balancers Attributes
@mock_ec2
@@ -124,3 +128,7 @@ class Test_ELB_Service:
assert elb.loadbalancers[0].region == AWS_REGION
assert elb.loadbalancers[0].scheme == "internal"
assert elb.loadbalancers[0].access_logs
assert (
elb.loadbalancers[0].arn
== f"arn:aws:elasticloadbalancing:{AWS_REGION}:{AWS_ACCOUNT_NUMBER}:loadbalancer/my-lb"
)

View File

@@ -0,0 +1,6 @@
from providers.aws.lib.audit_info.audit_info import current_audit_info
from providers.aws.services.globalaccelerator.globalaccelerator_service import (
GlobalAccelerator,
)
globalaccelerator_client = GlobalAccelerator(current_audit_info)

View File

@@ -0,0 +1,46 @@
from pydantic import BaseModel
from lib.logger import logger
from providers.aws.aws_provider import get_region_global_service
################### GlobalAccelerator
class GlobalAccelerator:
def __init__(self, audit_info):
self.service = "globalaccelerator"
self.session = audit_info.audit_session
self.audited_account = audit_info.audited_account
self.region = get_region_global_service(audit_info)
self.client = self.session.client(self.service, self.region)
self.accelerators = {}
self.__list_accelerators__()
def __get_session__(self):
return self.session
def __list_accelerators__(self):
logger.info("GlobalAccelerator - Listing Accelerators...")
try:
list_accelerators_paginator = self.client.get_paginator("list_accelerators")
for page in list_accelerators_paginator.paginate():
for accelerator in page["Accelerators"]:
accelerator_arn = accelerator["AcceleratorArn"]
accelerator_name = accelerator["Name"]
enabled = accelerator["Enabled"]
self.accelerators[accelerator_name] = Accelerator(
name=accelerator_name,
arn=accelerator_arn,
region=self.region,
enabled=enabled,
)
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class Accelerator(BaseModel):
arn: str
name: str
region: str
enabled: bool

View File

@@ -0,0 +1,105 @@
from providers.aws.lib.audit_info.models import AWS_Audit_Info
from providers.aws.services.globalaccelerator.globalaccelerator_service import (
GlobalAccelerator,
)
from mock import patch
from moto.core import DEFAULT_ACCOUNT_ID
import botocore
from boto3 import session
# Mock Test Region
AWS_REGION = "eu-west-1"
# Mocking Access Analyzer Calls
make_api_call = botocore.client.BaseClient._make_api_call
def mock_make_api_call(self, operation_name, kwarg):
"""We have to mock every AWS API call using Boto3"""
if operation_name == "ListAccelerators":
return {
"Accelerators": [
{
"AcceleratorArn": f"arn:aws:globalaccelerator::{DEFAULT_ACCOUNT_ID}:accelerator/5555abcd-abcd-5555-abcd-5555EXAMPLE1",
"Name": "TestAccelerator",
"IpAddressType": "IPV4",
"Enabled": True,
"IpSets": [
{
"IpFamily": "IPv4",
"IpAddresses": ["192.0.2.250", "198.51.100.52"],
}
],
"DnsName": "5a5a5a5a5a5a5a5a.awsglobalaccelerator.com",
"Status": "DEPLOYED",
"CreatedTime": 1552424416.0,
"LastModifiedTime": 1569375641.0,
}
]
}
if operation_name == "GetSubscriptionState":
return {"SubscriptionState": "ACTIVE"}
return make_api_call(self, operation_name, kwarg)
# Patch every AWS call using Boto3 and generate_regional_clients to have 1 client
@patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call)
class Test_GlobalAccelerator_Service:
# Mocked Audit Info
def set_mocked_audit_info(self):
audit_info = AWS_Audit_Info(
original_session=None,
audit_session=session.Session(
profile_name=None,
botocore_session=None,
),
audited_account=DEFAULT_ACCOUNT_ID,
audited_user_id=None,
audited_partition="aws",
audited_identity_arn=None,
profile=None,
profile_region=AWS_REGION,
credentials=None,
assumed_role_info=None,
audited_regions=None,
organizations_metadata=None,
)
return audit_info
# Test GlobalAccelerator Service
def test_service(self):
# GlobalAccelerator client for this test class
audit_info = self.set_mocked_audit_info()
globalaccelerator = GlobalAccelerator(audit_info)
assert globalaccelerator.service == "globalaccelerator"
# Test GlobalAccelerator Client
def test_client(self):
# GlobalAccelerator client for this test class
audit_info = self.set_mocked_audit_info()
globalaccelerator = GlobalAccelerator(audit_info)
assert globalaccelerator.client.__class__.__name__ == "GlobalAccelerator"
# Test GlobalAccelerator Session
def test__get_session__(self):
# GlobalAccelerator client for this test class
audit_info = self.set_mocked_audit_info()
globalaccelerator = GlobalAccelerator(audit_info)
assert globalaccelerator.session.__class__.__name__ == "Session"
def test__list_accelerators__(self):
# GlobalAccelerator client for this test class
audit_info = self.set_mocked_audit_info()
globalaccelerator = GlobalAccelerator(audit_info)
accelerator_arn = f"arn:aws:globalaccelerator::{DEFAULT_ACCOUNT_ID}:accelerator/5555abcd-abcd-5555-abcd-5555EXAMPLE1"
accelerator_name = "TestAccelerator"
assert globalaccelerator.accelerators
assert len(globalaccelerator.accelerators) == 1
assert globalaccelerator.accelerators[accelerator_name]
assert globalaccelerator.accelerators[accelerator_name].name == accelerator_name
assert globalaccelerator.accelerators[accelerator_name].arn == accelerator_arn
assert globalaccelerator.accelerators[accelerator_name].region == AWS_REGION
assert globalaccelerator.accelerators[accelerator_name].enabled

View File

@@ -22,7 +22,7 @@ class kms_key_not_publicly_accessible(Check):
report.status_extended = (
f"KMS key {key.id} may be publicly accessible!"
)
else:
elif "AWS" in statement["Principal"]:
if type(statement["Principal"]["AWS"]) == str:
principals = [statement["Principal"]["AWS"]]
else:

View File

@@ -37,6 +37,7 @@ class Test_route53_public_hosted_zones_cloudwatch_logging_enabled:
route53.hosted_zones = {
hosted_zone_name: HostedZone(
name=hosted_zone_name,
arn=f"arn:aws:route53:::{hosted_zone_id}",
id=hosted_zone_id,
private_zone=False,
region=AWS_REGION,
@@ -72,6 +73,7 @@ class Test_route53_public_hosted_zones_cloudwatch_logging_enabled:
route53.hosted_zones = {
hosted_zone_name: HostedZone(
name=hosted_zone_name,
arn=f"arn:aws:route53:::{hosted_zone_id}",
id=hosted_zone_id,
private_zone=False,
region=AWS_REGION,
@@ -106,6 +108,7 @@ class Test_route53_public_hosted_zones_cloudwatch_logging_enabled:
route53.hosted_zones = {
hosted_zone_name: HostedZone(
name=hosted_zone_name,
arn=f"arn:aws:route53:::{hosted_zone_id}",
id=hosted_zone_id,
private_zone=True,
region=AWS_REGION,

View File

@@ -9,6 +9,7 @@ class Route53:
def __init__(self, audit_info):
self.service = "route53"
self.session = audit_info.audit_session
self.audited_partition = audit_info.audited_partition
self.client = self.session.client(self.service)
self.region = get_region_global_service(audit_info)
self.hosted_zones = {}
@@ -32,6 +33,7 @@ class Route53:
id=hosted_zone_id,
name=hosted_zone_name,
private_zone=private_zone,
arn=f"arn:{self.audited_partition}:route53:::{hosted_zone_id}",
region=self.region,
)
@@ -69,6 +71,7 @@ class LoggingConfig(BaseModel):
class HostedZone(BaseModel):
id: str
arn: str
name: str
private_zone: bool
logging_config: LoggingConfig = None

View File

@@ -93,6 +93,10 @@ class Test_Route53_Service:
assert len(route53.hosted_zones) == 1
assert route53.hosted_zones[hosted_zone_id]
assert route53.hosted_zones[hosted_zone_id].id == hosted_zone_id
assert (
route53.hosted_zones[hosted_zone_id].arn
== f"arn:aws:route53:::{hosted_zone_id}"
)
assert route53.hosted_zones[hosted_zone_id].name == hosted_zone_name
assert route53.hosted_zones[hosted_zone_id].private_zone
assert route53.hosted_zones[hosted_zone_id].logging_config
@@ -131,6 +135,10 @@ class Test_Route53_Service:
assert len(route53.hosted_zones) == 1
assert route53.hosted_zones[hosted_zone_id]
assert route53.hosted_zones[hosted_zone_id].id == hosted_zone_id
assert (
route53.hosted_zones[hosted_zone_id].arn
== f"arn:aws:route53:::{hosted_zone_id}"
)
assert route53.hosted_zones[hosted_zone_id].name == hosted_zone_name
assert not route53.hosted_zones[hosted_zone_id].private_zone
assert route53.hosted_zones[hosted_zone_id].logging_config
@@ -159,6 +167,10 @@ class Test_Route53_Service:
assert len(route53.hosted_zones) == 1
assert route53.hosted_zones[hosted_zone_id]
assert route53.hosted_zones[hosted_zone_id].id == hosted_zone_id
assert (
route53.hosted_zones[hosted_zone_id].arn
== f"arn:aws:route53:::{hosted_zone_id}"
)
assert route53.hosted_zones[hosted_zone_id].name == hosted_zone_name
assert route53.hosted_zones[hosted_zone_id].private_zone
assert not route53.hosted_zones[hosted_zone_id].logging_config
@@ -183,6 +195,10 @@ class Test_Route53_Service:
assert len(route53.hosted_zones) == 1
assert route53.hosted_zones[hosted_zone_id]
assert route53.hosted_zones[hosted_zone_id].id == hosted_zone_id
assert (
route53.hosted_zones[hosted_zone_id].arn
== f"arn:aws:route53:::{hosted_zone_id}"
)
assert route53.hosted_zones[hosted_zone_id].name == hosted_zone_name
assert not route53.hosted_zones[hosted_zone_id].private_zone
assert not route53.hosted_zones[hosted_zone_id].logging_config

View File

@@ -182,9 +182,9 @@ class S3Control:
def __init__(self, audit_info):
self.service = "s3control"
self.session = audit_info.audit_session
self.client = self.session.client(self.service)
self.audited_account = audit_info.audited_account
self.region = get_region_global_service(audit_info)
self.client = self.session.client(self.service, self.region)
self.account_public_access_block = self.__get_public_access_block__()
def __get_session__(self):

View File

@@ -1,50 +0,0 @@
#!/usr/bin/env bash
# Prowler - the handy cloud security tool (copyright 2019) by Toni de la Fuente
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
CHECK_ID_extra7166="7.166"
CHECK_TITLE_extra7166="[extra7166] Check if Elastic IP addresses with associations are protected by AWS Shield Advanced"
CHECK_SCORED_extra7166="NOT_SCORED"
CHECK_CIS_LEVEL_extra7166="EXTRA"
CHECK_SEVERITY_extra7166="Medium"
CHECK_ASFF_RESOURCE_TYPE_extra7166="AwsEc2Eip"
CHECK_ALTERNATE_check7166="extra7166"
CHECK_SERVICENAME_extra7166="shield"
CHECK_RISK_extra7166='AWS Shield Advanced provides expanded DDoS attack protection for your resources'
CHECK_REMEDIATION_extra7166='Add as a protected resource in AWS Shield Advanced.'
CHECK_DOC_extra7166='https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html'
CHECK_CAF_EPIC_extra7166='Infrastructure security'
extra7166() {
if [[ "$($AWSCLI $PROFILE_OPT shield get-subscription-state --output text)" == "ACTIVE" ]]; then
CALLER_IDENTITY=$($AWSCLI sts get-caller-identity $PROFILE_OPT --query Arn)
PARTITION=$(echo $CALLER_IDENTITY | cut -d: -f2)
ACCOUNT_ID=$(echo $CALLER_IDENTITY | cut -d: -f5)
for regx in $REGIONS; do
LIST_OF_ELASTIC_IPS_WITH_ASSOCIATIONS=$($AWSCLI ec2 describe-addresses $PROFILE_OPT --region $regx --query 'Addresses[?AssociationId].AllocationId' --output text)
if [[ $LIST_OF_ELASTIC_IPS_WITH_ASSOCIATIONS ]]; then
for elastic_ip in $LIST_OF_ELASTIC_IPS_WITH_ASSOCIATIONS; do
EIP_ARN="arn:${PARTITION}:ec2:${regx}:${ACCOUNT_ID}:eip-allocation/${elastic_ip}"
if $AWSCLI $PROFILE_OPT shield describe-protection --resource-arn $EIP_ARN >/dev/null 2>&1; then
textPass "$regx: EIP $elastic_ip is protected by AWS Shield Advanced" "$regx" "$elastic_ip"
else
textFail "$regx: EIP $elastic_ip is not protected by AWS Shield Advanced" "$regx" "$elastic_ip"
fi
done
else
textInfo "$regx: no elastic IP addresses with assocations found" "$regx"
fi
done
else
textInfo "No AWS Shield Advanced subscription found. Skipping check."
fi
}

View File

@@ -1,46 +0,0 @@
#!/usr/bin/env bash
# Prowler - the handy cloud security tool (copyright 2019) by Toni de la Fuente
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
CHECK_ID_extra7167="7.167"
CHECK_TITLE_extra7167="[extra7167] Check if Cloudfront distributions are protected by AWS Shield Advanced"
CHECK_SCORED_extra7167="NOT_SCORED"
CHECK_CIS_LEVEL_extra7167="EXTRA"
CHECK_SEVERITY_extra7167="Medium"
CHECK_ASFF_RESOURCE_TYPE_extra7167="AwsCloudFrontDistribution"
CHECK_ALTERNATE_check7167="extra7167"
CHECK_SERVICENAME_extra7167="shield"
CHECK_RISK_extra7167='AWS Shield Advanced provides expanded DDoS attack protection for your resources'
CHECK_REMEDIATION_extra7167='Add as a protected resource in AWS Shield Advanced.'
CHECK_DOC_extra7167='https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html'
CHECK_CAF_EPIC_extra7167='Infrastructure security'
extra7167() {
if [[ "$($AWSCLI $PROFILE_OPT shield get-subscription-state --output text)" == "ACTIVE" ]]; then
LIST_OF_CLOUDFRONT_DISTRIBUTIONS=$($AWSCLI cloudfront list-distributions $PROFILE_OPT --query 'DistributionList.Items[*].[Id,ARN]' --output text | grep -v None)
if [[ $LIST_OF_CLOUDFRONT_DISTRIBUTIONS ]]; then
while read -r distribution; do
DISTRIBUTION_ID=$(echo $distribution | awk '{ print $1; }')
DISTRIBUTION_ARN=$(echo $distribution | awk '{ print $2; }')
if $AWSCLI $PROFILE_OPT shield describe-protection --resource-arn $DISTRIBUTION_ARN >/dev/null 2>&1; then
textPass "$REGION: Cloudfront distribution $DISTRIBUTION_ID is protected by AWS Shield Advanced" "$REGION" "$DISTRIBUTION_ID"
else
textFail "$REGION: Cloudfront distribution $DISTRIBUTION_ID is not protected by AWS Shield Advanced" "$REGION" "$DISTRIBUTION_ID"
fi
done <<<"$LIST_OF_CLOUDFRONT_DISTRIBUTIONS"
else
textInfo "$REGION: no Cloudfront distributions found" "$REGION"
fi
else
textInfo "No AWS Shield Advanced subscription found. Skipping check."
fi
}

View File

@@ -1,49 +0,0 @@
#!/usr/bin/env bash
# Prowler - the handy cloud security tool (copyright 2019) by Toni de la Fuente
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
CHECK_ID_extra7168="7.168"
CHECK_TITLE_extra7168="[extra7168] Check if Route53 hosted zones are protected by AWS Shield Advanced"
CHECK_SCORED_extra7168="NOT_SCORED"
CHECK_CIS_LEVEL_extra7168="EXTRA"
CHECK_SEVERITY_extra7168="Medium"
CHECK_ASFF_RESOURCE_TYPE_extra7168="AwsRoute53Domain"
CHECK_ALTERNATE_check7168="extra7168"
CHECK_SERVICENAME_extra7168="shield"
CHECK_RISK_extra7168='AWS Shield Advanced provides expanded DDoS attack protection for your resources'
CHECK_REMEDIATION_extra7168='Add as a protected resource in AWS Shield Advanced.'
CHECK_DOC_extra7168='https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html'
CHECK_CAF_EPIC_extra7168='Infrastructure security'
extra7168() {
if [[ "$($AWSCLI $PROFILE_OPT shield get-subscription-state --output text)" == "ACTIVE" ]]; then
CALLER_IDENTITY=$($AWSCLI sts get-caller-identity $PROFILE_OPT --query Arn)
PARTITION=$(echo $CALLER_IDENTITY | cut -d: -f2)
LIST_OF_ROUTE53_HOSTED_ZONES=$($AWSCLI route53 list-hosted-zones $PROFILE_OPT --query 'HostedZones[*].[Id,Name]' --output text)
if [[ $LIST_OF_ROUTE53_HOSTED_ZONES ]]; then
while read -r hosted_zone; do
HOSTED_ZONE_ID=$(echo $hosted_zone | awk '{ print $1; }')
HOSTED_ZONE_NAME=$(echo $hosted_zone | awk '{ print $2; }')
HOSTED_ZONE_ARN="arn:${PARTITION}:route53:::${HOSTED_ZONE_ID:1}"
if $AWSCLI $PROFILE_OPT shield describe-protection --resource-arn $HOSTED_ZONE_ARN >/dev/null 2>&1; then
textPass "$REGION: Route53 Hosted Zone $HOSTED_ZONE_NAME is protected by AWS Shield Advanced" "$REGION" "$HOSTED_ZONE_NAME"
else
textFail "$REGION: Route53 Hosted Zone $HOSTED_ZONE_NAME is not protected by AWS Shield Advanced" "$REGION" "$HOSTED_ZONE_NAME"
fi
done <<<"$LIST_OF_ROUTE53_HOSTED_ZONES"
else
textInfo "$REGION: no Route53 hosted zones found" "$REGION"
fi
else
textInfo "No AWS Shield Advanced subscription found. Skipping check."
fi
}

View File

@@ -1,46 +0,0 @@
#!/usr/bin/env bash
# Prowler - the handy cloud security tool (copyright 2019) by Toni de la Fuente
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
CHECK_ID_extra7169="7.169"
CHECK_TITLE_extra7169="[extra7169] Check if global accelerators are protected by AWS Shield Advanced"
CHECK_SCORED_extra7169="NOT_SCORED"
CHECK_CIS_LEVEL_extra7169="EXTRA"
CHECK_SEVERITY_extra7169="Medium"
CHECK_ASFF_RESOURCE_TYPE_extra7169="AwsGlobalAccelerator"
CHECK_ALTERNATE_check7169="extra7169"
CHECK_SERVICENAME_extra7169="shield"
CHECK_RISK_extra7169='AWS Shield Advanced provides expanded DDoS attack protection for your resources'
CHECK_REMEDIATION_extra7169='Add as a protected resource in AWS Shield Advanced.'
CHECK_DOC_extra7169='https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html'
CHECK_CAF_EPIC_extra7169='Infrastructure security'
extra7169() {
if [[ "$($AWSCLI $PROFILE_OPT shield get-subscription-state --output text)" == "ACTIVE" ]]; then
LIST_OF_GLOBAL_ACCELERATORS=$($AWSCLI globalaccelerator list-accelerators --region us-west-2 $PROFILE_OPT --query 'Accelerators[?Enabled].[Name,AcceleratorArn]' --output text)
if [[ $LIST_OF_GLOBAL_ACCELERATORS ]]; then
while read -r accelerator; do
ACCELERATOR_NAME=$(echo $accelerator | awk '{ print $1; }')
ACCELERATOR_ARN=$(echo $accelerator | awk '{ print $2; }')
if $AWSCLI $PROFILE_OPT shield describe-protection --resource-arn $ACCELERATOR_ARN >/dev/null 2>&1; then
textPass "$REGION: Global Accelerator $ACCELERATOR_NAME is protected by AWS Shield Advanced" "$REGION" "$ACCELERATOR_NAME"
else
textFail "$REGION: Global Accelerator $ACCELERATOR_NAME is not protected by AWS Shield Advanced" "$REGION" "$ACCELERATOR_NAME"
fi
done <<<"$LIST_OF_GLOBAL_ACCELERATORS"
else
textInfo "$REGION: no global accelerators found" "$REGION"
fi
else
textInfo "No AWS Shield Advanced subscription found. Skipping check."
fi
}

View File

@@ -1,50 +0,0 @@
#!/usr/bin/env bash
# Prowler - the handy cloud security tool (copyright 2019) by Toni de la Fuente
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
CHECK_ID_extra7171="7.171"
CHECK_TITLE_extra7171="[extra7171] Check if classic load balancers are protected by AWS Shield Advanced"
CHECK_SCORED_extra7171="NOT_SCORED"
CHECK_CIS_LEVEL_extra7171="EXTRA"
CHECK_SEVERITY_extra7171="Medium"
CHECK_ASFF_RESOURCE_TYPE_extra7171="AwsElasticLoadBalancingLoadBalancer"
CHECK_ALTERNATE_check7171="extra7171"
CHECK_SERVICENAME_extra7171="shield"
CHECK_RISK_extra7171='AWS Shield Advanced provides expanded DDoS attack protection for your resources'
CHECK_REMEDIATION_extra7171='Add as a protected resource in AWS Shield Advanced.'
CHECK_DOC_extra7171='https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html'
CHECK_CAF_EPIC_extra7171='Infrastructure security'
extra7171() {
if [[ "$($AWSCLI $PROFILE_OPT shield get-subscription-state --output text)" == "ACTIVE" ]]; then
CALLER_IDENTITY=$($AWSCLI sts get-caller-identity $PROFILE_OPT --query Arn)
PARTITION=$(echo $CALLER_IDENTITY | cut -d: -f2)
ACCOUNT_ID=$(echo $CALLER_IDENTITY | cut -d: -f5)
for regx in $REGIONS; do
LIST_OF_CLASSIC_LOAD_BALANCERS=$($AWSCLI elb describe-load-balancers $PROFILE_OPT --region $regx --query 'LoadBalancerDescriptions[?Scheme == `internet-facing`].[LoadBalancerName]' --output text |grep -v '^None$')
if [[ $LIST_OF_CLASSIC_LOAD_BALANCERS ]]; then
for elb in $LIST_OF_CLASSIC_LOAD_BALANCERS; do
ELB_ARN="arn:${PARTITION}:elasticloadbalancing:${regx}:${ACCOUNT_ID}:loadbalancer/${elb}"
if $AWSCLI $PROFILE_OPT shield describe-protection --resource-arn $ELB_ARN >/dev/null 2>&1; then
textPass "$regx: ELB $elb is protected by AWS Shield Advanced" "$regx" "$elb"
else
textFail "$regx: ELB $elb is not protected by AWS Shield Advanced" "$regx" "$elb"
fi
done
else
textInfo "$regx: no classic load balancers found" "$regx"
fi
done
else
textInfo "No AWS Shield Advanced subscription found. Skipping check."
fi
}

View File

@@ -0,0 +1,35 @@
{
"Provider": "aws",
"CheckID": "shield_advanced_protection_in_associated_elastic_ips",
"CheckTitle": "Check if Elastic IP addresses with associations are protected by AWS Shield Advanced.",
"CheckType": [],
"ServiceName": "shield",
"SubServiceName": "",
"ResourceIdTemplate": "",
"Severity": "medium",
"ResourceType": "AwsEc2Eip",
"Description": "Check if Elastic IP addresses with associations are protected by AWS Shield Advanced.",
"Risk": "AWS Shield Advanced provides expanded DDoS attack protection for your resources.",
"RelatedUrl": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html",
"Remediation": {
"Code": {
"CLI": "",
"NativeIaC": "",
"Other": "",
"Terraform": ""
},
"Recommendation": {
"Text": "Add as a protected resource in AWS Shield Advanced.",
"Url": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html"
}
},
"Categories": [],
"Tags": {
"Tag1Key": "value",
"Tag2Key": "value"
},
"DependsOn": [],
"RelatedTo": [],
"Notes": "",
"Compliance": []
}

View File

@@ -0,0 +1,26 @@
from lib.check.models import Check, Check_Report
from providers.aws.services.shield.shield_client import shield_client
from providers.aws.services.ec2.ec2_client import ec2_client
class shield_advanced_protection_in_associated_elastic_ips(Check):
def execute(self):
findings = []
if shield_client.enabled:
for elastic_ip in ec2_client.elastic_ips:
report = Check_Report(self.metadata)
report.region = shield_client.region
report.resource_id = elastic_ip.allocation_id
report.resource_arn = elastic_ip.arn
report.status = "FAIL"
report.status_extended = f"Elastic IP {elastic_ip.allocation_id} is not protected by AWS Shield Advanced"
for protection in shield_client.protections.values():
if elastic_ip.arn == protection.resource_arn:
report.status = "PASS"
report.status_extended = f"Elastic IP {elastic_ip.allocation_id} is protected by AWS Shield Advanced"
break
findings.append(report)
return findings

View File

@@ -0,0 +1,207 @@
from unittest import mock
from boto3 import client, session
from mock import patch
from moto import mock_ec2
from moto.core import DEFAULT_ACCOUNT_ID
from providers.aws.lib.audit_info.models import AWS_Audit_Info
from providers.aws.services.shield.shield_service import Protection
AWS_REGION = "eu-west-1"
# Mock generate_regional_clients()
def mock_generate_regional_clients(service, audit_info):
regional_client = audit_info.audit_session.client(service, region_name=AWS_REGION)
regional_client.region = AWS_REGION
return {AWS_REGION: regional_client}
# Patch every AWS call using Boto3 and generate_regional_clients to have 1 client
@patch(
"providers.aws.services.accessanalyzer.accessanalyzer_service.generate_regional_clients",
new=mock_generate_regional_clients,
)
class Test_shield_advanced_protection_in_associated_elastic_ips:
# Mocked Audit Info
def set_mocked_audit_info(self):
audit_info = AWS_Audit_Info(
original_session=None,
audit_session=session.Session(
profile_name=None,
botocore_session=None,
),
audited_account=DEFAULT_ACCOUNT_ID,
audited_user_id=None,
audited_partition="aws",
audited_identity_arn=None,
profile=None,
profile_region=AWS_REGION,
credentials=None,
assumed_role_info=None,
audited_regions=None,
organizations_metadata=None,
)
return audit_info
@mock_ec2
def test_no_shield_not_active(self):
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
from providers.aws.services.ec2.ec2_service import EC2
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_associated_elastic_ips.shield_advanced_protection_in_associated_elastic_ips.ec2_client",
new=EC2(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_associated_elastic_ips.shield_advanced_protection_in_associated_elastic_ips import (
shield_advanced_protection_in_associated_elastic_ips,
)
check = shield_advanced_protection_in_associated_elastic_ips()
result = check.execute()
assert len(result) == 0
@mock_ec2
def test_shield_enabled_ip_protected(self):
# EC2 Client
ec2_client = client("ec2", region_name=AWS_REGION)
resp = ec2_client.allocate_address(Domain="vpc", Address="127.38.43.222")
allocation_id = resp["AllocationId"]
elastic_ip_arn = f"arn:aws:ec2:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:eip-allocation/{allocation_id}"
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
protection_id = "test-protection"
shield_client.protections = {
protection_id: Protection(
id=protection_id,
name="",
resource_arn=elastic_ip_arn,
protection_arn="",
region=AWS_REGION,
)
}
from providers.aws.services.ec2.ec2_service import EC2
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_associated_elastic_ips.shield_advanced_protection_in_associated_elastic_ips.ec2_client",
new=EC2(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_associated_elastic_ips.shield_advanced_protection_in_associated_elastic_ips import (
shield_advanced_protection_in_associated_elastic_ips,
)
check = shield_advanced_protection_in_associated_elastic_ips()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == allocation_id
assert result[0].resource_arn == elastic_ip_arn
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"Elastic IP {allocation_id} is protected by AWS Shield Advanced"
)
@mock_ec2
def test_shield_enabled_ip_not_protected(self):
# EC2 Client
ec2_client = client("ec2", region_name=AWS_REGION)
resp = ec2_client.allocate_address(Domain="vpc", Address="127.38.43.222")
allocation_id = resp["AllocationId"]
elastic_ip_arn = f"arn:aws:ec2:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:eip-allocation/{allocation_id}"
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
shield_client.protections = {}
from providers.aws.services.ec2.ec2_service import EC2
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_associated_elastic_ips.shield_advanced_protection_in_associated_elastic_ips.ec2_client",
new=EC2(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_associated_elastic_ips.shield_advanced_protection_in_associated_elastic_ips import (
shield_advanced_protection_in_associated_elastic_ips,
)
check = shield_advanced_protection_in_associated_elastic_ips()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == allocation_id
assert result[0].resource_arn == elastic_ip_arn
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"Elastic IP {allocation_id} is not protected by AWS Shield Advanced"
)
@mock_ec2
def test_shield_disabled_ip_not_protected(self):
# EC2 Client
ec2_client = client("ec2", region_name=AWS_REGION)
resp = ec2_client.allocate_address(Domain="vpc", Address="127.38.43.222")
allocation_id = resp["AllocationId"]
_ = f"arn:aws:ec2:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:eip-allocation/{allocation_id}"
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
shield_client.region = AWS_REGION
shield_client.protections = {}
from providers.aws.services.ec2.ec2_service import EC2
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_associated_elastic_ips.shield_advanced_protection_in_associated_elastic_ips.ec2_client",
new=EC2(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_associated_elastic_ips.shield_advanced_protection_in_associated_elastic_ips import (
shield_advanced_protection_in_associated_elastic_ips,
)
check = shield_advanced_protection_in_associated_elastic_ips()
result = check.execute()
assert len(result) == 0

View File

@@ -0,0 +1,35 @@
{
"Provider": "aws",
"CheckID": "shield_advanced_protection_in_associated_elastic_ips",
"CheckTitle": "Check if Classic Load Balancers are protected by AWS Shield Advanced.",
"CheckType": [],
"ServiceName": "shield",
"SubServiceName": "",
"ResourceIdTemplate": "",
"Severity": "medium",
"ResourceType": "AwsElasticLoadBalancingLoadBalancer",
"Description": "Check if Classic Load Balancers are protected by AWS Shield Advanced.",
"Risk": "AWS Shield Advanced provides expanded DDoS attack protection for your resources.",
"RelatedUrl": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html",
"Remediation": {
"Code": {
"CLI": "",
"NativeIaC": "",
"Other": "",
"Terraform": ""
},
"Recommendation": {
"Text": "Add as a protected resource in AWS Shield Advanced.",
"Url": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html"
}
},
"Categories": [],
"Tags": {
"Tag1Key": "value",
"Tag2Key": "value"
},
"DependsOn": [],
"RelatedTo": [],
"Notes": "",
"Compliance": []
}

View File

@@ -0,0 +1,30 @@
from lib.check.models import Check, Check_Report
from providers.aws.services.shield.shield_client import shield_client
from providers.aws.services.elb.elb_client import elb_client
class shield_advanced_protection_in_classic_load_balancers(Check):
def execute(self):
findings = []
if shield_client.enabled:
for elb in elb_client.loadbalancers:
report = Check_Report(self.metadata)
report.region = shield_client.region
report.resource_id = elb.name
report.resource_arn = elb.arn
report.status = "FAIL"
report.status_extended = (
f"ELB {elb.name} is not protected by AWS Shield Advanced"
)
for protection in shield_client.protections.values():
if elb.arn == protection.resource_arn:
report.status = "PASS"
report.status_extended = (
f"ELB {elb.name} is protected by AWS Shield Advanced"
)
break
findings.append(report)
return findings

View File

@@ -0,0 +1,240 @@
from unittest import mock
from boto3 import client, resource, session
from moto import mock_ec2, mock_elb
from moto.core import DEFAULT_ACCOUNT_ID
from providers.aws.lib.audit_info.models import AWS_Audit_Info
from providers.aws.services.shield.shield_service import Protection
AWS_REGION = "eu-west-1"
class Test_shield_advanced_protection_in_classic_load_balancers:
# Mocked Audit Info
def set_mocked_audit_info(self):
audit_info = AWS_Audit_Info(
original_session=None,
audit_session=session.Session(
profile_name=None,
botocore_session=None,
),
audited_account=DEFAULT_ACCOUNT_ID,
audited_user_id=None,
audited_partition="aws",
audited_identity_arn=None,
profile=None,
profile_region=AWS_REGION,
credentials=None,
assumed_role_info=None,
audited_regions=None,
organizations_metadata=None,
)
return audit_info
@mock_elb
@mock_ec2
def test_no_shield_not_active(self):
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
from providers.aws.services.elb.elb_service import ELB
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_classic_load_balancers.shield_advanced_protection_in_classic_load_balancers.elb_client",
new=ELB(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_classic_load_balancers.shield_advanced_protection_in_classic_load_balancers import (
shield_advanced_protection_in_classic_load_balancers,
)
check = shield_advanced_protection_in_classic_load_balancers()
result = check.execute()
assert len(result) == 0
@mock_ec2
@mock_elb
def test_shield_enabled_elb_protected(self):
# ELB Client
elb = client("elb", region_name=AWS_REGION)
ec2 = resource("ec2", region_name=AWS_REGION)
security_group = ec2.create_security_group(
GroupName="sg01", Description="Test security group sg01"
)
elb_name = "my-lb"
elb.create_load_balancer(
LoadBalancerName=elb_name,
Listeners=[
{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080},
{"Protocol": "http", "LoadBalancerPort": 81, "InstancePort": 9000},
],
AvailabilityZones=[f"{AWS_REGION}a"],
Scheme="internet-facing",
SecurityGroups=[security_group.id],
)
elb_arn = f"arn:aws:elasticloadbalancing:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:loadbalancer/{elb_name}"
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
protection_id = "test-protection"
shield_client.protections = {
protection_id: Protection(
id=protection_id,
name="",
resource_arn=elb_arn,
protection_arn="",
region=AWS_REGION,
)
}
from providers.aws.services.elb.elb_service import ELB
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_classic_load_balancers.shield_advanced_protection_in_classic_load_balancers.elb_client",
new=ELB(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_classic_load_balancers.shield_advanced_protection_in_classic_load_balancers import (
shield_advanced_protection_in_classic_load_balancers,
)
check = shield_advanced_protection_in_classic_load_balancers()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == elb_name
assert result[0].resource_arn == elb_arn
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"ELB {elb_name} is protected by AWS Shield Advanced"
)
@mock_elb
@mock_ec2
def test_shield_enabled_elb_not_protected(self):
# ELB Client
elb = client("elb", region_name=AWS_REGION)
ec2 = resource("ec2", region_name=AWS_REGION)
security_group = ec2.create_security_group(
GroupName="sg01", Description="Test security group sg01"
)
elb_name = "my-lb"
elb.create_load_balancer(
LoadBalancerName=elb_name,
Listeners=[
{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080},
{"Protocol": "http", "LoadBalancerPort": 81, "InstancePort": 9000},
],
AvailabilityZones=[f"{AWS_REGION}a"],
Scheme="internet-facing",
SecurityGroups=[security_group.id],
)
elb_arn = f"arn:aws:elasticloadbalancing:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:loadbalancer/{elb_name}"
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
shield_client.protections = {}
from providers.aws.services.elb.elb_service import ELB
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_classic_load_balancers.shield_advanced_protection_in_classic_load_balancers.elb_client",
new=ELB(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_classic_load_balancers.shield_advanced_protection_in_classic_load_balancers import (
shield_advanced_protection_in_classic_load_balancers,
)
check = shield_advanced_protection_in_classic_load_balancers()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == elb_name
assert result[0].resource_arn == elb_arn
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"ELB {elb_name} is not protected by AWS Shield Advanced"
)
@mock_elb
@mock_ec2
def test_shield_disabled_elb_not_protected(self):
# ELB Client
elb = client("elb", region_name=AWS_REGION)
ec2 = resource("ec2", region_name=AWS_REGION)
security_group = ec2.create_security_group(
GroupName="sg01", Description="Test security group sg01"
)
elb_name = "my-lb"
elb.create_load_balancer(
LoadBalancerName=elb_name,
Listeners=[
{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080},
{"Protocol": "http", "LoadBalancerPort": 81, "InstancePort": 9000},
],
AvailabilityZones=[f"{AWS_REGION}a"],
Scheme="internet-facing",
SecurityGroups=[security_group.id],
)
_ = f"arn:aws:elasticloadbalancing:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:loadbalancer/{elb_name}"
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
shield_client.region = AWS_REGION
shield_client.protections = {}
from providers.aws.services.elb.elb_service import ELB
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_classic_load_balancers.shield_advanced_protection_in_classic_load_balancers.elb_client",
new=ELB(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_classic_load_balancers.shield_advanced_protection_in_classic_load_balancers import (
shield_advanced_protection_in_classic_load_balancers,
)
check = shield_advanced_protection_in_classic_load_balancers()
result = check.execute()
assert len(result) == 0

View File

@@ -0,0 +1,35 @@
{
"Provider": "aws",
"CheckID": "shield_advanced_protection_in_cloudfront_distributions",
"CheckTitle": "Check if Cloudfront distributions are protected by AWS Shield Advanced.",
"CheckType": [],
"ServiceName": "shield",
"SubServiceName": "",
"ResourceIdTemplate": "",
"Severity": "medium",
"ResourceType": "AwsCloudFrontDistribution",
"Description": "Check if Cloudfront distributions are protected by AWS Shield Advanced.",
"Risk": "AWS Shield Advanced provides expanded DDoS attack protection for your resources.",
"RelatedUrl": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html",
"Remediation": {
"Code": {
"CLI": "",
"NativeIaC": "",
"Other": "",
"Terraform": ""
},
"Recommendation": {
"Text": "Add as a protected resource in AWS Shield Advanced.",
"Url": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html"
}
},
"Categories": [],
"Tags": {
"Tag1Key": "value",
"Tag2Key": "value"
},
"DependsOn": [],
"RelatedTo": [],
"Notes": "",
"Compliance": []
}

View File

@@ -0,0 +1,26 @@
from lib.check.models import Check, Check_Report
from providers.aws.services.shield.shield_client import shield_client
from providers.aws.services.cloudfront.cloudfront_client import cloudfront_client
class shield_advanced_protection_in_cloudfront_distributions(Check):
def execute(self):
findings = []
if shield_client.enabled:
for distribution in cloudfront_client.distributions.values():
report = Check_Report(self.metadata)
report.region = shield_client.region
report.resource_id = distribution.id
report.resource_arn = distribution.arn
report.status = "FAIL"
report.status_extended = f"CloudFront distribution {distribution.id} is not protected by AWS Shield Advanced"
for protection in shield_client.protections.values():
if distribution.arn == protection.resource_arn:
report.status = "PASS"
report.status_extended = f"CloudFront distribution {distribution.id} is protected by AWS Shield Advanced"
break
findings.append(report)
return findings

View File

@@ -0,0 +1,165 @@
from unittest import mock
from moto.core import DEFAULT_ACCOUNT_ID
from providers.aws.services.cloudfront.cloudfront_service import Distribution
from providers.aws.services.shield.shield_service import Protection
AWS_REGION = "eu-west-1"
class Test_shield_advanced_protection_in_cloudfront_distributions:
def test_no_shield_not_active(self):
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
# CloudFront Client
cloudfront_client = mock.MagicMock
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.cloudfront.cloudfront_service.CloudFront",
new=cloudfront_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_cloudfront_distributions.shield_advanced_protection_in_cloudfront_distributions import (
shield_advanced_protection_in_cloudfront_distributions,
)
check = shield_advanced_protection_in_cloudfront_distributions()
result = check.execute()
assert len(result) == 0
def test_shield_enabled_cloudfront_protected(self):
# CloudFront Client
cloudfront_client = mock.MagicMock
distribution_id = "EDFDVBD632BHDS5"
distribution_arn = (
f"arn:aws:cloudfront::{DEFAULT_ACCOUNT_ID}:distribution/{distribution_id}"
)
cloudfront_client.distributions = {
distribution_id: Distribution(
arn=distribution_arn, id=distribution_id, region=AWS_REGION, origins=[]
)
}
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
protection_id = "test-protection"
shield_client.protections = {
protection_id: Protection(
id=protection_id,
name="",
resource_arn=distribution_arn,
protection_arn="",
region=AWS_REGION,
)
}
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.cloudfront.cloudfront_service.CloudFront",
new=cloudfront_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_cloudfront_distributions.shield_advanced_protection_in_cloudfront_distributions import (
shield_advanced_protection_in_cloudfront_distributions,
)
check = shield_advanced_protection_in_cloudfront_distributions()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == distribution_id
assert result[0].resource_arn == distribution_arn
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"CloudFront distribution {distribution_id} is protected by AWS Shield Advanced"
)
def test_shield_enabled_cloudfront_not_protected(self):
# CloudFront Client
cloudfront_client = mock.MagicMock
distribution_id = "EDFDVBD632BHDS5"
distribution_arn = (
f"arn:aws:cloudfront::{DEFAULT_ACCOUNT_ID}:distribution/{distribution_id}"
)
cloudfront_client.distributions = {
distribution_id: Distribution(
arn=distribution_arn, id=distribution_id, region=AWS_REGION, origins=[]
)
}
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
shield_client.protections = {}
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.cloudfront.cloudfront_service.CloudFront",
new=cloudfront_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_cloudfront_distributions.shield_advanced_protection_in_cloudfront_distributions import (
shield_advanced_protection_in_cloudfront_distributions,
)
check = shield_advanced_protection_in_cloudfront_distributions()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == distribution_id
assert result[0].resource_arn == distribution_arn
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"CloudFront distribution {distribution_id} is not protected by AWS Shield Advanced"
)
def test_shield_disabled_cloudfront_not_protected(self):
# CloudFront Client
cloudfront_client = mock.MagicMock
distribution_id = "EDFDVBD632BHDS5"
distribution_arn = (
f"arn:aws:cloudfront::{DEFAULT_ACCOUNT_ID}:distribution/{distribution_id}"
)
cloudfront_client.distributions = {
distribution_id: Distribution(
arn=distribution_arn, id=distribution_id, region=AWS_REGION, origins=[]
)
}
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
shield_client.region = AWS_REGION
shield_client.protections = {}
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.cloudfront.cloudfront_service.CloudFront",
new=cloudfront_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_cloudfront_distributions.shield_advanced_protection_in_cloudfront_distributions import (
shield_advanced_protection_in_cloudfront_distributions,
)
check = shield_advanced_protection_in_cloudfront_distributions()
result = check.execute()
assert len(result) == 0

View File

@@ -0,0 +1,35 @@
{
"Provider": "aws",
"CheckID": "shield_advanced_protection_in_global_accelerators",
"CheckTitle": "Check if Global Accelerators are protected by AWS Shield Advanced.",
"CheckType": [],
"ServiceName": "shield",
"SubServiceName": "",
"ResourceIdTemplate": "",
"Severity": "medium",
"ResourceType": "AwsGlobalAccelerator",
"Description": "Check if Global Accelerators are protected by AWS Shield Advanced.",
"Risk": "AWS Shield Advanced provides expanded DDoS attack protection for your resources.",
"RelatedUrl": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html",
"Remediation": {
"Code": {
"CLI": "",
"NativeIaC": "",
"Other": "",
"Terraform": ""
},
"Recommendation": {
"Text": "Add as a protected resource in AWS Shield Advanced.",
"Url": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html"
}
},
"Categories": [],
"Tags": {
"Tag1Key": "value",
"Tag2Key": "value"
},
"DependsOn": [],
"RelatedTo": [],
"Notes": "",
"Compliance": []
}

View File

@@ -0,0 +1,28 @@
from lib.check.models import Check, Check_Report
from providers.aws.services.shield.shield_client import shield_client
from providers.aws.services.globalaccelerator.globalaccelerator_client import (
globalaccelerator_client,
)
class shield_advanced_protection_in_global_accelerators(Check):
def execute(self):
findings = []
if shield_client.enabled:
for accelerator in globalaccelerator_client.accelerators.values():
report = Check_Report(self.metadata)
report.region = shield_client.region
report.resource_id = accelerator.name
report.resource_arn = accelerator.arn
report.status = "FAIL"
report.status_extended = f"Global Accelerator {accelerator.name} is not protected by AWS Shield Advanced"
for protection in shield_client.protections.values():
if accelerator.arn == protection.resource_arn:
report.status = "PASS"
report.status_extended = f"Global Accelerator {accelerator.name} is protected by AWS Shield Advanced"
break
findings.append(report)
return findings

View File

@@ -0,0 +1,173 @@
from unittest import mock
from moto.core import DEFAULT_ACCOUNT_ID
from providers.aws.services.globalaccelerator.globalaccelerator_service import (
Accelerator,
)
from providers.aws.services.shield.shield_service import Protection
AWS_REGION = "eu-west-1"
class Test_shield_advanced_protection_in_global_accelerators:
def test_no_shield_not_active(self):
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
# GlobalAccelerator Client
globalaccelerator_client = mock.MagicMock
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.globalaccelerator.globalaccelerator_service.GlobalAccelerator",
new=globalaccelerator_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_global_accelerators.shield_advanced_protection_in_global_accelerators import (
shield_advanced_protection_in_global_accelerators,
)
check = shield_advanced_protection_in_global_accelerators()
result = check.execute()
assert len(result) == 0
def test_shield_enabled_globalaccelerator_protected(self):
# GlobalAccelerator Client
globalaccelerator_client = mock.MagicMock
accelerator_name = "1234abcd-abcd-1234-abcd-1234abcdefgh"
accelerator_id = "1234abcd-abcd-1234-abcd-1234abcdefgh"
accelerator_arn = f"arn:aws:globalaccelerator::{DEFAULT_ACCOUNT_ID}:accelerator/{accelerator_id}"
globalaccelerator_client.accelerators = {
accelerator_name: Accelerator(
arn=accelerator_arn,
name=accelerator_name,
region=AWS_REGION,
enabled=True,
)
}
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
protection_id = "test-protection"
shield_client.protections = {
protection_id: Protection(
id=protection_id,
name="",
resource_arn=accelerator_arn,
protection_arn="",
region=AWS_REGION,
)
}
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.globalaccelerator.globalaccelerator_service.GlobalAccelerator",
new=globalaccelerator_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_global_accelerators.shield_advanced_protection_in_global_accelerators import (
shield_advanced_protection_in_global_accelerators,
)
check = shield_advanced_protection_in_global_accelerators()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == accelerator_id
assert result[0].resource_arn == accelerator_arn
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"Global Accelerator {accelerator_id} is protected by AWS Shield Advanced"
)
def test_shield_enabled_globalaccelerator_not_protected(self):
# GlobalAccelerator Client
globalaccelerator_client = mock.MagicMock
accelerator_name = "1234abcd-abcd-1234-abcd-1234abcdefgh"
accelerator_id = "1234abcd-abcd-1234-abcd-1234abcdefgh"
accelerator_arn = f"arn:aws:globalaccelerator::{DEFAULT_ACCOUNT_ID}:accelerator/{accelerator_id}"
globalaccelerator_client.accelerators = {
accelerator_name: Accelerator(
arn=accelerator_arn,
name=accelerator_name,
region=AWS_REGION,
enabled=True,
)
}
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
shield_client.protections = {}
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.globalaccelerator.globalaccelerator_service.GlobalAccelerator",
new=globalaccelerator_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_global_accelerators.shield_advanced_protection_in_global_accelerators import (
shield_advanced_protection_in_global_accelerators,
)
check = shield_advanced_protection_in_global_accelerators()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == accelerator_id
assert result[0].resource_arn == accelerator_arn
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"Global Accelerator {accelerator_id} is not protected by AWS Shield Advanced"
)
def test_shield_disabled_globalaccelerator_not_protected(self):
# GlobalAccelerator Client
globalaccelerator_client = mock.MagicMock
accelerator_name = "1234abcd-abcd-1234-abcd-1234abcdefgh"
accelerator_id = "1234abcd-abcd-1234-abcd-1234abcdefgh"
accelerator_arn = f"arn:aws:globalaccelerator::{DEFAULT_ACCOUNT_ID}:accelerator/{accelerator_id}"
globalaccelerator_client.accelerators = {
accelerator_name: Accelerator(
arn=accelerator_arn,
name=accelerator_name,
region=AWS_REGION,
enabled=True,
)
}
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
shield_client.region = AWS_REGION
shield_client.protections = {}
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.globalaccelerator.globalaccelerator_service.GlobalAccelerator",
new=globalaccelerator_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_global_accelerators.shield_advanced_protection_in_global_accelerators import (
shield_advanced_protection_in_global_accelerators,
)
check = shield_advanced_protection_in_global_accelerators()
result = check.execute()
assert len(result) == 0

View File

@@ -0,0 +1,35 @@
{
"Provider": "aws",
"CheckID": "shield_advanced_protection_in_internet_facing_load_balancers",
"CheckTitle": "Check if internet-facing Application Load Balancers are protected by AWS Shield Advanced.",
"CheckType": [],
"ServiceName": "shield",
"SubServiceName": "",
"ResourceIdTemplate": "",
"Severity": "medium",
"ResourceType": "AwsElasticLoadBalancingV2LoadBalancer",
"Description": "Check if internet-facing Application Load Balancers are protected by AWS Shield Advanced.",
"Risk": "AWS Shield Advanced provides expanded DDoS attack protection for your resources.",
"RelatedUrl": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html",
"Remediation": {
"Code": {
"CLI": "",
"NativeIaC": "",
"Other": "",
"Terraform": ""
},
"Recommendation": {
"Text": "Add as a protected resource in AWS Shield Advanced.",
"Url": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html"
}
},
"Categories": [],
"Tags": {
"Tag1Key": "value",
"Tag2Key": "value"
},
"DependsOn": [],
"RelatedTo": [],
"Notes": "",
"Compliance": []
}

View File

@@ -0,0 +1,29 @@
from lib.check.models import Check, Check_Report
from providers.aws.services.shield.shield_client import shield_client
from providers.aws.services.elbv2.elbv2_client import (
elbv2_client,
)
class shield_advanced_protection_in_internet_facing_load_balancers(Check):
def execute(self):
findings = []
if shield_client.enabled:
for elbv2 in elbv2_client.loadbalancersv2:
if elbv2.type == "application" and elbv2.scheme == "internet-facing":
report = Check_Report(self.metadata)
report.region = shield_client.region
report.resource_id = elbv2.name
report.resource_arn = elbv2.arn
report.status = "FAIL"
report.status_extended = f"ELBv2 ALB {elbv2.name} is not protected by AWS Shield Advanced"
for protection in shield_client.protections.values():
if elbv2.arn == protection.resource_arn:
report.status = "PASS"
report.status_extended = f"ELBv2 ALB {elbv2.name} is protected by AWS Shield Advanced"
break
findings.append(report)
return findings

View File

@@ -0,0 +1,337 @@
from unittest import mock
from boto3 import client, resource, session
from mock import patch
from moto import mock_ec2, mock_elbv2
from moto.core import DEFAULT_ACCOUNT_ID as AWS_ACCOUNT_NUMBER
from providers.aws.lib.audit_info.models import AWS_Audit_Info
from providers.aws.services.shield.shield_service import Protection
AWS_REGION = "eu-west-1"
# Mock generate_regional_clients()
def mock_generate_regional_clients(service, audit_info):
regional_client = audit_info.audit_session.client(service, region_name=AWS_REGION)
regional_client.region = AWS_REGION
return {AWS_REGION: regional_client}
# Patch every AWS call using Boto3 and generate_regional_clients to have 1 client
@patch(
"providers.aws.services.accessanalyzer.accessanalyzer_service.generate_regional_clients",
new=mock_generate_regional_clients,
)
class Test_shield_advanced_protection_in_internet_facing_load_balancers:
# Mocked Audit Info
def set_mocked_audit_info(self):
audit_info = AWS_Audit_Info(
original_session=None,
audit_session=session.Session(
profile_name=None,
botocore_session=None,
),
audited_account=AWS_ACCOUNT_NUMBER,
audited_user_id=None,
audited_partition="aws",
audited_identity_arn=None,
profile=None,
profile_region=AWS_REGION,
credentials=None,
assumed_role_info=None,
audited_regions=None,
organizations_metadata=None,
)
return audit_info
@mock_ec2
@mock_elbv2
def test_no_shield_not_active(self):
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
from providers.aws.services.elbv2.elbv2_service import ELBv2
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_internet_facing_load_balancers.shield_advanced_protection_in_internet_facing_load_balancers.elbv2_client",
new=ELBv2(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_internet_facing_load_balancers.shield_advanced_protection_in_internet_facing_load_balancers import (
shield_advanced_protection_in_internet_facing_load_balancers,
)
check = shield_advanced_protection_in_internet_facing_load_balancers()
result = check.execute()
assert len(result) == 0
@mock_ec2
@mock_elbv2
def test_shield_enabled_elbv2_internet_facing_protected(self):
# ELBv2 Client
conn = client("elbv2", region_name=AWS_REGION)
ec2 = resource("ec2", region_name=AWS_REGION)
security_group = ec2.create_security_group(
GroupName="a-security-group", Description="First One"
)
vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default")
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock="172.28.7.192/26",
AvailabilityZone=f"{AWS_REGION}a",
)
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock="172.28.7.0/26",
AvailabilityZone=f"{AWS_REGION}b",
)
lb_name = "my-lb"
lb = conn.create_load_balancer(
Name=lb_name,
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme="internet-facing",
Type="application",
)["LoadBalancers"][0]
lb_arn = lb["LoadBalancerArn"]
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
protection_id = "test-protection"
shield_client.protections = {
protection_id: Protection(
id=protection_id,
name="",
resource_arn=lb_arn,
protection_arn="",
region=AWS_REGION,
)
}
from providers.aws.services.elbv2.elbv2_service import ELBv2
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_internet_facing_load_balancers.shield_advanced_protection_in_internet_facing_load_balancers.elbv2_client",
new=ELBv2(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_internet_facing_load_balancers.shield_advanced_protection_in_internet_facing_load_balancers import (
shield_advanced_protection_in_internet_facing_load_balancers,
)
check = shield_advanced_protection_in_internet_facing_load_balancers()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == lb_name
assert result[0].resource_arn == lb["LoadBalancerArn"]
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"ELBv2 ALB {lb_name} is protected by AWS Shield Advanced"
)
@mock_ec2
@mock_elbv2
def test_shield_enabled_elbv2_internal_protected(self):
# ELBv2 Client
conn = client("elbv2", region_name=AWS_REGION)
ec2 = resource("ec2", region_name=AWS_REGION)
security_group = ec2.create_security_group(
GroupName="a-security-group", Description="First One"
)
vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default")
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock="172.28.7.192/26",
AvailabilityZone=f"{AWS_REGION}a",
)
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock="172.28.7.0/26",
AvailabilityZone=f"{AWS_REGION}b",
)
lb_name = "my-lb"
lb = conn.create_load_balancer(
Name=lb_name,
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme="internal",
Type="application",
)["LoadBalancers"][0]
lb_arn = lb["LoadBalancerArn"]
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
protection_id = "test-protection"
shield_client.protections = {
protection_id: Protection(
id=protection_id,
name="",
resource_arn=lb_arn,
protection_arn="",
region=AWS_REGION,
)
}
from providers.aws.services.elbv2.elbv2_service import ELBv2
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_internet_facing_load_balancers.shield_advanced_protection_in_internet_facing_load_balancers.elbv2_client",
new=ELBv2(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_internet_facing_load_balancers.shield_advanced_protection_in_internet_facing_load_balancers import (
shield_advanced_protection_in_internet_facing_load_balancers,
)
check = shield_advanced_protection_in_internet_facing_load_balancers()
result = check.execute()
assert len(result) == 0
@mock_ec2
@mock_elbv2
def test_shield_enabled_elbv2_internet_facing_not_protected(self):
# ELBv2 Client
conn = client("elbv2", region_name=AWS_REGION)
ec2 = resource("ec2", region_name=AWS_REGION)
security_group = ec2.create_security_group(
GroupName="a-security-group", Description="First One"
)
vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default")
subnet1 = ec2.create_subnet(
VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone=f"{AWS_REGION}a"
)
subnet2 = ec2.create_subnet(
VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone=f"{AWS_REGION}b"
)
lb_name = "my-lb"
lb = conn.create_load_balancer(
Name=lb_name,
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme="internet-facing",
Type="application",
)["LoadBalancers"][0]
lb_arn = lb["LoadBalancerArn"]
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
shield_client.protections = {}
from providers.aws.services.elbv2.elbv2_service import ELBv2
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_internet_facing_load_balancers.shield_advanced_protection_in_internet_facing_load_balancers.elbv2_client",
new=ELBv2(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_internet_facing_load_balancers.shield_advanced_protection_in_internet_facing_load_balancers import (
shield_advanced_protection_in_internet_facing_load_balancers,
)
check = shield_advanced_protection_in_internet_facing_load_balancers()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == lb_name
assert result[0].resource_arn == lb_arn
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"ELBv2 ALB {lb_name} is not protected by AWS Shield Advanced"
)
@mock_ec2
@mock_elbv2
def test_shield_disabled_elbv2_internet_facing_not_protected(self):
# ELBv2 Client
conn = client("elbv2", region_name=AWS_REGION)
ec2 = resource("ec2", region_name=AWS_REGION)
security_group = ec2.create_security_group(
GroupName="a-security-group", Description="First One"
)
vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default")
subnet1 = ec2.create_subnet(
VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone=f"{AWS_REGION}a"
)
subnet2 = ec2.create_subnet(
VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone=f"{AWS_REGION}b"
)
lb_name = "my-lb"
lb = conn.create_load_balancer(
Name=lb_name,
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme="internal",
Type="application",
)["LoadBalancers"][0]
_ = lb["LoadBalancerArn"]
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
shield_client.region = AWS_REGION
shield_client.protections = {}
from providers.aws.services.elbv2.elbv2_service import ELBv2
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.lib.audit_info.audit_info.current_audit_info",
new=self.set_mocked_audit_info(),
), mock.patch(
"providers.aws.services.shield.shield_advanced_protection_in_internet_facing_load_balancers.shield_advanced_protection_in_internet_facing_load_balancers.elbv2_client",
new=ELBv2(self.set_mocked_audit_info()),
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_internet_facing_load_balancers.shield_advanced_protection_in_internet_facing_load_balancers import (
shield_advanced_protection_in_internet_facing_load_balancers,
)
check = shield_advanced_protection_in_internet_facing_load_balancers()
result = check.execute()
assert len(result) == 0

View File

@@ -0,0 +1,35 @@
{
"Provider": "aws",
"CheckID": "shield_advanced_protection_in_route53_hosted_zones",
"CheckTitle": "Check if Route53 hosted zones are protected by AWS Shield Advanced.",
"CheckType": [],
"ServiceName": "shield",
"SubServiceName": "",
"ResourceIdTemplate": "",
"Severity": "medium",
"ResourceType": "AwsRoute53Domain",
"Description": "Check if Route53 hosted zones are protected by AWS Shield Advanced.",
"Risk": "AWS Shield Advanced provides expanded DDoS attack protection for your resources.",
"RelatedUrl": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html",
"Remediation": {
"Code": {
"CLI": "",
"NativeIaC": "",
"Other": "",
"Terraform": ""
},
"Recommendation": {
"Text": "Add as a protected resource in AWS Shield Advanced.",
"Url": "https://docs.aws.amazon.com/waf/latest/developerguide/configure-new-protection.html"
}
},
"Categories": [],
"Tags": {
"Tag1Key": "value",
"Tag2Key": "value"
},
"DependsOn": [],
"RelatedTo": [],
"Notes": "",
"Compliance": []
}

View File

@@ -0,0 +1,28 @@
from lib.check.models import Check, Check_Report
from providers.aws.services.shield.shield_client import shield_client
from providers.aws.services.route53.route53_client import (
route53_client,
)
class shield_advanced_protection_in_route53_hosted_zones(Check):
def execute(self):
findings = []
if shield_client.enabled:
for hosted_zone in route53_client.hosted_zones.values():
report = Check_Report(self.metadata)
report.region = shield_client.region
report.resource_id = hosted_zone.id
report.resource_arn = hosted_zone.arn
report.status = "FAIL"
report.status_extended = f"Route53 Hosted Zone {hosted_zone.id} is not protected by AWS Shield Advanced"
for protection in shield_client.protections.values():
if hosted_zone.arn == protection.resource_arn:
report.status = "PASS"
report.status_extended = f"Route53 Hosted Zone {hosted_zone.id} is protected by AWS Shield Advanced"
break
findings.append(report)
return findings

View File

@@ -0,0 +1,181 @@
from unittest import mock
from providers.aws.services.route53.route53_service import (
HostedZone,
)
from providers.aws.services.shield.shield_service import Protection
AWS_REGION = "eu-west-1"
class Test_shield_advanced_protection_in_route53_hosted_zones:
def test_no_shield_not_active(self):
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
# Route53 Client
route53_client = mock.MagicMock
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.route53.route53_service.Route53",
new=route53_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_route53_hosted_zones.shield_advanced_protection_in_route53_hosted_zones import (
shield_advanced_protection_in_route53_hosted_zones,
)
check = shield_advanced_protection_in_route53_hosted_zones()
result = check.execute()
assert len(result) == 0
def test_shield_enabled_route53_hosted_zone_protected(self):
# Route53 Client
route53_client = mock.MagicMock
hosted_zone_id = "ABCDEF12345678"
hosted_zone_arn = f"arn:aws:route53:::hostedzone/{hosted_zone_id}"
hosted_zone_name = "test-hosted-zone"
route53_client.hosted_zones = {
hosted_zone_id: HostedZone(
id=hosted_zone_id,
arn=hosted_zone_arn,
name=hosted_zone_name,
hosted_zone_name=hosted_zone_name,
private_zone=False,
region=AWS_REGION,
)
}
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
protection_id = "test-protection"
shield_client.protections = {
protection_id: Protection(
id=protection_id,
name="",
resource_arn=hosted_zone_arn,
protection_arn="",
region=AWS_REGION,
)
}
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.route53.route53_service.Route53",
new=route53_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_route53_hosted_zones.shield_advanced_protection_in_route53_hosted_zones import (
shield_advanced_protection_in_route53_hosted_zones,
)
check = shield_advanced_protection_in_route53_hosted_zones()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == hosted_zone_id
assert result[0].resource_arn == hosted_zone_arn
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"Route53 Hosted Zone {hosted_zone_id} is protected by AWS Shield Advanced"
)
def test_shield_enabled_route53_hosted_zone_not_protected(self):
# Route53 Client
route53_client = mock.MagicMock
hosted_zone_id = "ABCDEF12345678"
hosted_zone_arn = f"arn:aws:route53:::hostedzone/{hosted_zone_id}"
hosted_zone_name = "test-hosted-zone"
route53_client.hosted_zones = {
hosted_zone_id: HostedZone(
id=hosted_zone_id,
arn=hosted_zone_arn,
name=hosted_zone_name,
hosted_zone_name=hosted_zone_name,
private_zone=False,
region=AWS_REGION,
)
}
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = True
shield_client.region = AWS_REGION
shield_client.protections = {}
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.route53.route53_service.Route53",
new=route53_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_route53_hosted_zones.shield_advanced_protection_in_route53_hosted_zones import (
shield_advanced_protection_in_route53_hosted_zones,
)
check = shield_advanced_protection_in_route53_hosted_zones()
result = check.execute()
assert len(result) == 1
assert result[0].region == AWS_REGION
assert result[0].resource_id == hosted_zone_id
assert result[0].resource_arn == hosted_zone_arn
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"Route53 Hosted Zone {hosted_zone_id} is not protected by AWS Shield Advanced"
)
def test_shield_disabled_route53_hosted_zone_not_protected(self):
# Route53 Client
route53_client = mock.MagicMock
hosted_zone_id = "ABCDEF12345678"
hosted_zone_arn = f"arn:aws:route53:::hostedzone/{hosted_zone_id}"
hosted_zone_name = "test-hosted-zone"
route53_client.hosted_zones = {
hosted_zone_id: HostedZone(
id=hosted_zone_id,
arn=hosted_zone_arn,
name=hosted_zone_name,
hosted_zone_name=hosted_zone_name,
private_zone=False,
region=AWS_REGION,
)
}
# Shield Client
shield_client = mock.MagicMock
shield_client.enabled = False
shield_client.region = AWS_REGION
shield_client.protections = {}
with mock.patch(
"providers.aws.services.shield.shield_service.Shield",
new=shield_client,
), mock.patch(
"providers.aws.services.route53.route53_service.Route53",
new=route53_client,
):
# Test Check
from providers.aws.services.shield.shield_advanced_protection_in_route53_hosted_zones.shield_advanced_protection_in_route53_hosted_zones import (
shield_advanced_protection_in_route53_hosted_zones,
)
check = shield_advanced_protection_in_route53_hosted_zones()
result = check.execute()
assert len(result) == 0

View File

@@ -0,0 +1,4 @@
from providers.aws.lib.audit_info.audit_info import current_audit_info
from providers.aws.services.shield.shield_service import Shield
shield_client = Shield(current_audit_info)

View File

@@ -0,0 +1,64 @@
from pydantic import BaseModel
from lib.logger import logger
from providers.aws.aws_provider import get_region_global_service
################### Shield
class Shield:
def __init__(self, audit_info):
self.service = "shield"
self.session = audit_info.audit_session
self.audited_account = audit_info.audited_account
self.client = self.session.client(self.service)
self.region = get_region_global_service(audit_info)
self.enabled = self.__get_subscription_state__()
self.protections = {}
self.__list_protections__()
def __get_session__(self):
return self.session
def __get_subscription_state__(self):
logger.info("Shield - Getting Subscription State...")
try:
return (
True
if self.client.get_subscription_state()["SubscriptionState"] == "ACTIVE"
else False
)
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __list_protections__(self):
logger.info("Shield - Listing Protections...")
try:
list_protections_paginator = self.client.get_paginator("list_protections")
for page in list_protections_paginator.paginate():
for protection in page["Protections"]:
protection_arn = protection.get("ProtectionArn")
protection_id = protection.get("Id")
protection_name = protection.get("Name")
resource_arn = protection.get("ResourceArn")
self.protections[protection_id] = Protection(
id=protection_id,
name=protection_name,
resource_arn=resource_arn,
protection_arn=protection_arn,
region=self.region,
)
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class Protection(BaseModel):
id: str
name: str
resource_arn: str
protection_arn: str = None
region: str

View File

@@ -0,0 +1,101 @@
from providers.aws.lib.audit_info.models import AWS_Audit_Info
from providers.aws.services.shield.shield_service import Shield
from mock import patch
from moto.core import DEFAULT_ACCOUNT_ID
import botocore
from boto3 import session
# Mock Test Region
AWS_REGION = "eu-west-1"
# Mocking Access Analyzer Calls
make_api_call = botocore.client.BaseClient._make_api_call
def mock_make_api_call(self, operation_name, kwarg):
"""We have to mock every AWS API call using Boto3"""
if operation_name == "ListProtections":
return {
"Protections": [
{
"Id": "a1b2c3d4-5678-90ab-cdef-EXAMPLE11111",
"Name": "Protection for CloudFront distribution",
"ResourceArn": f"arn:aws:cloudfront::{DEFAULT_ACCOUNT_ID}:distribution/E198WC25FXOWY8",
}
]
}
if operation_name == "GetSubscriptionState":
return {"SubscriptionState": "ACTIVE"}
return make_api_call(self, operation_name, kwarg)
# Patch every AWS call using Boto3 and generate_regional_clients to have 1 client
@patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call)
class Test_Shield_Service:
# Mocked Audit Info
def set_mocked_audit_info(self):
audit_info = AWS_Audit_Info(
original_session=None,
audit_session=session.Session(
profile_name=None,
botocore_session=None,
),
audited_account=DEFAULT_ACCOUNT_ID,
audited_user_id=None,
audited_partition="aws",
audited_identity_arn=None,
profile=None,
profile_region=AWS_REGION,
credentials=None,
assumed_role_info=None,
audited_regions=None,
organizations_metadata=None,
)
return audit_info
# Test Shield Service
def test_service(self):
# Shield client for this test class
audit_info = self.set_mocked_audit_info()
shield = Shield(audit_info)
assert shield.service == "shield"
# Test Shield Client
def test_client(self):
# Shield client for this test class
audit_info = self.set_mocked_audit_info()
shield = Shield(audit_info)
assert shield.client.__class__.__name__ == "Shield"
# Test Shield Session
def test__get_session__(self):
# Shield client for this test class
audit_info = self.set_mocked_audit_info()
shield = Shield(audit_info)
assert shield.session.__class__.__name__ == "Session"
def test__get_subscription_state__(self):
# Shield client for this test class
audit_info = self.set_mocked_audit_info()
shield = Shield(audit_info)
assert shield.enabled
def test__list_protections__(self):
# Shield client for this test class
audit_info = self.set_mocked_audit_info()
shield = Shield(audit_info)
protection_id = "a1b2c3d4-5678-90ab-cdef-EXAMPLE11111"
protection_name = "Protection for CloudFront distribution"
cloudfront_distribution_id = "E198WC25FXOWY8"
resource_arn = (
f"arn:aws:cloudfront::{DEFAULT_ACCOUNT_ID}:distribution/{cloudfront_distribution_id}",
)
assert shield.protections
assert len(shield.protections) == 1
assert shield.protections[protection_id]
assert shield.protections[protection_id].id == protection_id
assert shield.protections[protection_id].name == protection_name
assert not shield.protections[protection_id].protection_arn
assert not shield.protections[protection_id].resource_arn == resource_arn