feat(tags): add resource tags in A services (#1997)

This commit is contained in:
Sergio Garcia
2023-03-02 10:59:49 +01:00
committed by GitHub
parent eabccba3fa
commit 032feb343f
40 changed files with 195 additions and 123 deletions

View File

@@ -187,6 +187,7 @@ def add_html_header(file_descriptor, audit_info):
<th scope="col">Region</th>
<th style="width:20%" scope="col">Check Title</th>
<th scope="col">Resource ID</th>
<th scope="col">Resource Tags</th>
<th style="width:15%" scope="col">Check Description</th>
<th scope="col">Check ID</th>
<th scope="col">Status Extended</th>
@@ -221,6 +222,7 @@ def fill_html(file_descriptor, finding):
<td>{finding.region}</td>
<td>{finding.check_metadata.CheckTitle}</td>
<td>{finding.resource_id.replace("<", "&lt;").replace(">", "&gt;").replace("_", "<wbr>_")}</td>
<td>{str(finding.resource_tags)}</td>
<td>{finding.check_metadata.Description}</td>
<td>{finding.check_metadata.CheckID.replace("_", "<wbr>_")}</td>
<td>{finding.status_extended.replace("<", "&lt;").replace(">", "&gt;").replace("_", "<wbr>_")}</td>

View File

@@ -162,7 +162,7 @@ class Check_Output_CSV(BaseModel):
severity: str
resource_type: str
resource_details: str
resource_tags: list
resource_tags: Optional[list]
description: str
risk: str
related_url: str
@@ -235,6 +235,7 @@ def generate_provider_output_json(provider: str, finding, audit_info, mode: str,
finding_output.Region = finding.region
finding_output.ResourceId = finding.resource_id
finding_output.ResourceArn = finding.resource_arn
finding_output.ResourceTags = finding.resource_tags
finding_output.FindingUniqueId = f"prowler-{provider}-{finding.check_metadata.CheckID}-{audit_info.audited_account}-{finding.region}-{finding.resource_id}"
if audit_info.organizations_metadata:
@@ -292,6 +293,7 @@ class Aws_Check_Output_JSON(Check_Output_JSON):
Region: str = ""
ResourceId: str = ""
ResourceArn: str = ""
ResourceTags: list = []
def __init__(self, **metadata):
super().__init__(**metadata)
@@ -299,7 +301,7 @@ class Aws_Check_Output_JSON(Check_Output_JSON):
class Azure_Check_Output_JSON(Check_Output_JSON):
"""
Aws_Check_Output_JSON generates a finding's output in JSON format for the AWS provider.
Azure_Check_Output_JSON generates a finding's output in JSON format for the AWS provider.
"""
Tenant_Domain: str = ""

View File

@@ -17,6 +17,7 @@ class accessanalyzer_enabled(Check):
)
report.resource_id = analyzer.name
report.resource_arn = analyzer.arn
report.resource_tags = analyzer.tags
elif analyzer.status == "NOT_AVAILABLE":
report.status = "FAIL"
@@ -31,6 +32,7 @@ class accessanalyzer_enabled(Check):
)
report.resource_id = analyzer.name
report.resource_arn = analyzer.arn
report.resource_tags = analyzer.tags
findings.append(report)
return findings

View File

@@ -17,6 +17,7 @@ class accessanalyzer_enabled_without_findings(Check):
)
report.resource_id = analyzer.name
report.resource_arn = analyzer.arn
report.resource_tags = analyzer.tags
if len(analyzer.findings) != 0:
active_finding_counter = 0
for finding in analyzer.findings:
@@ -28,6 +29,7 @@ class accessanalyzer_enabled_without_findings(Check):
report.status_extended = f"IAM Access Analyzer {analyzer.name} has {active_finding_counter} active findings"
report.resource_id = analyzer.name
report.resource_arn = analyzer.arn
report.resource_tags = analyzer.tags
elif analyzer.status == "NOT_AVAILABLE":
report.status = "FAIL"
report.status_extended = (
@@ -41,6 +43,7 @@ class accessanalyzer_enabled_without_findings(Check):
)
report.resource_id = analyzer.name
report.resource_arn = analyzer.arn
report.resource_tags = analyzer.tags
findings.append(report)
return findings

