feat(tags): add resource tags to G-R services (#2009)

This commit is contained in:
Sergio Garcia
2023-03-02 13:56:22 +01:00
committed by GitHub
parent 76bb418ea9
commit e8a1378ad0
58 changed files with 287 additions and 38 deletions

View File

@@ -1,5 +1,6 @@
import json
import threading
from typing import Optional
from botocore.client import ClientError
from pydantic import BaseModel
@@ -20,6 +21,7 @@ class Glacier:
self.vaults = {}
self.__threading_call__(self.__list_vaults__)
self.__threading_call__(self.__get_vault_access_policy__)
self.__list_tags_for_vault__()
def __get_session__(self):
return self.session
@@ -79,9 +81,24 @@ class Glacier:
f" {error}"
)
def __list_tags_for_vault__(self):
logger.info("Glacier - List Tags...")
try:
for vault in self.vaults.values():
regional_client = self.regional_clients[vault.region]
response = regional_client.list_tags_for_vault(vaultName=vault.name)[
"Tags"
]
vault.tags = [response]
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class Vault(BaseModel):
name: str
arn: str
region: str
access_policy: dict = {}
tags: Optional[list] = []

View File

@@ -10,7 +10,7 @@ class glacier_vaults_policy_public_access(Check):
report.region = vault.region
report.resource_id = vault.name
report.resource_arn = vault.arn
report.resource_tags = vault.tags
report.status = "PASS"
report.status_extended = (
f"Vault {vault.name} has policy which does not allow access to everyone"
@@ -19,10 +19,8 @@ class glacier_vaults_policy_public_access(Check):
public_access = False
if vault.access_policy:
for statement in vault.access_policy["Statement"]:
# Only check allow statements
if statement["Effect"] == "Allow":
if (
"*" in statement["Principal"]
or (

View File

@@ -10,6 +10,7 @@ class guardduty_is_enabled(Check):
report.region = detector.region
report.resource_id = detector.id
report.resource_arn = detector.arn
report.resource_tags = detector.tags
report.status = "PASS"
report.status_extended = f"GuardDuty detector {detector.id} enabled"
if detector.status is None:

View File

@@ -10,6 +10,7 @@ class guardduty_no_high_severity_findings(Check):
report.region = detector.region
report.resource_id = detector.id
report.resource_arn = detector.arn
report.resource_tags = detector.tags
report.status = "PASS"
report.status_extended = f"GuardDuty detector {detector.id} does not have high severity findings."
if len(detector.findings) > 0:

View File

@@ -1,4 +1,5 @@
import threading
from typing import Optional
from pydantic import BaseModel
@@ -12,12 +13,15 @@ class GuardDuty:
def __init__(self, audit_info):
self.service = "guardduty"
self.session = audit_info.audit_session
self.audited_account = audit_info.audited_account
self.audit_resources = audit_info.audit_resources
self.audited_partition = audit_info.audited_partition
self.regional_clients = generate_regional_clients(self.service, audit_info)
self.detectors = []
self.__threading_call__(self.__list_detectors__)
self.__get_detector__(self.regional_clients)
self.__list_findings__(self.regional_clients)
self.__list_tags_for_resource__()
def __get_session__(self):
return self.session
@@ -40,8 +44,11 @@ class GuardDuty:
if not self.audit_resources or (
is_resource_filtered(detector, self.audit_resources)
):
arn = f"arn:{self.audited_partition}:guardduty:{regional_client.region}:{self.audited_account}:detector/{detector}"
self.detectors.append(
Detector(id=detector, region=regional_client.region)
Detector(
id=detector, arn=arn, region=regional_client.region
)
)
except Exception as error:
logger.error(
@@ -93,11 +100,25 @@ class GuardDuty:
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __list_tags_for_resource__(self):
logger.info("Guardduty - List Tags...")
try:
for detector in self.detectors:
regional_client = self.regional_clients[detector.region]
response = regional_client.list_tags_for_resource(
ResourceArn=detector.arn
)["Tags"]
detector.tags = [response]
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class Detector(BaseModel):
id: str
# there is no arn for a guardduty detector but we want it filled for the reports
arn: str = ""
arn: str
region: str
status: bool = None
findings: list = []
tags: Optional[list] = []

View File

@@ -13,6 +13,7 @@ class iam_disable_30_days_credentials(Check):
report = Check_Report_AWS(self.metadata())
report.resource_id = user.name
report.resource_arn = user.arn
report.resource_tags = user.tags
report.region = iam_client.region
if user.password_last_used:
time_since_insertion = (

View File

@@ -13,6 +13,7 @@ class iam_disable_45_days_credentials(Check):
report = Check_Report_AWS(self.metadata())
report.resource_id = user.name
report.resource_arn = user.arn
report.resource_tags = user.tags
report.region = iam_client.region
if user.password_last_used:
time_since_insertion = (

View File

@@ -13,6 +13,7 @@ class iam_disable_90_days_credentials(Check):
report = Check_Report_AWS(self.metadata())
report.resource_id = user.name
report.resource_arn = user.arn
report.resource_tags = user.tags
report.region = iam_client.region
if user.password_last_used:
time_since_insertion = (

View File

@@ -12,6 +12,7 @@ class iam_role_cross_service_confused_deputy_prevention(Check):
report.region = iam_client.region
report.resource_arn = role.arn
report.resource_id = role.name
report.resource_tags = role.tags
report.status = "FAIL"
report.status_extended = f"IAM Service Role {role.name} prevents against a cross-service confused deputy attack"
for statement in role.assume_role_policy["Statement"]:

View File

@@ -59,6 +59,7 @@ class IAM:
self.__list_policies_version__(self.policies)
self.saml_providers = self.__list_saml_providers__()
self.server_certificates = self.__list_server_certificates__()
self.__list_tags_for_resource__()
def __get_client__(self):
return self.client
@@ -408,6 +409,25 @@ class IAM:
finally:
return server_certificates
def __list_tags_for_resource__(self):
logger.info("IAM - List Tags...")
try:
for role in self.roles:
response = self.client.list_role_tags(RoleName=role.name)["Tags"]
role.tags = response
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
try:
for user in self.users:
response = self.client.list_user_tags(UserName=user.name)["Tags"]
user.tags = response
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class MFADevice(BaseModel):
serial_number: str
@@ -421,6 +441,7 @@ class User(BaseModel):
password_last_used: Optional[datetime]
attached_policies: list[dict] = []
inline_policies: list[str] = []
tags: Optional[list] = []
class Role(BaseModel):
@@ -428,6 +449,7 @@ class Role(BaseModel):
arn: str
assume_role_policy: dict
is_service_role: bool
tags: Optional[list] = []
class Group(BaseModel):

View File

@@ -11,6 +11,7 @@ class iam_user_hardware_mfa_enabled(Check):
report = Check_Report_AWS(self.metadata())
report.resource_id = user.name
report.resource_arn = user.arn
report.resource_tags = user.tags
report.region = iam_client.region
if user.mfa_devices:
report.status = "PASS"

View File

@@ -12,6 +12,7 @@ class kms_cmk_are_used(Check):
report.region = key.region
report.resource_id = key.id
report.resource_arn = key.arn
report.resource_tags = key.tags
if key.state != "Enabled":
if key.state == "PendingDeletion":
report.status = "PASS"

View File

@@ -8,6 +8,7 @@ class kms_cmk_rotation_enabled(Check):
for key in kms_client.keys:
report = Check_Report_AWS(self.metadata())
report.region = key.region
report.resource_tags = key.tags
# Only check enabled CMKs keys
if (
key.manager == "CUSTOMER"

View File

@@ -14,6 +14,7 @@ class kms_key_not_publicly_accessible(Check):
report.status_extended = f"KMS key {key.id} is not exposed to Public."
report.resource_id = key.id
report.resource_arn = key.arn
report.resource_tags = key.tags
report.region = key.region
# If the "Principal" element value is set to { "AWS": "*" } and the policy statement is not using any Condition clauses to filter the access, the selected AWS KMS master key is publicly accessible.
if key.policy and "Statement" in key.policy:

View File

@@ -23,6 +23,7 @@ class KMS:
self.__describe_key__()
self.__get_key_rotation_status__()
self.__get_key_policy__()
self.__list_resource_tags__()
def __get_session__(self):
return self.session
@@ -109,6 +110,20 @@ class KMS:
f"{regional_client.region} -- {error.__class__.__name__}:{error.__traceback__.tb_lineno} -- {error}"
)
def __list_resource_tags__(self):
logger.info("KMS - List Tags...")
for key in self.keys:
try:
regional_client = self.regional_clients[key.region]
response = regional_client.list_resource_tags(
KeyId=key.id,
)["Tags"]
key.tags = response
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class Key(BaseModel):
id: str
@@ -120,3 +135,4 @@ class Key(BaseModel):
policy: Optional[dict]
spec: Optional[str]
region: str
tags: Optional[list] = []

View File

@@ -1,5 +1,6 @@
import threading
from dataclasses import dataclass
from pydantic import BaseModel
from prowler.lib.logger import logger
from prowler.providers.aws.aws_provider import generate_regional_clients
@@ -32,8 +33,8 @@ class Macie:
try:
self.sessions.append(
Session(
regional_client.get_macie_session()["status"],
regional_client.region,
status=regional_client.get_macie_session()["status"],
region=regional_client.region,
)
)
@@ -41,8 +42,8 @@ class Macie:
if "Macie is not enabled" in str(error):
self.sessions.append(
Session(
"DISABLED",
regional_client.region,
status="DISABLED",
region=regional_client.region,
)
)
else:
@@ -51,15 +52,6 @@ class Macie:
)
@dataclass
class Session:
class Session(BaseModel):
status: str
region: str
def __init__(
self,
status,
region,
):
self.status = status
self.region = region

View File

@@ -1,5 +1,6 @@
import threading
from json import loads
from typing import Optional
from pydantic import BaseModel
@@ -19,6 +20,7 @@ class OpenSearchService:
self.__threading_call__(self.__list_domain_names__)
self.__describe_domain_config__(self.regional_clients)
self.__describe_domain__(self.regional_clients)
self.__list_tags__()
def __get_session__(self):
return self.session
@@ -129,6 +131,20 @@ class OpenSearchService:
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __list_tags__(self):
logger.info("OpenSearch - List Tags...")
for domain in self.opensearch_domains:
try:
regional_client = self.regional_clients[domain.region]
response = regional_client.list_tags(
ARN=domain.arn,
)["TagList"]
domain.tags = response
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class PublishingLoggingOption(BaseModel):
name: str
@@ -150,3 +166,4 @@ class OpenSearchDomain(BaseModel):
internal_user_database: bool = None
update_available: bool = None
version: str = None
tags: Optional[list] = []

View File

@@ -12,6 +12,7 @@ class opensearch_service_domains_audit_logging_enabled(Check):
report.region = domain.region
report.resource_id = domain.name
report.resource_arn = domain.arn
report.resource_tags = domain.tags
report.status = "FAIL"
report.status_extended = (
f"Opensearch domain {domain.name} AUDIT_LOGS disabled"

View File

@@ -12,6 +12,7 @@ class opensearch_service_domains_cloudwatch_logging_enabled(Check):
report.region = domain.region
report.resource_id = domain.name
report.resource_arn = domain.arn
report.resource_tags = domain.tags
report.status = "FAIL"
report.status_extended = f"Opensearch domain {domain.name} SEARCH_SLOW_LOGS and INDEX_SLOW_LOGS disabled"
has_SEARCH_SLOW_LOGS = False

View File

@@ -12,6 +12,7 @@ class opensearch_service_domains_encryption_at_rest_enabled(Check):
report.region = domain.region
report.resource_id = domain.name
report.resource_arn = domain.arn
report.resource_tags = domain.tags
report.status = "PASS"
report.status_extended = (
f"Opensearch domain {domain.name} has encryption at-rest enabled"

View File

@@ -12,6 +12,7 @@ class opensearch_service_domains_https_communications_enforced(Check):
report.region = domain.region
report.resource_id = domain.name
report.resource_arn = domain.arn
report.resource_tags = domain.tags
report.status = "PASS"
report.status_extended = (
f"Opensearch domain {domain.name} has enforce HTTPS enabled"

View File

@@ -12,6 +12,7 @@ class opensearch_service_domains_internal_user_database_enabled(Check):
report.region = domain.region
report.resource_id = domain.name
report.resource_arn = domain.arn
report.resource_tags = domain.tags
report.status = "PASS"
report.status_extended = f"Opensearch domain {domain.name} does not have internal user database enabled"
if domain.internal_user_database:

View File

@@ -12,6 +12,7 @@ class opensearch_service_domains_node_to_node_encryption_enabled(Check):
report.region = domain.region
report.resource_id = domain.name
report.resource_arn = domain.arn
report.resource_tags = domain.tags
report.status = "PASS"
report.status_extended = (
f"Opensearch domain {domain.name} has node-to-node encryption enabled"

View File

@@ -12,6 +12,7 @@ class opensearch_service_domains_not_publicly_accessible(Check):
report.region = domain.region
report.resource_id = domain.name
report.resource_arn = domain.arn
report.resource_tags = domain.tags
report.status = "PASS"
report.status_extended = (
f"Opensearch domain {domain.name} does not allow anonymous access"

View File

@@ -12,6 +12,7 @@ class opensearch_service_domains_updated_to_the_latest_service_software_version(
report.region = domain.region
report.resource_id = domain.name
report.resource_arn = domain.arn
report.resource_tags = domain.tags
report.status = "PASS"
report.status_extended = f"Opensearch domain {domain.name} with version {domain.version} does not have internal updates available"
if domain.update_available:

View File

@@ -12,6 +12,7 @@ class opensearch_service_domains_use_cognito_authentication_for_kibana(Check):
report.region = domain.region
report.resource_id = domain.name
report.resource_arn = domain.arn
report.resource_tags = domain.tags
report.status = "PASS"
report.status_extended = f"Opensearch domain {domain.name} has Amazon Cognito authentication for Kibana enabled"
if not domain.cognito_options:

View File

@@ -9,6 +9,7 @@ class rds_instance_backup_enabled(Check):
report = Check_Report_AWS(self.metadata())
report.region = db_instance.region
report.resource_id = db_instance.id
report.resource_tags = db_instance.tags
if db_instance.backup_retention_period > 0:
report.status = "PASS"
report.status_extended = f"RDS Instance {db_instance.id} has backup enabled with retention period {db_instance.backup_retention_period} days."

View File

@@ -9,6 +9,7 @@ class rds_instance_deletion_protection(Check):
report = Check_Report_AWS(self.metadata())
report.region = db_instance.region
report.resource_id = db_instance.id
report.resource_tags = db_instance.tags
if db_instance.deletion_protection:
report.status = "PASS"
report.status_extended = (

View File

@@ -9,6 +9,7 @@ class rds_instance_enhanced_monitoring_enabled(Check):
report = Check_Report_AWS(self.metadata())
report.region = db_instance.region
report.resource_id = db_instance.id
report.resource_tags = db_instance.tags
if db_instance.enhanced_monitoring_arn:
report.status = "PASS"
report.status_extended = (

View File

@@ -9,6 +9,7 @@ class rds_instance_integration_cloudwatch_logs(Check):
report = Check_Report_AWS(self.metadata())
report.region = db_instance.region
report.resource_id = db_instance.id
report.resource_tags = db_instance.tags
if db_instance.cloudwatch_logs:
report.status = "PASS"
report.status_extended = f"RDS Instance {db_instance.id} is shipping {' '.join(db_instance.cloudwatch_logs)} to CloudWatch Logs."

View File

@@ -9,6 +9,7 @@ class rds_instance_minor_version_upgrade_enabled(Check):
report = Check_Report_AWS(self.metadata())
report.region = db_instance.region
report.resource_id = db_instance.id
report.resource_tags = db_instance.tags
if db_instance.auto_minor_version_upgrade:
report.status = "PASS"
report.status_extended = (

View File

@@ -9,6 +9,7 @@ class rds_instance_multi_az(Check):
report = Check_Report_AWS(self.metadata())
report.region = db_instance.region
report.resource_id = db_instance.id
report.resource_tags = db_instance.tags
if db_instance.multi_az:
report.status = "PASS"
report.status_extended = (

View File

@@ -9,6 +9,7 @@ class rds_instance_no_public_access(Check):
report = Check_Report_AWS(self.metadata())
report.region = db_instance.region
report.resource_id = db_instance.id
report.resource_tags = db_instance.tags
if not db_instance.public:
report.status = "PASS"
report.status_extended = (

View File

@@ -9,6 +9,7 @@ class rds_instance_storage_encrypted(Check):
report = Check_Report_AWS(self.metadata())
report.region = db_instance.region
report.resource_id = db_instance.id
report.resource_tags = db_instance.tags
if db_instance.encrypted:
report.status = "PASS"
report.status_extended = f"RDS Instance {db_instance.id} is encrypted."

View File

@@ -74,6 +74,7 @@ class RDS:
),
multi_az=instance["MultiAZ"],
region=regional_client.region,
tags=instance.get("TagList"),
)
)
except Exception as error:
@@ -100,6 +101,7 @@ class RDS:
id=snapshot["DBSnapshotIdentifier"],
instance_id=snapshot["DBInstanceIdentifier"],
region=regional_client.region,
tags=snapshot.get("TagList"),
)
)
except Exception as error:
@@ -144,6 +146,7 @@ class RDS:
id=snapshot["DBClusterSnapshotIdentifier"],
cluster_id=snapshot["DBClusterIdentifier"],
region=regional_client.region,
tags=snapshot.get("TagList"),
)
)
except Exception as error:
@@ -183,6 +186,7 @@ class DBInstance(BaseModel):
enhanced_monitoring_arn: Optional[str]
multi_az: bool
region: str
tags: Optional[list] = []
class DBSnapshot(BaseModel):
@@ -190,6 +194,7 @@ class DBSnapshot(BaseModel):
instance_id: str
public: bool = False
region: str
tags: Optional[list] = []
class ClusterSnapshot(BaseModel):
@@ -197,3 +202,4 @@ class ClusterSnapshot(BaseModel):
cluster_id: str
public: bool = False
region: str
tags: Optional[list] = []

View File

@@ -9,6 +9,7 @@ class rds_snapshots_public_access(Check):
report = Check_Report_AWS(self.metadata())
report.region = db_snap.region
report.resource_id = db_snap.id
report.resource_tags = db_snap.tags
if db_snap.public:
report.status = "FAIL"
report.status_extended = (
@@ -26,6 +27,7 @@ class rds_snapshots_public_access(Check):
report = Check_Report_AWS(self.metadata())
report.region = db_snap.region
report.resource_id = db_snap.id
report.resource_tags = db_snap.tags
if db_snap.public:
report.status = "FAIL"
report.status_extended = f"RDS Cluster Snapshot {db_snap.id} is public."

View File

@@ -10,6 +10,7 @@ class redshift_cluster_audit_logging(Check):
report.region = cluster.region
report.resource_id = cluster.id
report.resource_arn = cluster.arn
report.resource_tags = cluster.tags
report.status = "PASS"
report.status_extended = (
f"Redshift Cluster {cluster.arn} has audit logging enabled"

View File

@@ -10,6 +10,7 @@ class redshift_cluster_automated_snapshot(Check):
report.region = cluster.region
report.resource_id = cluster.id
report.resource_arn = cluster.arn
report.resource_tags = cluster.tags
report.status = "PASS"
report.status_extended = (
f"Redshift Cluster {cluster.arn} has automated snapshots"

View File

@@ -10,6 +10,7 @@ class redshift_cluster_automatic_upgrades(Check):
report.region = cluster.region
report.resource_id = cluster.id
report.resource_arn = cluster.arn
report.resource_tags = cluster.tags
report.status = "PASS"
report.status_extended = (
f"Redshift Cluster {cluster.arn} has AllowVersionUpgrade enabled"

View File

@@ -10,6 +10,7 @@ class redshift_cluster_public_access(Check):
report.region = cluster.region
report.resource_id = cluster.id
report.resource_arn = cluster.arn
report.resource_tags = cluster.tags
report.status = "PASS"
report.status_extended = (
f"Redshift Cluster {cluster.arn} is not publicly accessible"

View File

@@ -1,4 +1,5 @@
import threading
from typing import Optional
from pydantic import BaseModel
@@ -45,6 +46,7 @@ class Redshift:
cluster_to_append = Cluster(
id=cluster["ClusterIdentifier"],
region=regional_client.region,
tags=cluster.get("Tags"),
)
if (
"PubliclyAccessible" in cluster
@@ -114,3 +116,4 @@ class Cluster(BaseModel):
logging_enabled: bool = None
bucket: str = None
cluster_snapshots: bool = None
tags: Optional[list] = []

View File

@@ -12,7 +12,7 @@ class route53_domains_privacy_protection_enabled(Check):
report = Check_Report_AWS(self.metadata())
report.resource_id = domain.name
report.region = domain.region
report.resource_tags = domain.tags
if domain.admin_privacy:
report.status = "PASS"
report.status_extended = (

View File

@@ -12,7 +12,7 @@ class route53_domains_transferlock_enabled(Check):
report = Check_Report_AWS(self.metadata())
report.resource_id = domain.name
report.region = domain.region
report.resource_tags = domain.tags
if domain.status_list and "clientTransferProhibited" in domain.status_list:
report.status = "PASS"
report.status_extended = (

View File

@@ -10,6 +10,7 @@ class route53_public_hosted_zones_cloudwatch_logging_enabled(Check):
if not hosted_zone.private_zone:
report = Check_Report_AWS(self.metadata())
report.resource_id = hosted_zone.id
report.resource_tags = hosted_zone.tags
report.region = hosted_zone.region
if (
hosted_zone.logging_config

View File

@@ -1,3 +1,5 @@
from typing import Optional
from pydantic import BaseModel
from prowler.lib.logger import logger
@@ -21,6 +23,7 @@ class Route53:
self.region = self.client.region
self.__list_hosted_zones__()
self.__list_query_logging_configs__()
self.__list_tags_for_resource__()
def __get_session__(self):
return self.session
@@ -74,6 +77,19 @@ class Route53:
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __list_tags_for_resource__(self):
logger.info("Route53Domains - List Tags...")
for hosted_zone in self.hosted_zones.values():
try:
response = self.client.list_tags_for_resource(
ResourceType="hostedzone", ResourceId=hosted_zone.id
)["ResourceTagSet"]
hosted_zone.tags = response.get("Tags")
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class LoggingConfig(BaseModel):
cloudwatch_log_group_arn: str
@@ -86,6 +102,7 @@ class HostedZone(BaseModel):
private_zone: bool
logging_config: LoggingConfig = None
region: str
tags: Optional[list] = []
################## Route53Domains
@@ -102,6 +119,7 @@ class Route53Domains:
self.client = self.session.client(self.service, self.region)
self.__list_domains__()
self.__get_domain_detail__()
self.__list_tags_for_domain__()
def __get_session__(self):
return self.session
@@ -136,9 +154,23 @@ class Route53Domains:
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __list_tags_for_domain__(self):
logger.info("Route53Domains - List Tags...")
for domain in self.domains.values():
try:
response = self.client.list_tags_for_domain(
DomainName=domain.name,
)["TagList"]
domain.tags = response
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class Domain(BaseModel):
name: str
region: str
admin_privacy: bool = False
status_list: list[str] = None
tags: Optional[list] = []

View File

@@ -54,6 +54,9 @@ def mock_make_api_call(self, operation_name, kwarg):
if operation_name == "GetVaultAccessPolicy":
return {"policy": {"Policy": json.dumps(vault_json_policy)}}
if operation_name == "ListTagsForVault":
return {"Tags": {"test": "test"}}
return make_api_call(self, operation_name, kwarg)
@@ -99,6 +102,7 @@ class Test_Glacier_Service:
== f"arn:aws:glacier:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:vaults/examplevault"
)
assert glacier.vaults[vault_name].region == AWS_REGION
assert glacier.vaults[vault_name].tags == [{"test": "test"}]
def test__get_vault_access_policy__(self):
# Set partition for the service

View File

@@ -8,6 +8,9 @@ AWS_REGION = "eu-west-1"
AWS_ACCOUNT_NUMBER = "123456789012"
detector_id = str(uuid4())
detector_arn = (
f"arn:aws:guardduty:{AWS_REGION}:{AWS_ACCOUNT_NUMBER}:detector/{detector_id}"
)
class Test_guardduty_is_enabled:
@@ -33,6 +36,7 @@ class Test_guardduty_is_enabled:
Detector(
id=detector_id,
region=AWS_REGION,
arn=detector_arn,
status=True,
)
)
@@ -50,7 +54,7 @@ class Test_guardduty_is_enabled:
assert result[0].status == "PASS"
assert search("enabled", result[0].status_extended)
assert result[0].resource_id == detector_id
assert result[0].resource_arn == ""
assert result[0].resource_arn == detector_arn
def test_guardduty_configured_but_suspended(self):
guardduty_client = mock.MagicMock
@@ -58,6 +62,7 @@ class Test_guardduty_is_enabled:
guardduty_client.detectors.append(
Detector(
id=detector_id,
arn=detector_arn,
region=AWS_REGION,
status=False,
)
@@ -76,7 +81,7 @@ class Test_guardduty_is_enabled:
assert result[0].status == "FAIL"
assert search("configured but suspended", result[0].status_extended)
assert result[0].resource_id == detector_id
assert result[0].resource_arn == ""
assert result[0].resource_arn == detector_arn
def test_guardduty_not_configured(self):
guardduty_client = mock.MagicMock
@@ -84,6 +89,7 @@ class Test_guardduty_is_enabled:
guardduty_client.detectors.append(
Detector(
id=detector_id,
arn=detector_arn,
region=AWS_REGION,
)
)
@@ -101,4 +107,4 @@ class Test_guardduty_is_enabled:
assert result[0].status == "FAIL"
assert search("not configured", result[0].status_extended)
assert result[0].resource_id == detector_id
assert result[0].resource_arn == ""
assert result[0].resource_arn == detector_arn

View File

@@ -32,6 +32,7 @@ class Test_guardduty_no_high_severity_findings:
guardduty_client.detectors.append(
Detector(
id=detector_id,
arn="",
region=AWS_REGION,
)
)
@@ -58,7 +59,11 @@ class Test_guardduty_no_high_severity_findings:
guardduty_client.detectors = []
guardduty_client.detectors.append(
Detector(
id=detector_id, region=AWS_REGION, status=False, findings=[str(uuid4())]
id=detector_id,
region=AWS_REGION,
arn="",
status=False,
findings=[str(uuid4())],
)
)
with mock.patch(

View File

@@ -16,6 +16,8 @@ make_api_call = botocore.client.BaseClient._make_api_call
def mock_make_api_call(self, operation_name, kwarg):
if operation_name == "ListFindings":
return {"FindingIds": ["86c1d16c9ec63f634ccd087ae0d427ba1"]}
if operation_name == "ListTagsForResource":
return {"Tags": {"test": "test"}}
return make_api_call(self, operation_name, kwarg)
@@ -77,7 +79,7 @@ class Test_GuardDuty_Service:
# Test GuardDuty session
def test__list_detectors__(self):
guardduty_client = client("guardduty", region_name=AWS_REGION)
response = guardduty_client.create_detector(Enable=True)
response = guardduty_client.create_detector(Enable=True, Tags={"test": "test"})
audit_info = self.set_mocked_audit_info()
guardduty = GuardDuty(audit_info)
@@ -85,6 +87,7 @@ class Test_GuardDuty_Service:
assert len(guardduty.detectors) == 1
assert guardduty.detectors[0].id == response["DetectorId"]
assert guardduty.detectors[0].region == AWS_REGION
assert guardduty.detectors[0].tags == [{"test": "test"}]
@mock_guardduty
# Test GuardDuty session

View File

@@ -247,10 +247,16 @@ class Test_IAM_Service:
service_role = iam_client.create_role(
RoleName="test-1",
AssumeRolePolicyDocument=dumps(service_policy_document),
Tags=[
{"Key": "test", "Value": "test"},
],
)["Role"]
role = iam_client.create_role(
RoleName="test-2",
AssumeRolePolicyDocument=dumps(policy_document),
Tags=[
{"Key": "test", "Value": "test"},
],
)["Role"]
# IAM client for this test class
@@ -258,6 +264,12 @@ class Test_IAM_Service:
iam = IAM(audit_info)
assert len(iam.roles) == len(iam_client.list_roles()["Roles"])
assert iam.roles[0].tags == [
{"Key": "test", "Value": "test"},
]
assert iam.roles[1].tags == [
{"Key": "test", "Value": "test"},
]
assert is_service_role(service_role)
assert not is_service_role(role)
@@ -287,15 +299,27 @@ class Test_IAM_Service:
# Create 2 IAM Users
iam_client.create_user(
UserName="user1",
Tags=[
{"Key": "test", "Value": "test"},
],
)
iam_client.create_user(
UserName="user2",
Tags=[
{"Key": "test", "Value": "test"},
],
)
# IAM client for this test class
audit_info = self.set_mocked_audit_info()
iam = IAM(audit_info)
assert len(iam.users) == len(iam_client.list_users()["Users"])
assert iam.users[0].tags == [
{"Key": "test", "Value": "test"},
]
assert iam.users[1].tags == [
{"Key": "test", "Value": "test"},
]
# Test IAM Get Account Summary
@mock_iam

View File

@@ -88,7 +88,11 @@ class Test_ACM_Service:
# Generate KMS Client
kms_client = client("kms", region_name=AWS_REGION)
# Create KMS keys
key1 = kms_client.create_key()["KeyMetadata"]
key1 = kms_client.create_key(
Tags=[
{"TagKey": "test", "TagValue": "test"},
],
)["KeyMetadata"]
# KMS client for this test class
audit_info = self.set_mocked_audit_info()
kms = KMS(audit_info)
@@ -97,6 +101,9 @@ class Test_ACM_Service:
assert kms.keys[0].state == key1["KeyState"]
assert kms.keys[0].origin == key1["Origin"]
assert kms.keys[0].manager == key1["KeyManager"]
assert kms.keys[0].tags == [
{"TagKey": "test", "TagValue": "test"},
]
# Test KMS Get rotation status
@mock_kms

View File

@@ -8,8 +8,8 @@ class Test_macie_is_enabled:
macie_client = mock.MagicMock
macie_client.sessions = [
Session(
"DISABLED",
"eu-west-1",
status="DISABLED",
region="eu-west-1",
)
]
with mock.patch(
@@ -33,8 +33,8 @@ class Test_macie_is_enabled:
macie_client = mock.MagicMock
macie_client.sessions = [
Session(
"ENABLED",
"eu-west-1",
status="ENABLED",
region="eu-west-1",
)
]
with mock.patch(
@@ -58,8 +58,8 @@ class Test_macie_is_enabled:
macie_client = mock.MagicMock
macie_client.sessions = [
Session(
"PAUSED",
"eu-west-1",
status="PAUSED",
region="eu-west-1",
)
]
with mock.patch(

View File

@@ -66,8 +66,8 @@ class Test_Macie_Service:
macie = Macie(current_audit_info)
macie.sessions = [
Session(
"ENABLED",
"eu-west-1",
status="ENABLED",
region="eu-west-1",
)
]
assert len(macie.sessions) == 1

View File

@@ -82,6 +82,12 @@ def mock_make_api_call(self, operation_name, kwarg):
"AdvancedSecurityOptions": {"InternalUserDatabaseEnabled": True},
}
}
if operation_name == "ListTags":
return {
"TagList": [
{"Key": "test", "Value": "test"},
]
}
return make_api_call(self, operation_name, kwarg)
@@ -183,3 +189,6 @@ class Test_OpenSearchService_Service:
assert opensearch.opensearch_domains[0].internal_user_database
assert opensearch.opensearch_domains[0].update_available
assert opensearch.opensearch_domains[0].version == "opensearch-version1"
assert opensearch.opensearch_domains[0].tags == [
{"Key": "test", "Value": "test"},
]

View File

@@ -82,6 +82,9 @@ class Test_RDS_Service:
BackupRetentionPeriod=10,
EnableCloudwatchLogsExports=["audit", "error"],
MultiAZ=True,
Tags=[
{"Key": "test", "Value": "test"},
],
)
# RDS client for this test class
audit_info = self.set_mocked_audit_info()
@@ -101,6 +104,9 @@ class Test_RDS_Service:
assert rds.db_instances[0].deletion_protection
assert rds.db_instances[0].auto_minor_version_upgrade
assert rds.db_instances[0].multi_az
assert rds.db_instances[0].tags == [
{"Key": "test", "Value": "test"},
]
# Test RDS Describe DB Snapshots
@mock_rds

View File

@@ -110,6 +110,9 @@ class Test_Redshift_Service:
MasterUsername="user",
MasterUserPassword="password",
PubliclyAccessible=True,
Tags=[
{"Key": "test", "Value": "test"},
],
)
audit_info = self.set_mocked_audit_info()
redshift = Redshift(audit_info)
@@ -126,6 +129,9 @@ class Test_Redshift_Service:
redshift.clusters[0].allow_version_upgrade
== response["Cluster"]["AllowVersionUpgrade"]
)
assert redshift.clusters[0].tags == [
{"Key": "test", "Value": "test"},
]
@mock_redshift
def test_describe_logging_status(self):

View File

@@ -18,7 +18,16 @@ def mock_make_api_call(self, operation_name, kwarg):
"""We have to mock every AWS API call using Boto3"""
if operation_name == "DescribeDirectories":
return {}
if operation_name == "ListTagsForResource":
return {
"ResourceTagSet": {
"ResourceType": "hostedzone",
"ResourceId": "test",
"Tags": [
{"Key": "test", "Value": "test"},
],
}
}
return make_api_call(self, operation_name, kwarg)
@@ -107,6 +116,9 @@ class Test_Route53_Service:
== log_group_arn
)
assert route53.hosted_zones[hosted_zone_id].region == AWS_REGION
assert route53.hosted_zones[hosted_zone_id].tags == [
{"Key": "test", "Value": "test"},
]
@mock_route53
@mock_logs

View File

@@ -28,6 +28,12 @@ def mock_make_api_call(self, operation_name, kwarg):
],
"NextPageMarker": "string",
}
if operation_name == "ListTagsForDomain":
return {
"TagList": [
{"Key": "test", "Value": "test"},
]
}
if operation_name == "GetDomainDetail":
return {
"DomainName": "test.domain.com",
@@ -117,3 +123,6 @@ class Test_Route53_Service:
"clientTransferProhibited"
in route53domains.domains[domain_name].status_list
)
assert route53domains.domains[domain_name].tags == [
{"Key": "test", "Value": "test"},
]