From eed7ab979370fb3cfdca89c838946fe929f907ed Mon Sep 17 00:00:00 2001 From: Nacho Rivera Date: Thu, 2 Mar 2023 11:16:05 +0100 Subject: [PATCH] fix(iam): refactor IAM service (#2010) --- .../iam_check_saml_providers_sts.py | 19 +- .../iam_root_hardware_mfa_enabled.py | 43 ++-- .../iam_root_mfa_enabled.py | 28 +-- .../providers/aws/services/iam/iam_service.py | 218 +++++++----------- 4 files changed, 127 insertions(+), 181 deletions(-) diff --git a/prowler/providers/aws/services/iam/iam_check_saml_providers_sts/iam_check_saml_providers_sts.py b/prowler/providers/aws/services/iam/iam_check_saml_providers_sts/iam_check_saml_providers_sts.py index d71705e9..28fe4754 100644 --- a/prowler/providers/aws/services/iam/iam_check_saml_providers_sts/iam_check_saml_providers_sts.py +++ b/prowler/providers/aws/services/iam/iam_check_saml_providers_sts/iam_check_saml_providers_sts.py @@ -5,14 +5,15 @@ from prowler.providers.aws.services.iam.iam_client import iam_client class iam_check_saml_providers_sts(Check): def execute(self) -> Check_Report_AWS: findings = [] - for provider in iam_client.saml_providers: - report = Check_Report_AWS(self.metadata()) - provider_name = provider["Arn"].split("/")[1] - report.resource_id = provider_name - report.resource_arn = provider["Arn"] - report.region = iam_client.region - report.status = "PASS" - report.status_extended = f"SAML Provider {provider_name} has been found" - findings.append(report) + if iam_client.saml_providers: + for provider in iam_client.saml_providers: + report = Check_Report_AWS(self.metadata()) + provider_name = provider["Arn"].split("/")[1] + report.resource_id = provider_name + report.resource_arn = provider["Arn"] + report.region = iam_client.region + report.status = "PASS" + report.status_extended = f"SAML Provider {provider_name} has been found" + findings.append(report) return findings diff --git a/prowler/providers/aws/services/iam/iam_root_hardware_mfa_enabled/iam_root_hardware_mfa_enabled.py b/prowler/providers/aws/services/iam/iam_root_hardware_mfa_enabled/iam_root_hardware_mfa_enabled.py index 717cb5de..318ae020 100644 --- a/prowler/providers/aws/services/iam/iam_root_hardware_mfa_enabled/iam_root_hardware_mfa_enabled.py +++ b/prowler/providers/aws/services/iam/iam_root_hardware_mfa_enabled/iam_root_hardware_mfa_enabled.py @@ -7,28 +7,29 @@ class iam_root_hardware_mfa_enabled(Check): findings = [] # This check is only avaible in Commercial Partition if iam_client.partition == "aws": - virtual_mfa = False - report = Check_Report_AWS(self.metadata()) - report.region = iam_client.region - report.resource_id = "root" - report.resource_arn = f"arn:aws:iam::{iam_client.account}:root" + if iam_client.account_summary: + virtual_mfa = False + report = Check_Report_AWS(self.metadata()) + report.region = iam_client.region + report.resource_id = "root" + report.resource_arn = f"arn:aws:iam::{iam_client.account}:root" - if iam_client.account_summary["SummaryMap"]["AccountMFAEnabled"] > 0: - virtual_mfas = iam_client.virtual_mfa_devices - for mfa in virtual_mfas: - if "root" in mfa["SerialNumber"]: - virtual_mfa = True - report.status = "FAIL" - report.status_extended = "Root account has a virtual MFA instead of a hardware MFA device enabled." - if not virtual_mfa: - report.status = "PASS" - report.status_extended = ( - "Root account has a hardware MFA device enabled." - ) - else: - report.status = "FAIL" - report.status_extended = "MFA is not enabled for root account." + if iam_client.account_summary["SummaryMap"]["AccountMFAEnabled"] > 0: + virtual_mfas = iam_client.virtual_mfa_devices + for mfa in virtual_mfas: + if "root" in mfa["SerialNumber"]: + virtual_mfa = True + report.status = "FAIL" + report.status_extended = "Root account has a virtual MFA instead of a hardware MFA device enabled." + if not virtual_mfa: + report.status = "PASS" + report.status_extended = ( + "Root account has a hardware MFA device enabled." + ) + else: + report.status = "FAIL" + report.status_extended = "MFA is not enabled for root account." - findings.append(report) + findings.append(report) return findings diff --git a/prowler/providers/aws/services/iam/iam_root_mfa_enabled/iam_root_mfa_enabled.py b/prowler/providers/aws/services/iam/iam_root_mfa_enabled/iam_root_mfa_enabled.py index e073b450..d515e411 100644 --- a/prowler/providers/aws/services/iam/iam_root_mfa_enabled/iam_root_mfa_enabled.py +++ b/prowler/providers/aws/services/iam/iam_root_mfa_enabled/iam_root_mfa_enabled.py @@ -5,19 +5,19 @@ from prowler.providers.aws.services.iam.iam_client import iam_client class iam_root_mfa_enabled(Check): def execute(self) -> Check_Report_AWS: findings = [] - - for user in iam_client.credential_report: - if user["user"] == "": - report = Check_Report_AWS(self.metadata()) - report.region = iam_client.region - report.resource_id = user["user"] - report.resource_arn = user["arn"] - if user["mfa_active"] == "false": - report.status = "FAIL" - report.status_extended = "MFA is not enabled for root account." - else: - report.status = "PASS" - report.status_extended = "MFA is enabled for root account." - findings.append(report) + if iam_client.credential_report: + for user in iam_client.credential_report: + if user["user"] == "": + report = Check_Report_AWS(self.metadata()) + report.region = iam_client.region + report.resource_id = user["user"] + report.resource_arn = user["arn"] + if user["mfa_active"] == "false": + report.status = "FAIL" + report.status_extended = "MFA is not enabled for root account." + else: + report.status = "PASS" + report.status_extended = "MFA is enabled for root account." + findings.append(report) return findings diff --git a/prowler/providers/aws/services/iam/iam_service.py b/prowler/providers/aws/services/iam/iam_service.py index 9bbf81af..6e5edc2b 100644 --- a/prowler/providers/aws/services/iam/iam_service.py +++ b/prowler/providers/aws/services/iam/iam_service.py @@ -1,6 +1,8 @@ import csv -from dataclasses import dataclass from datetime import datetime +from typing import Optional + +from pydantic import BaseModel from prowler.lib.logger import logger from prowler.lib.scan_filters.scan_filters import is_resource_filtered @@ -66,8 +68,8 @@ class IAM: def __get_roles__(self): try: - get_roles_paginator = self.client.get_paginator("list_roles") roles = [] + get_roles_paginator = self.client.get_paginator("list_roles") for page in get_roles_paginator.paginate(): for role in page["Roles"]: if not self.audit_resources or ( @@ -81,14 +83,16 @@ class IAM: is_service_role=is_service_role(role), ) ) - return roles except Exception as error: logger.error( f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) + finally: + return roles def __get_credential_report__(self): report_is_completed = False + credential_list = [] try: while not report_is_completed: report_status = self.client.generate_credential_report() @@ -99,29 +103,29 @@ class IAM: credential_lines = credential.split("\n") csv_reader = csv.DictReader(credential_lines, delimiter=",") credential_list = list(csv_reader) - return credential_list except Exception as error: logger.error( f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - return [] + finally: + return credential_list def __get_groups__(self): try: - get_groups_paginator = self.client.get_paginator("list_groups") - except Exception as error: - logger.error( - f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" - ) - else: groups = [] + get_groups_paginator = self.client.get_paginator("list_groups") for page in get_groups_paginator.paginate(): for group in page["Groups"]: if not self.audit_resources or ( is_resource_filtered(group["Arn"], self.audit_resources) ): - groups.append(Group(group["GroupName"], group["Arn"])) + groups.append(Group(name=group["GroupName"], arn=group["Arn"])) + except Exception as error: + logger.error( + f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + finally: return groups def __get_account_summary__(self): @@ -131,12 +135,13 @@ class IAM: logger.error( f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - else: - + account_summary = None + finally: return account_summary def __get_password_policy__(self): try: + stored_password_policy = None password_policy = self.client.get_account_password_policy()[ "PasswordPolicy" ] @@ -150,36 +155,33 @@ class IAM: reuse_prevention = password_policy["PasswordReusePrevention"] if "HardExpiry" in password_policy: hard_expiry = password_policy["HardExpiry"] + + stored_password_policy = PasswordPolicy( + length=password_policy["MinimumPasswordLength"], + symbols=password_policy["RequireSymbols"], + numbers=password_policy["RequireNumbers"], + uppercase=password_policy["RequireUppercaseCharacters"], + lowercase=password_policy["RequireLowercaseCharacters"], + allow_change=password_policy["AllowUsersToChangePassword"], + expiration=password_policy["RequireNumbers"], + max_age=max_age, + reuse_prevention=reuse_prevention, + hard_expiry=hard_expiry, + ) except Exception as error: if "NoSuchEntity" in str(error): # Password policy does not exist - password_policy = None + stored_password_policy = None else: logger.error( f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - else: - return PasswordPolicy( - password_policy["MinimumPasswordLength"], - password_policy["RequireSymbols"], - password_policy["RequireNumbers"], - password_policy["RequireUppercaseCharacters"], - password_policy["RequireLowercaseCharacters"], - password_policy["AllowUsersToChangePassword"], - password_policy["ExpirePasswords"], - max_age, - reuse_prevention, - hard_expiry, - ) + finally: + return stored_password_policy def __get_users__(self): try: get_users_paginator = self.client.get_paginator("list_users") - except Exception as error: - logger.error( - f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" - ) - else: users = [] for page in get_users_paginator.paginate(): for user in page["Users"]: @@ -187,33 +189,37 @@ class IAM: is_resource_filtered(user["Arn"], self.audit_resources) ): if "PasswordLastUsed" not in user: - users.append(User(user["UserName"], user["Arn"], None)) + users.append(User(name=user["UserName"], arn=user["Arn"])) else: users.append( User( - user["UserName"], - user["Arn"], - user["PasswordLastUsed"], + name=user["UserName"], + arn=user["Arn"], + password_last_used=user["PasswordLastUsed"], ) ) - - return users - - def __list_virtual_mfa_devices__(self): - try: - list_virtual_mfa_devices_paginator = self.client.get_paginator( - "list_virtual_mfa_devices" - ) except Exception as error: logger.error( f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - else: + finally: + return users + + def __list_virtual_mfa_devices__(self): + try: mfa_devices = [] + list_virtual_mfa_devices_paginator = self.client.get_paginator( + "list_virtual_mfa_devices" + ) + for page in list_virtual_mfa_devices_paginator.paginate(): for mfa_device in page["VirtualMFADevices"]: mfa_devices.append(mfa_device) - + except Exception as error: + logger.error( + f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + finally: return mfa_devices def __list_attached_group_policies__(self): @@ -244,14 +250,14 @@ class IAM: for user in page["Users"]: if "PasswordLastUsed" not in user: group_users.append( - User(user["UserName"], user["Arn"], None) + User(name=user["UserName"], arn=user["Arn"]) ) else: group_users.append( User( - user["UserName"], - user["Arn"], - user["PasswordLastUsed"], + name=user["UserName"], + arn=user["Arn"], + password_last_used=user["PasswordLastUsed"], ) ) group.users = group_users @@ -273,7 +279,9 @@ class IAM: mfa_type = ( mfa_device["SerialNumber"].split(":")[5].split("/")[0] ) - mfa_devices.append(MFADevice(mfa_serial_number, mfa_type)) + mfa_devices.append( + MFADevice(serial_number=mfa_serial_number, type=mfa_type) + ) user.mfa_devices = mfa_devices except Exception as error: logger.error( @@ -333,7 +341,6 @@ class IAM: logger.error( f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - finally: return support_roles @@ -351,13 +358,11 @@ class IAM: logger.error( f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - else: + finally: return policies def __list_policies_version__(self, policies): try: - pass - for policy in policies: policy_version = self.client.get_policy_version( PolicyArn=policy["Arn"], VersionId=policy["DefaultVersionId"] @@ -375,7 +380,7 @@ class IAM: logger.error( f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - + saml_providers = None finally: return saml_providers @@ -390,79 +395,49 @@ class IAM: ): server_certificates.append( Certificate( - certificate["ServerCertificateName"], - certificate["ServerCertificateId"], - certificate["Arn"], - certificate["Expiration"], + name=certificate["ServerCertificateName"], + id=certificate["ServerCertificateId"], + arn=certificate["Arn"], + expiration=certificate["Expiration"], ) ) except Exception as error: logger.error( f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - finally: return server_certificates -@dataclass -class MFADevice: +class MFADevice(BaseModel): serial_number: str type: str - def __init__(self, serial_number, type): - self.serial_number = serial_number - self.type = type - -@dataclass -class User: +class User(BaseModel): name: str arn: str - mfa_devices: list[MFADevice] - password_last_used: str - attached_policies: list[dict] - inline_policies: list[str] - - def __init__(self, name, arn, password_last_used): - self.name = name - self.arn = arn - self.password_last_used = password_last_used - self.mfa_devices = [] - self.attached_policies = [] - self.inline_policies = [] + mfa_devices: list[MFADevice] = [] + password_last_used: Optional[datetime] + attached_policies: list[dict] = [] + inline_policies: list[str] = [] -@dataclass -class Role: +class Role(BaseModel): name: str arn: str assume_role_policy: dict is_service_role: bool - def __init__(self, name, arn, assume_role_policy, is_service_role): - self.name = name - self.arn = arn - self.assume_role_policy = assume_role_policy - self.is_service_role = is_service_role - -@dataclass -class Group: +class Group(BaseModel): name: str arn: str - attached_policies: list[dict] - users: list[User] - - def __init__(self, name, arn): - self.name = name - self.arn = arn - self.attached_policies = [] - self.users = [] + attached_policies: list[dict] = [] + users: list[User] = [] -@dataclass -class PasswordPolicy: +class PasswordPolicy(BaseModel): length: int symbols: bool numbers: bool @@ -470,44 +445,13 @@ class PasswordPolicy: lowercase: bool allow_change: bool expiration: bool - max_age: int - reuse_prevention: int - hard_expiry: bool - - def __init__( - self, - length, - symbols, - numbers, - uppercase, - lowercase, - allow_change, - expiration, - max_age, - reuse_prevention, - hard_expiry, - ): - self.length = length - self.symbols = symbols - self.numbers = numbers - self.uppercase = uppercase - self.lowercase = lowercase - self.allow_change = allow_change - self.expiration = expiration - self.max_age = max_age - self.reuse_prevention = reuse_prevention - self.hard_expiry = hard_expiry + max_age: Optional[int] + reuse_prevention: Optional[int] + hard_expiry: Optional[bool] -@dataclass -class Certificate: +class Certificate(BaseModel): name: str id: str arn: str expiration: datetime - - def __init__(self, name, id, arn, expiration): - self.name = name - self.id = id - self.arn = arn - self.expiration = expiration