View File

@@ -1,4 +1,5 @@
import threading
from typing import Optional
from pydantic import BaseModel
@@ -48,7 +49,7 @@ class AccessAnalyzer:
arn=analyzer["arn"],
name=analyzer["name"],
status=analyzer["status"],
tags=str(analyzer["tags"]),
tags=[analyzer.get("tags")],
type=analyzer["type"],
region=regional_client.region,
)
@@ -60,7 +61,7 @@ class AccessAnalyzer:
arn="",
name=self.audited_account,
status="NOT_AVAILABLE",
tags="",
tags=[],
type="",
region=regional_client.region,
)
@@ -119,6 +120,6 @@ class Analyzer(BaseModel):
name: str
status: str
findings: list[Finding] = []
tags: str
tags: Optional[list] = []
type: str
region: str

View File

@@ -15,11 +15,13 @@ class acm_certificates_expiration_check(Check):
report.status_extended = f"ACM Certificate for {certificate.name} expires in {certificate.expiration_days} days."
report.resource_id = certificate.name
report.resource_arn = certificate.arn
report.resource_tags = certificate.tags
else:
report.status = "FAIL"
report.status_extended = f"ACM Certificate for {certificate.name} is about to expire in {DAYS_TO_EXPIRE_THRESHOLD} days."
report.resource_id = certificate.name
report.resource_arn = certificate.arn
report.resource_tags = certificate.tags
findings.append(report)
return findings

View File

@@ -15,16 +15,19 @@ class acm_certificates_transparency_logs_enabled(Check):
)
report.resource_id = certificate.name
report.resource_arn = certificate.arn
report.resource_tags = certificate.tags
else:
if not certificate.transparency_logging:
report.status = "FAIL"
report.status_extended = f"ACM Certificate for {certificate.name} has Certificate Transparency logging disabled."
report.resource_id = certificate.name
report.resource_arn = certificate.arn
report.resource_tags = certificate.tags
else:
report.status = "PASS"
report.status_extended = f"ACM Certificate for {certificate.name} has Certificate Transparency logging enabled."
report.resource_id = certificate.name
report.resource_arn = certificate.arn
report.resource_tags = certificate.tags
findings.append(report)
return findings

View File

@@ -20,6 +20,7 @@ class ACM:
self.certificates = []
self.__threading_call__(self.__list_certificates__)
self.__describe_certificates__()
self.__list_tags_for_certificate__()
def __get_session__(self):
return self.session
@@ -91,11 +92,26 @@ class ACM:
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __list_tags_for_certificate__(self):
logger.info("ACM - List Tags...")
try:
for certificate in self.certificates:
regional_client = self.regional_clients[certificate.region]
response = regional_client.list_tags_for_certificate(
CertificateArn=certificate.arn
)["Tags"]
certificate.tags = response
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class Certificate(BaseModel):
arn: str
name: str
type: str
tags: Optional[list] = []
expiration_days: int
transparency_logging: Optional[bool]
region: str

View File

@@ -12,6 +12,7 @@ class apigateway_authorizers_enabled(Check):
report.region = rest_api.region
report.resource_id = rest_api.name
report.resource_arn = rest_api.arn
report.resource_tags = rest_api.tags
if rest_api.authorizer:
report.status = "PASS"
report.status_extended = f"API Gateway {rest_api.name} ID {rest_api.id} has authorizer configured."

View File

@@ -13,6 +13,7 @@ class apigateway_client_certificate_enabled(Check):
report.resource_id = rest_api.name
report.region = rest_api.region
report.resource_arn = stage.arn
report.resource_tags = stage.tags
if stage.client_certificate:
report.status = "PASS"
report.status_extended = f"API Gateway {rest_api.name} ID {rest_api.id} in stage {stage.name} has client certificate enabled."

