fix(iam): refactor IAM service (#2010)

This commit is contained in:
Nacho Rivera
2023-03-02 11:16:05 +01:00
committed by GitHub
parent 032feb343f
commit eed7ab9793
4 changed files with 127 additions and 181 deletions

View File

@@ -5,14 +5,15 @@ from prowler.providers.aws.services.iam.iam_client import iam_client
class iam_check_saml_providers_sts(Check):
def execute(self) -> Check_Report_AWS:
findings = []
for provider in iam_client.saml_providers:
report = Check_Report_AWS(self.metadata())
provider_name = provider["Arn"].split("/")[1]
report.resource_id = provider_name
report.resource_arn = provider["Arn"]
report.region = iam_client.region
report.status = "PASS"
report.status_extended = f"SAML Provider {provider_name} has been found"
findings.append(report)
if iam_client.saml_providers:
for provider in iam_client.saml_providers:
report = Check_Report_AWS(self.metadata())
provider_name = provider["Arn"].split("/")[1]
report.resource_id = provider_name
report.resource_arn = provider["Arn"]
report.region = iam_client.region
report.status = "PASS"
report.status_extended = f"SAML Provider {provider_name} has been found"
findings.append(report)
return findings

View File

@@ -7,28 +7,29 @@ class iam_root_hardware_mfa_enabled(Check):
findings = []
# This check is only avaible in Commercial Partition
if iam_client.partition == "aws":
virtual_mfa = False
report = Check_Report_AWS(self.metadata())
report.region = iam_client.region
report.resource_id = "root"
report.resource_arn = f"arn:aws:iam::{iam_client.account}:root"
if iam_client.account_summary:
virtual_mfa = False
report = Check_Report_AWS(self.metadata())
report.region = iam_client.region
report.resource_id = "root"
report.resource_arn = f"arn:aws:iam::{iam_client.account}:root"
if iam_client.account_summary["SummaryMap"]["AccountMFAEnabled"] > 0:
virtual_mfas = iam_client.virtual_mfa_devices
for mfa in virtual_mfas:
if "root" in mfa["SerialNumber"]:
virtual_mfa = True
report.status = "FAIL"
report.status_extended = "Root account has a virtual MFA instead of a hardware MFA device enabled."
if not virtual_mfa:
report.status = "PASS"
report.status_extended = (
"Root account has a hardware MFA device enabled."
)
else:
report.status = "FAIL"
report.status_extended = "MFA is not enabled for root account."
if iam_client.account_summary["SummaryMap"]["AccountMFAEnabled"] > 0:
virtual_mfas = iam_client.virtual_mfa_devices
for mfa in virtual_mfas:
if "root" in mfa["SerialNumber"]:
virtual_mfa = True
report.status = "FAIL"
report.status_extended = "Root account has a virtual MFA instead of a hardware MFA device enabled."
if not virtual_mfa:
report.status = "PASS"
report.status_extended = (
"Root account has a hardware MFA device enabled."
)
else:
report.status = "FAIL"
report.status_extended = "MFA is not enabled for root account."
findings.append(report)
findings.append(report)
return findings

View File

@@ -5,19 +5,19 @@ from prowler.providers.aws.services.iam.iam_client import iam_client
class iam_root_mfa_enabled(Check):
def execute(self) -> Check_Report_AWS:
findings = []
for user in iam_client.credential_report:
if user["user"] == "<root_account>":
report = Check_Report_AWS(self.metadata())
report.region = iam_client.region
report.resource_id = user["user"]
report.resource_arn = user["arn"]
if user["mfa_active"] == "false":
report.status = "FAIL"
report.status_extended = "MFA is not enabled for root account."
else:
report.status = "PASS"
report.status_extended = "MFA is enabled for root account."
findings.append(report)
if iam_client.credential_report:
for user in iam_client.credential_report:
if user["user"] == "<root_account>":
report = Check_Report_AWS(self.metadata())
report.region = iam_client.region
report.resource_id = user["user"]
report.resource_arn = user["arn"]
if user["mfa_active"] == "false":
report.status = "FAIL"
report.status_extended = "MFA is not enabled for root account."
else:
report.status = "PASS"
report.status_extended = "MFA is enabled for root account."
findings.append(report)
return findings

View File

@@ -1,6 +1,8 @@
import csv
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
@@ -66,8 +68,8 @@ class IAM:
def __get_roles__(self):
try:
get_roles_paginator = self.client.get_paginator("list_roles")
roles = []
get_roles_paginator = self.client.get_paginator("list_roles")
for page in get_roles_paginator.paginate():
for role in page["Roles"]:
if not self.audit_resources or (
@@ -81,14 +83,16 @@ class IAM:
is_service_role=is_service_role(role),
)
)
return roles
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
finally:
return roles
def __get_credential_report__(self):
report_is_completed = False
credential_list = []
try:
while not report_is_completed:
report_status = self.client.generate_credential_report()
@@ -99,29 +103,29 @@ class IAM:
credential_lines = credential.split("\n")
csv_reader = csv.DictReader(credential_lines, delimiter=",")
credential_list = list(csv_reader)
return credential_list
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return []
finally:
return credential_list
def __get_groups__(self):
try:
get_groups_paginator = self.client.get_paginator("list_groups")
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
else:
groups = []
get_groups_paginator = self.client.get_paginator("list_groups")
for page in get_groups_paginator.paginate():
for group in page["Groups"]:
if not self.audit_resources or (
is_resource_filtered(group["Arn"], self.audit_resources)
):
groups.append(Group(group["GroupName"], group["Arn"]))
groups.append(Group(name=group["GroupName"], arn=group["Arn"]))
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
finally:
return groups
def __get_account_summary__(self):
@@ -131,12 +135,13 @@ class IAM:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
else:
account_summary = None
finally:
return account_summary
def __get_password_policy__(self):
try:
stored_password_policy = None
password_policy = self.client.get_account_password_policy()[
"PasswordPolicy"
]
@@ -150,36 +155,33 @@ class IAM:
reuse_prevention = password_policy["PasswordReusePrevention"]
if "HardExpiry" in password_policy:
hard_expiry = password_policy["HardExpiry"]
stored_password_policy = PasswordPolicy(
length=password_policy["MinimumPasswordLength"],
symbols=password_policy["RequireSymbols"],
numbers=password_policy["RequireNumbers"],
uppercase=password_policy["RequireUppercaseCharacters"],
lowercase=password_policy["RequireLowercaseCharacters"],
allow_change=password_policy["AllowUsersToChangePassword"],
expiration=password_policy["RequireNumbers"],
max_age=max_age,
reuse_prevention=reuse_prevention,
hard_expiry=hard_expiry,
)
except Exception as error:
if "NoSuchEntity" in str(error):
# Password policy does not exist
password_policy = None
stored_password_policy = None
else:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
else:
return PasswordPolicy(
password_policy["MinimumPasswordLength"],
password_policy["RequireSymbols"],
password_policy["RequireNumbers"],
password_policy["RequireUppercaseCharacters"],
password_policy["RequireLowercaseCharacters"],
password_policy["AllowUsersToChangePassword"],
password_policy["ExpirePasswords"],
max_age,
reuse_prevention,
hard_expiry,
)
finally:
return stored_password_policy
def __get_users__(self):
try:
get_users_paginator = self.client.get_paginator("list_users")
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
else:
users = []
for page in get_users_paginator.paginate():
for user in page["Users"]:
@@ -187,33 +189,37 @@ class IAM:
is_resource_filtered(user["Arn"], self.audit_resources)
):
if "PasswordLastUsed" not in user:
users.append(User(user["UserName"], user["Arn"], None))
users.append(User(name=user["UserName"], arn=user["Arn"]))
else:
users.append(
User(
user["UserName"],
user["Arn"],
user["PasswordLastUsed"],
name=user["UserName"],
arn=user["Arn"],
password_last_used=user["PasswordLastUsed"],
)
)
return users
def __list_virtual_mfa_devices__(self):
try:
list_virtual_mfa_devices_paginator = self.client.get_paginator(
"list_virtual_mfa_devices"
)
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
else:
finally:
return users
def __list_virtual_mfa_devices__(self):
try:
mfa_devices = []
list_virtual_mfa_devices_paginator = self.client.get_paginator(
"list_virtual_mfa_devices"
)
for page in list_virtual_mfa_devices_paginator.paginate():
for mfa_device in page["VirtualMFADevices"]:
mfa_devices.append(mfa_device)
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
finally:
return mfa_devices
def __list_attached_group_policies__(self):
@@ -244,14 +250,14 @@ class IAM:
for user in page["Users"]:
if "PasswordLastUsed" not in user:
group_users.append(
User(user["UserName"], user["Arn"], None)
User(name=user["UserName"], arn=user["Arn"])
)
else:
group_users.append(
User(
user["UserName"],
user["Arn"],
user["PasswordLastUsed"],
name=user["UserName"],
arn=user["Arn"],
password_last_used=user["PasswordLastUsed"],
)
)
group.users = group_users
@@ -273,7 +279,9 @@ class IAM:
mfa_type = (
mfa_device["SerialNumber"].split(":")[5].split("/")[0]
)
mfa_devices.append(MFADevice(mfa_serial_number, mfa_type))
mfa_devices.append(
MFADevice(serial_number=mfa_serial_number, type=mfa_type)
)
user.mfa_devices = mfa_devices
except Exception as error:
logger.error(
@@ -333,7 +341,6 @@ class IAM:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
finally:
return support_roles
@@ -351,13 +358,11 @@ class IAM:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
else:
finally:
return policies
def __list_policies_version__(self, policies):
try:
pass
for policy in policies:
policy_version = self.client.get_policy_version(
PolicyArn=policy["Arn"], VersionId=policy["DefaultVersionId"]
@@ -375,7 +380,7 @@ class IAM:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
saml_providers = None
finally:
return saml_providers
@@ -390,79 +395,49 @@ class IAM:
):
server_certificates.append(
Certificate(
certificate["ServerCertificateName"],
certificate["ServerCertificateId"],
certificate["Arn"],
certificate["Expiration"],
name=certificate["ServerCertificateName"],
id=certificate["ServerCertificateId"],
arn=certificate["Arn"],
expiration=certificate["Expiration"],
)
)
except Exception as error:
logger.error(
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
finally:
return server_certificates
@dataclass
class MFADevice:
class MFADevice(BaseModel):
serial_number: str
type: str
def __init__(self, serial_number, type):
self.serial_number = serial_number
self.type = type
@dataclass
class User:
class User(BaseModel):
name: str
arn: str
mfa_devices: list[MFADevice]
password_last_used: str
attached_policies: list[dict]
inline_policies: list[str]
def __init__(self, name, arn, password_last_used):
self.name = name
self.arn = arn
self.password_last_used = password_last_used
self.mfa_devices = []
self.attached_policies = []
self.inline_policies = []
mfa_devices: list[MFADevice] = []
password_last_used: Optional[datetime]
attached_policies: list[dict] = []
inline_policies: list[str] = []
@dataclass
class Role:
class Role(BaseModel):
name: str
arn: str
assume_role_policy: dict
is_service_role: bool
def __init__(self, name, arn, assume_role_policy, is_service_role):
self.name = name
self.arn = arn
self.assume_role_policy = assume_role_policy
self.is_service_role = is_service_role
@dataclass
class Group:
class Group(BaseModel):
name: str
arn: str
attached_policies: list[dict]
users: list[User]
def __init__(self, name, arn):
self.name = name
self.arn = arn
self.attached_policies = []
self.users = []
attached_policies: list[dict] = []
users: list[User] = []
@dataclass
class PasswordPolicy:
class PasswordPolicy(BaseModel):
length: int
symbols: bool
numbers: bool
@@ -470,44 +445,13 @@ class PasswordPolicy:
lowercase: bool
allow_change: bool
expiration: bool
max_age: int
reuse_prevention: int
hard_expiry: bool
def __init__(
self,
length,
symbols,
numbers,
uppercase,
lowercase,
allow_change,
expiration,
max_age,
reuse_prevention,
hard_expiry,
):
self.length = length
self.symbols = symbols
self.numbers = numbers
self.uppercase = uppercase
self.lowercase = lowercase
self.allow_change = allow_change
self.expiration = expiration
self.max_age = max_age
self.reuse_prevention = reuse_prevention
self.hard_expiry = hard_expiry
max_age: Optional[int]
reuse_prevention: Optional[int]
hard_expiry: Optional[bool]
@dataclass
class Certificate:
class Certificate(BaseModel):
name: str
id: str
arn: str
expiration: datetime
def __init__(self, name, id, arn, expiration):
self.name = name
self.id = id
self.arn = arn
self.expiration = expiration