View File

@@ -12,6 +12,7 @@ class apigateway_endpoint_public(Check):
report.region = rest_api.region
report.resource_id = rest_api.name
report.resource_arn = rest_api.arn
report.resource_tags = rest_api.tags
if rest_api.public_endpoint:
report.status = "FAIL"
report.status_extended = f"API Gateway {rest_api.name} ID {rest_api.id} is internet accesible."

View File

@@ -13,6 +13,7 @@ class apigateway_logging_enabled(Check):
report.region = rest_api.region
report.resource_id = rest_api.name
report.resource_arn = stage.arn
report.resource_tags = stage.tags
if stage.logging:
report.status = "PASS"
report.status_extended = f"API Gateway {rest_api.name} ID {rest_api.id} in stage {stage.name} has logging enabled."

View File

@@ -1,5 +1,7 @@
import threading
from dataclasses import dataclass
from typing import Optional
from pydantic import BaseModel
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
@@ -45,10 +47,11 @@ class APIGateway:
):
self.rest_apis.append(
RestAPI(
apigw["id"],
arn,
regional_client.region,
apigw["name"],
id=apigw["id"],
arn=arn,
region=regional_client.region,
name=apigw["name"],
tags=[apigw.get("tags")],
)
)
except Exception as error:
@@ -100,61 +103,33 @@ class APIGateway:
arn = f"arn:{self.audited_partition}:apigateway:{regional_client.region}::/apis/{rest_api.id}/stages/{stage['stageName']}"
rest_api.stages.append(
Stage(
stage["stageName"],
arn,
logging,
client_certificate,
waf,
name=stage["stageName"],
arn=arn,
logging=logging,
client_certificate=client_certificate,
waf=waf,
tags=[stage.get("tags")],
)
)
except Exception as error:
logger.error(f"{error.__class__.__name__}: {error}")
@dataclass
class Stage:
class Stage(BaseModel):
name: str
arn: str
logging: bool
client_certificate: bool
waf: str
def __init__(
self,
name,
arn,
logging,
client_certificate,
waf,
):
self.name = name
self.arn = arn
self.logging = logging
self.client_certificate = client_certificate
self.waf = waf
waf: Optional[str]
tags: Optional[list] = []
@dataclass
class RestAPI:
class RestAPI(BaseModel):
id: str
arn: str
region: str
name: str
authorizer: bool
public_endpoint: bool
stages: list[Stage]
def __init__(
self,
id,
arn,
region,
name,
):
self.id = id
self.arn = arn
self.region = region
self.name = name
self.authorizer = False
self.public_endpoint = True
self.stages = []
authorizer: bool = False
public_endpoint: bool = True
stages: list[Stage] = []
tags: Optional[list] = []

View File

@@ -13,6 +13,7 @@ class apigateway_waf_acl_attached(Check):
report.region = rest_api.region
report.resource_id = rest_api.name
report.resource_arn = stage.arn
report.resource_tags = stage.tags
if stage.waf:
report.status = "PASS"
report.status_extended = f"API Gateway {rest_api.name} ID {rest_api.id} in stage {stage.name} has {stage.waf} WAF ACL attached."

View File

@@ -15,10 +15,12 @@ class apigatewayv2_access_logging_enabled(Check):
report.status = "PASS"
report.status_extended = f"API Gateway V2 {api.name} ID {api.id} in stage {stage.name} has access logging enabled."
report.resource_id = api.name
report.resource_tags = api.tags
else:
report.status = "FAIL"
report.status_extended = f"API Gateway V2 {api.name} ID {api.id} in stage {stage.name} has access logging disabled."
report.resource_id = api.name
report.resource_tags = api.tags
findings.append(report)
return findings

View File

@@ -16,10 +16,12 @@ class apigatewayv2_authorizers_enabled(Check):
f"API Gateway V2 {api.name} ID {api.id} has authorizer configured."
)
report.resource_id = api.name
report.resource_tags = api.tags
else:
report.status = "FAIL"
report.status_extended = f"API Gateway V2 {api.name} ID {api.id} has not authorizer configured."
report.resource_id = api.name
report.resource_tags = api.tags
findings.append(report)
return findings

View File

@@ -1,5 +1,7 @@
import threading
from dataclasses import dataclass
from typing import Optional
from pydantic import BaseModel
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
@@ -42,9 +44,10 @@ class ApiGatewayV2:
):
self.apis.append(
API(
apigw["ApiId"],
regional_client.region,
apigw["Name"],
id=apigw["ApiId"],
region=regional_client.region,
name=apigw["Name"],
tags=[apigw.get("Tags")],
)
)
except Exception as error:
@@ -77,8 +80,9 @@ class ApiGatewayV2:
logging = True
api.stages.append(
Stage(
stage["StageName"],
logging,
name=stage["StageName"],
logging=logging,
tags=[stage.get("Tags")],
)
)
except Exception as error:
@@ -87,36 +91,16 @@ class ApiGatewayV2:
)
@dataclass
class Stage:
class Stage(BaseModel):
name: str
logging: bool
def __init__(
self,
name,
logging,
):
self.name = name
self.logging = logging
tags: Optional[list] = []
@dataclass
class API:
class API(BaseModel):
id: str
region: str
name: str
authorizer: bool
stages: list[Stage]
def __init__(
self,
id,
region,
name,
):
self.id = id
self.region = region
self.name = name
self.authorizer = False
self.stages = []
authorizer: bool = False
stages: list[Stage] = []
tags: Optional[list] = []

View File

@@ -14,6 +14,7 @@ class appstream_fleet_default_internet_access_disabled(Check):
report.region = fleet.region
report.resource_id = fleet.name
report.resource_arn = fleet.arn
report.resource_tags = fleet.tags
if fleet.enable_default_internet_access:
report.status = "FAIL"

View File

@@ -17,6 +17,7 @@ class appstream_fleet_maximum_session_duration(Check):
report.region = fleet.region
report.resource_id = fleet.name
report.resource_arn = fleet.arn
report.resource_tags = fleet.tags
if fleet.max_user_duration_in_seconds < max_session_duration_seconds:
report.status = "PASS"

View File

@@ -17,6 +17,7 @@ class appstream_fleet_session_disconnect_timeout(Check):
report.region = fleet.region
report.resource_id = fleet.name
report.resource_arn = fleet.arn
report.resource_tags = fleet.tags
if fleet.disconnect_timeout_in_seconds <= max_disconnect_timeout_in_seconds:
report.status = "PASS"

View File

@@ -19,6 +19,7 @@ class appstream_fleet_session_idle_disconnect_timeout(Check):
report.region = fleet.region
report.resource_id = fleet.name
report.resource_arn = fleet.arn
report.resource_tags = fleet.tags
if (
fleet.idle_disconnect_timeout_in_seconds

View File

@@ -18,6 +18,7 @@ class AppStream:
self.regional_clients = generate_regional_clients(self.service, audit_info)
self.fleets = []
self.__threading_call__(self.__describe_fleets__)
self.__list_tags_for_resource__()
def __get_session__(self):
return self.session
@@ -65,6 +66,20 @@ class AppStream:
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __list_tags_for_resource__(self):
logger.info("AppStream - List Tags...")
try:
for fleet in self.fleets:
regional_client = self.regional_clients[fleet.region]
response = regional_client.list_tags_for_resource(
ResourceArn=fleet.arn
)["Tags"]
fleet.tags = [response]
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class Fleet(BaseModel):
arn: str
@@ -74,3 +89,4 @@ class Fleet(BaseModel):
idle_disconnect_timeout_in_seconds: Optional[int]
enable_default_internet_access: bool
region: str
tags: Optional[list] = []

View File

@@ -1,5 +1,6 @@
import threading
from dataclasses import dataclass
from pydantic import BaseModel
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
@@ -45,11 +46,11 @@ class AutoScaling:
):
self.launch_configurations.append(
LaunchConfiguration(
configuration["LaunchConfigurationARN"],
configuration["LaunchConfigurationName"],
configuration["UserData"],
configuration["ImageId"],
regional_client.region,
arn=configuration["LaunchConfigurationARN"],
name=configuration["LaunchConfigurationName"],
user_data=configuration["UserData"],
image_id=configuration["ImageId"],
region=regional_client.region,
)
)
@@ -59,24 +60,9 @@ class AutoScaling:
)
@dataclass
class LaunchConfiguration:
class LaunchConfiguration(BaseModel):
arn: str
name: str
user_data: str
image_id: int
image_id: str
region: str
def __init__(
self,
arn,
name,
user_data,
image_id,
region,
):
self.arn = arn
self.name = name
self.image_id = image_id
self.user_data = user_data
self.region = region

View File

@@ -13,6 +13,7 @@ class awslambda_function_invoke_api_operations_cloudtrail_logging_enabled(Check)
report.region = function.region
report.resource_id = function.name
report.resource_arn = function.arn
report.resource_tags = function.tags
report.status = "FAIL"
report.status_extended = (

View File

@@ -17,6 +17,7 @@ class awslambda_function_no_secrets_in_code(Check):
report.region = function.region
report.resource_id = function.name
report.resource_arn = function.arn
report.resource_tags = function.tags
report.status = "PASS"
report.status_extended = (

View File

@@ -17,6 +17,7 @@ class awslambda_function_no_secrets_in_variables(Check):
report.region = function.region
report.resource_id = function.name
report.resource_arn = function.arn
report.resource_tags = function.tags
report.status = "PASS"
report.status_extended = (

View File

@@ -10,6 +10,7 @@ class awslambda_function_not_publicly_accessible(Check):
report.region = function.region
report.resource_id = function.name
report.resource_arn = function.arn
report.resource_tags = function.tags
report.status = "PASS"
report.status_extended = f"Lambda function {function.name} has a policy resource-based policy not public"

View File

@@ -10,6 +10,7 @@ class awslambda_function_url_cors_policy(Check):
report.region = function.region
report.resource_id = function.name
report.resource_arn = function.arn
report.resource_tags = function.tags
if function.url_config:
if "*" in function.url_config.cors_config.allow_origins:
report.status = "FAIL"

View File

@@ -11,6 +11,7 @@ class awslambda_function_url_public(Check):
report.region = function.region
report.resource_id = function.name
report.resource_arn = function.arn
report.resource_tags = function.tags
if function.url_config:
if function.url_config.auth_type == AuthType.AWS_IAM:
report.status = "PASS"

View File

@@ -12,6 +12,7 @@ class awslambda_function_using_supported_runtimes(Check):
report.region = function.region
report.resource_id = function.name
report.resource_arn = function.arn
report.resource_tags = function.tags
if function.runtime in get_config_var("obsolete_lambda_runtimes"):
report.status = "FAIL"

View File

@@ -24,6 +24,7 @@ class Lambda:
self.regional_clients = generate_regional_clients(self.service, audit_info)
self.functions = {}
self.__threading_call__(self.__list_functions__)
self.__list_tags_for_resource__()
# We only want to retrieve the Lambda code if the
# awslambda_function_no_secrets_in_code check is set
@@ -156,6 +157,18 @@ class Lambda:
f" {error}"
)
def __list_tags_for_resource__(self):
logger.info("Lambda - List Tags...")
try:
for function in self.functions.values():
regional_client = self.regional_clients[function.region]
response = regional_client.list_tags(Resource=function.arn)["Tags"]
function.tags = [response]
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class LambdaCode(BaseModel):
location: str
@@ -186,3 +199,4 @@ class Function(BaseModel):
policy: dict = None
code: LambdaCode = None
url_config: URLConfig = None
tags: Optional[list] = []

View File

@@ -31,7 +31,7 @@ class Test_accessanalyzer_enabled:
arn="",
name="012345678910",
status="NOT_AVAILABLE",
tags="",
tags=[],
type="",
region="eu-west-1",
)
@@ -62,7 +62,7 @@ class Test_accessanalyzer_enabled:
arn="",
name="012345678910",
status="NOT_AVAILABLE",
tags="",
tags=[],
type="",
region="eu-west-1",
),
@@ -70,7 +70,7 @@ class Test_accessanalyzer_enabled:
arn="",
name="Test Analyzer",
status="ACTIVE",
tags="",
tags=[],
type="",
region="eu-west-2",
),
@@ -112,7 +112,7 @@ class Test_accessanalyzer_enabled:
arn="",
name="Test Analyzer",
status="ACTIVE",
tags="",
tags=[],
type="",
region="eu-west-2",
)

View File

@@ -32,7 +32,7 @@ class Test_accessanalyzer_enabled_without_findings:
arn="",
name="012345678910",
status="NOT_AVAILABLE",
tags="",
tags=[],
type="",
region="eu-west-1",
)
@@ -63,7 +63,7 @@ class Test_accessanalyzer_enabled_without_findings:
arn="",
name="012345678910",
status="NOT_AVAILABLE",
tags="",
tags=[],
type="",
region="eu-west-1",
),
@@ -81,7 +81,7 @@ class Test_accessanalyzer_enabled_without_findings:
status="ARCHIVED",
),
],
tags="",
tags=[],
type="",
region="eu-west-2",
),
@@ -123,7 +123,7 @@ class Test_accessanalyzer_enabled_without_findings:
arn="",
name="Test Analyzer",
status="ACTIVE",
tags="",
tags=[],
type="",
region="eu-west-2",
)
@@ -157,7 +157,7 @@ class Test_accessanalyzer_enabled_without_findings:
arn="",
name="012345678910",
status="NOT_AVAILABLE",
tags="",
tags=[],
type="",
region="eu-west-1",
),

View File

@@ -30,7 +30,7 @@ def mock_make_api_call(self, operation_name, kwarg):
"name": "Test Analyzer",
"status": "ACTIVE",
"findings": 0,
"tags": "",
"tags": {"test": "test"},
"type": "ACCOUNT",
"region": "eu-west-1",
}
@@ -92,7 +92,7 @@ class Test_AccessAnalyzer_Service:
assert access_analyzer.analyzers[0].arn == "ARN"
assert access_analyzer.analyzers[0].name == "Test Analyzer"
assert access_analyzer.analyzers[0].status == "ACTIVE"
assert access_analyzer.analyzers[0].tags == ""
assert access_analyzer.analyzers[0].tags == [{"test": "test"}]
assert access_analyzer.analyzers[0].type == "ACCOUNT"
assert access_analyzer.analyzers[0].region == AWS_REGION

View File

@@ -67,6 +67,14 @@ def mock_make_api_call(self, operation_name, kwargs):
"Options": {"CertificateTransparencyLoggingPreference": "DISABLED"},
}
}
if operation_name == "ListTagsForCertificate":
if kwargs["CertificateArn"] == certificate_arn:
return {
"Tags": [
{"Key": "test", "Value": "test"},
]
}
return make_api_call(self, operation_name, kwargs)
@@ -163,3 +171,21 @@ class Test_ACM_Service:
assert acm.certificates[0].expiration_days == 365
assert acm.certificates[0].transparency_logging is False
assert acm.certificates[0].region == AWS_REGION
# Test ACM List Tags
# @mock_acm
def test__list_tags_for_certificate__(self):
# Generate ACM Client
# acm_client = client("acm", region_name=AWS_REGION)
# Request ACM certificate
# certificate = acm_client.request_certificate(
# DomainName="test.com",
# )
# ACM client for this test class
audit_info = self.set_mocked_audit_info()
acm = ACM(audit_info)
assert len(acm.certificates) == 1
assert acm.certificates[0].tags == [
{"Key": "test", "Value": "test"},
]

View File

@@ -106,7 +106,6 @@ class Test_apigateway_client_certificate_enabled:
@mock_apigateway
def test_apigateway_one_stage_with_certificate(self):
# Create APIGateway Mocked Resources
apigateway_client = client("apigateway", region_name=AWS_REGION)
# Create APIGateway Deployment Stage
@@ -131,8 +130,8 @@ class Test_apigateway_client_certificate_enabled:
service_client.rest_apis[0].stages.append(
Stage(
"test",
f"arn:{current_audit_info.audited_partition}:apigateway:{AWS_REGION}::/apis/test-rest-api/stages/test",
name="test",
arn=f"arn:{current_audit_info.audited_partition}:apigateway:{AWS_REGION}::/apis/test-rest-api/stages/test",
logging=True,
client_certificate=True,
waf=True,

View File

@@ -108,12 +108,15 @@ class Test_APIGateway_Service:
apigateway_client = client("apigateway", region_name=AWS_REGION)
# Create private APIGateway Rest API
apigateway_client.create_rest_api(
name="test-rest-api", endpointConfiguration={"types": ["PRIVATE"]}
name="test-rest-api",
endpointConfiguration={"types": ["PRIVATE"]},
tags={"test": "test"},
)
# APIGateway client for this test class
audit_info = self.set_mocked_audit_info()
apigateway = APIGateway(audit_info)
assert apigateway.rest_apis[0].public_endpoint is False
assert apigateway.rest_apis[0].tags == [{"test": "test"}]
# Test APIGateway Get Stages
@mock_apigateway

View File

@@ -102,11 +102,14 @@ class Test_ApiGatewayV2_Service:
# Generate ApiGatewayV2 Client
apigatewayv2_client = client("apigatewayv2", region_name=AWS_REGION)
# Create ApiGatewayV2 API
apigatewayv2_client.create_api(Name="test-api", ProtocolType="HTTP")
apigatewayv2_client.create_api(
Name="test-api", ProtocolType="HTTP", Tags={"test": "test"}
)
# ApiGatewayV2 client for this test class
audit_info = self.set_mocked_audit_info()
apigatewayv2 = ApiGatewayV2(audit_info)
assert len(apigatewayv2.apis) == len(apigatewayv2_client.get_apis()["Items"])
assert apigatewayv2.apis[0].tags == [{"test": "test"}]
# Test ApiGatewayV2 Get Authorizers
@mock_apigatewayv2

View File

@@ -43,6 +43,8 @@ def mock_make_api_call(self, operation_name, kwarg):
},
]
}
if operation_name == "ListTagsForResource":
return {"Tags": {"test": "test"}}
return make_api_call(self, operation_name, kwarg)
@@ -102,3 +104,13 @@ class Test_AppStream_Service:
assert appstream.fleets[1].idle_disconnect_timeout_in_seconds == 900
assert appstream.fleets[1].enable_default_internet_access is True
assert appstream.fleets[1].region == AWS_REGION
def test__list_tags_for_resource__(self):
# Set partition for the service
current_audit_info.audited_partition = "aws"
appstream = AppStream(current_audit_info)
assert len(appstream.fleets) == 2
assert appstream.fleets[0].tags == [{"test": "test"}]
assert appstream.fleets[1].tags == [{"test": "test"}]

View File

@@ -137,6 +137,7 @@ class Test_Lambda_Service:
"SubnetIds": ["subnet-123abc"],
},
Environment={"Variables": {"db-password": "test-password"}},
Tags={"test": "test"},
)
# Update Lambda Policy
lambda_policy = {
@@ -218,6 +219,8 @@ class Test_Lambda_Service:
lambda_name
].url_config.cors_config.allow_origins == ["*"]
assert awslambda.functions[lambda_name].tags == [{"test": "test"}]
# Pending ZipFile tests
with tempfile.TemporaryDirectory() as tmp_dir_name:
awslambda.functions[lambda_name].code.code_zip.extractall(tmp_dir_name)