From e52ab12696895faa366758df454b7119725bb919 Mon Sep 17 00:00:00 2001 From: Nacho Rivera <59198746+n4ch04@users.noreply.github.com> Date: Tue, 21 Jun 2022 07:53:49 +0200 Subject: [PATCH] feat(global_aws_session): Global data structure for the current AWS audit (#1212) * fix(audit info): Common data structure for current audit * fix(iam): iam session audit fixed * feat(aws_session): Include else block Co-authored-by: Pepe Fagoaga --- providers/aws/aws_provider.py | 236 +++++++++------------- providers/aws/models.py | 31 +++ providers/aws/services/ec2/ec2_service.py | 53 ++--- providers/aws/services/iam/iam_service.py | 10 +- prowler.py | 18 +- 5 files changed, 172 insertions(+), 176 deletions(-) create mode 100644 providers/aws/models.py diff --git a/providers/aws/aws_provider.py b/providers/aws/aws_provider.py index acf62ba5..b5c164f0 100644 --- a/providers/aws/aws_provider.py +++ b/providers/aws/aws_provider.py @@ -1,6 +1,3 @@ -from dataclasses import dataclass -from datetime import datetime - from arnparse import arnparse from boto3 import session from botocore.credentials import RefreshableCredentials @@ -8,63 +5,32 @@ from botocore.session import get_session from lib.arn.arn import arn_parsing from lib.logger import logger - - -@dataclass -class AWS_Credentials: - aws_access_key_id: str - aws_session_token: str - aws_secret_access_key: str - expiration: datetime - - -@dataclass -class Input_Data: - profile: str - role_arn: str - session_duration: int - external_id: str - regions: list - - -@dataclass -class AWS_Assume_Role: - role_arn: str - session_duration: int - external_id: str - sts_session: session - partition: str - - -@dataclass -class AWS_Session_Info: - profile: str - credentials: AWS_Credentials - role_info: AWS_Assume_Role +from providers.aws.models import AWS_Assume_Role, AWS_Audit_Info, AWS_Credentials ################## AWS PROVIDER class AWS_Provider: - def __init__(self, session_info): - self.aws_session = self.set_session(session_info) - self.role_info = session_info.role_info + def __init__(self, audit_info): + logger.info("Instantiating aws provider ...") + self.aws_session = self.set_session(audit_info) + self.role_info = audit_info.assumed_role_info def get_session(self): return self.aws_session - def set_session(self, session_info): + def set_session(self, audit_info): try: - if session_info.credentials: + if audit_info.credentials: # If we receive a credentials object filled is coming form an assumed role, so renewal is needed logger.info("Creating session for assumed role ...") # From botocore we can use RefreshableCredentials class, which has an attribute (refresh_using) # that needs to be a method without arguments that retrieves a new set of fresh credentials # asuming the role again. -> https://github.com/boto/botocore/blob/098cc255f81a25b852e1ecdeb7adebd94c7b1b73/botocore/credentials.py#L395 assumed_refreshable_credentials = RefreshableCredentials( - access_key=session_info.credentials.aws_access_key_id, - secret_key=session_info.credentials.aws_secret_access_key, - token=session_info.credentials.aws_session_token, - expiry_time=session_info.credentials.expiration, + access_key=audit_info.credentials.aws_access_key_id, + secret_key=audit_info.credentials.aws_secret_access_key, + token=audit_info.credentials.aws_session_token, + expiry_time=audit_info.credentials.expiration, refresh_using=self.refresh, method="sts-assume-role", ) @@ -74,13 +40,13 @@ class AWS_Provider: assumed_botocore_session.set_config_variable("region", "us-east-1") return session.Session( - profile_name=session_info.profile, + profile_name=audit_info.profile, botocore_session=assumed_botocore_session, ) # If we do not receive credentials start the session using the profile else: logger.info("Creating session for not assumed identity ...") - return session.Session(profile_name=session_info.profile) + return session.Session(profile_name=audit_info.profile) except Exception as error: logger.critical(f"{error.__class__.__name__} -- {error}") quit() @@ -105,6 +71,84 @@ class AWS_Provider: return refreshed_credentials +def provider_set_session( + input_profile, input_role, input_session_duration, input_external_id, input_regions +): + + # Mark variable that stores all the info about the audit as global + global current_audit_info + + assumed_session = None + + # Setting session + current_audit_info = AWS_Audit_Info( + original_session=None, + audit_session=None, + audited_account=None, + audited_partition=None, + profile=input_profile, + credentials=None, + assumed_role_info=AWS_Assume_Role( + role_arn=input_role, + session_duration=input_session_duration, + external_id=input_external_id, + ), + audited_regions=input_regions, + ) + + logger.info("Generating original session ...") + # Create an global original session using only profile/basic credentials info + current_audit_info.original_session = AWS_Provider(current_audit_info).get_session() + logger.info("Validating credentials ...") + # Verificate if we have valid credentials + caller_identity = validate_credentials(current_audit_info.original_session) + + logger.info("Credentials validated") + logger.info(f"Original caller identity UserId : {caller_identity['UserId']}") + logger.info(f"Original caller identity ARN : {caller_identity['Arn']}") + + current_audit_info.audited_account = caller_identity["Account"] + current_audit_info.audited_partition = arnparse(caller_identity["Arn"]).partition + + logger.info("Checking if role assumption is needed ...") + if current_audit_info.assumed_role_info.role_arn: + # Check if role arn is valid + try: + # this returns the arn already parsed, calls arnparse, into a dict to be used when it is needed to access its fields + role_arn_parsed = arn_parsing(current_audit_info.assumed_role_info.role_arn) + + except Exception as error: + logger.critical(f"{error.__class__.__name__} -- {error}") + quit() + + else: + logger.info( + f"Assuming role {current_audit_info.assumed_role_info.role_arn}" + ) + # Assume the role + assumed_role_response = assume_role() + logger.info("Role assumed") + # Set the info needed to create a session with an assumed role + current_audit_info.credentials = AWS_Credentials( + aws_access_key_id=assumed_role_response["Credentials"]["AccessKeyId"], + aws_session_token=assumed_role_response["Credentials"]["SessionToken"], + aws_secret_access_key=assumed_role_response["Credentials"][ + "SecretAccessKey" + ], + expiration=assumed_role_response["Credentials"]["Expiration"], + ) + assumed_session = AWS_Provider(current_audit_info).get_session() + + if assumed_session: + logger.info("Audit session is the new session created assuming role") + current_audit_info.audit_session = assumed_session + current_audit_info.audited_account = role_arn_parsed.account_id + current_audit_info.audited_partition = role_arn_parsed.partition + else: + logger.info("Audit session is the original one") + current_audit_info.audit_session = current_audit_info.original_session + + def validate_credentials(validate_session): try: validate_credentials_client = validate_session.client("sts") @@ -116,103 +160,25 @@ def validate_credentials(validate_session): return caller_identity -def provider_set_session(session_input): - - # global variables that are going to be shared accross the project - global aws_session - global original_session - global audited_regions - global audited_partition - global audited_account - - assumed_session = None - - # Initialize a session info dataclass only with info about the profile - session_info = AWS_Session_Info( - session_input.profile, - None, - None, - ) - - # Create an global original session using only profile/basic credentials info - original_session = AWS_Provider(session_info).get_session() - logger.info("Validating credentials ...") - # Verificate if we have valid credentials - caller_identity = validate_credentials(original_session) - - logger.info("Credentials validated") - logger.info(f"Original caller identity UserId : {caller_identity['UserId']}") - logger.info(f"Original caller identity ARN : {caller_identity['Arn']}") - - # Set some global values for original session - audited_regions = session_input.regions - audited_account = caller_identity["Account"] - audited_partition = arnparse(caller_identity["Arn"]).partition - - if session_input.role_arn: - # Check if role arn is valid - try: - # this returns the arn already parsed, calls arnparse, into a dict to be used when it is needed to access its fields - role_arn_parsed = arn_parsing(session_input.role_arn) - - except Exception as error: - logger.critical(f"{error.__class__.__name__} -- {error}") - quit() - - # Set info for role assumption if needed - role_info = AWS_Assume_Role( - session_input.role_arn, - session_input.session_duration, - session_input.external_id, - original_session, - audited_partition, - ) - logger.info(f"Assuming role {role_info.role_arn}") - # Assume the role - assumed_role_response = assume_role(role_info) - logger.info("Role assumed") - # Set the info needed to create a session with an assumed role - session_info = AWS_Session_Info( - session_input.profile, - AWS_Credentials( - aws_access_key_id=assumed_role_response["Credentials"]["AccessKeyId"], - aws_session_token=assumed_role_response["Credentials"]["SessionToken"], - aws_secret_access_key=assumed_role_response["Credentials"][ - "SecretAccessKey" - ], - expiration=assumed_role_response["Credentials"]["Expiration"], - ), - role_info, - ) - assumed_session = AWS_Provider(session_info).get_session() - - if assumed_session: - aws_session = assumed_session - audited_account = role_arn_parsed.account_id - audited_partition = role_arn_parsed.partition - else: - aws_session = original_session - - -def assume_role(role_info): +def assume_role(): try: # set the info to assume the role from the partition, account and role name - sts_client = role_info.sts_session.client("sts") + sts_client = current_audit_info.original_session.client("sts") # If external id, set it to the assume role api call - if role_info.external_id: + if current_audit_info.assumed_role_info.external_id: assumed_credentials = sts_client.assume_role( - RoleArn=role_info.role_arn, - RoleSessionName="ProwlerProSession", - DurationSeconds=role_info.session_duration, - ExternalId=role_info.external_id, + RoleArn=current_audit_info.assumed_role_info.role_arn, + RoleSessionName="ProwlerProAsessmentSession", + DurationSeconds=current_audit_info.assumed_role_info.session_duration, + ExternalId=current_audit_info.assumed_role_info.external_id, ) # else assume the role without the external id else: assumed_credentials = sts_client.assume_role( - RoleArn=role_info.role_arn, - RoleSessionName="ProwlerProSession", - DurationSeconds=role_info.session_duration, + RoleArn=current_audit_info.assumed_role_info.role_arn, + RoleSessionName="ProwlerProAsessmentSession", + DurationSeconds=current_audit_info.assumed_role_info.session_duration, ) except Exception as error: logger.critical(f"{error.__class__.__name__} -- {error}") diff --git a/providers/aws/models.py b/providers/aws/models.py new file mode 100644 index 00000000..49fb8a8c --- /dev/null +++ b/providers/aws/models.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from datetime import datetime + +from boto3 import session + + +@dataclass +class AWS_Credentials: + aws_access_key_id: str + aws_session_token: str + aws_secret_access_key: str + expiration: datetime + + +@dataclass +class AWS_Assume_Role: + role_arn: str + session_duration: int + external_id: str + + +@dataclass +class AWS_Audit_Info: + original_session: session.Session + audit_session: session.Session + audited_account: int + audited_partition: str + profile: str + credentials: AWS_Credentials + assumed_role_info: AWS_Assume_Role + audited_regions: list diff --git a/providers/aws/services/ec2/ec2_service.py b/providers/aws/services/ec2/ec2_service.py index 465f4c31..e88c5c7c 100644 --- a/providers/aws/services/ec2/ec2_service.py +++ b/providers/aws/services/ec2/ec2_service.py @@ -5,31 +5,24 @@ import urllib.request from config.config import aws_services_json_file, aws_services_json_url from lib.logger import logger from lib.utils.utils import open_file, parse_json_file -from providers.aws.aws_provider import ( - audited_account, - audited_partition, - audited_regions, - aws_session, -) +from providers.aws.aws_provider import current_audit_info ################## EC2 class EC2: - def __init__(self, aws_session, audited_regions): + def __init__(self, audit_info): self.service = "ec2" - self.aws_session = aws_session + self.session = audit_info.audit_session + self.audited_account = audit_info.audited_account self.regional_clients = self.__generate_regional_clients__( - self.service, audited_regions + self.service, audit_info ) self.__threading_call__(self.__describe_snapshots__) - def __get_clients__(self): - return self.clients - def __get_session__(self): - return self.aws_session + return self.session - def __generate_regional_clients__(self, service, audited_regions): + def __generate_regional_clients__(self, service, audit_info): regional_clients = [] try: # Try to get the list online with urllib.request.urlopen(aws_services_json_url) as url: @@ -40,46 +33,50 @@ class EC2: data = parse_json_file(f) for att in data["prices"]: - if audited_regions: # Check for input aws audited_regions + if ( + audit_info.audited_regions + ): # Check for input aws audit_info.audited_regions if ( service in att["id"].split(":")[0] - and att["attributes"]["aws:region"] in audited_regions + and att["attributes"]["aws:region"] in audit_info.audited_regions ): # Check if service has this region region = att["attributes"]["aws:region"] - regional_client = aws_session.client(service, region_name=region) + regional_client = audit_info.audit_session.client( + service, region_name=region + ) regional_client.region = region regional_clients.append(regional_client) else: - if audited_partition in "aws": + if audit_info.audited_partition in "aws": if ( service in att["id"].split(":")[0] and "gov" not in att["attributes"]["aws:region"] and "cn" not in att["attributes"]["aws:region"] ): region = att["attributes"]["aws:region"] - regional_client = aws_session.client( + regional_client = audit_info.audit_session.client( service, region_name=region ) regional_client.region = region regional_clients.append(regional_client) - elif audited_partition in "cn": + elif audit_info.audited_partition in "cn": if ( service in att["id"].split(":")[0] and "cn" in att["attributes"]["aws:region"] ): region = att["attributes"]["aws:region"] - regional_client = aws_session.client( + regional_client = audit_info.audit_session.client( service, region_name=region ) regional_client.region = region regional_clients.append(regional_client) - elif audited_partition in "gov": + elif audit_info.audited_partition in "gov": if ( service in att["id"].split(":")[0] and "gov" in att["attributes"]["aws:region"] ): region = att["attributes"]["aws:region"] - regional_client = aws_session.client( + regional_client = audit_info.audit_session.client( service, region_name=region ) regional_client.region = region @@ -90,13 +87,17 @@ class EC2: def __threading_call__(self, call): threads = [] for regional_client in self.regional_clients: - threads.append(threading.Thread(target=call, args=(regional_client,))) + threads.append( + threading.Thread( + target=call, args=(regional_client, self.audited_account) + ) + ) for t in threads: t.start() for t in threads: t.join() - def __describe_snapshots__(self, regional_client): + def __describe_snapshots__(self, regional_client, audited_account): logger.info("EC2 - Describing Snapshots...") try: describe_snapshots_paginator = regional_client.get_paginator( @@ -114,4 +115,4 @@ class EC2: regional_client.snapshots = snapshots -ec2_client = EC2(aws_session, audited_regions) +ec2_client = EC2(current_audit_info) diff --git a/providers/aws/services/iam/iam_service.py b/providers/aws/services/iam/iam_service.py index af20bcc8..ed11f6b2 100644 --- a/providers/aws/services/iam/iam_service.py +++ b/providers/aws/services/iam/iam_service.py @@ -1,13 +1,13 @@ from lib.logger import logger -from providers.aws.aws_provider import aws_session +from providers.aws.aws_provider import current_audit_info ################## IAM class IAM: - def __init__(self, session): + def __init__(self, audit_info): self.service = "iam" - self.session = session - self.client = session.client(self.service) + self.session = audit_info.audit_session + self.client = self.session.client(self.service) self.users = self.__get_users__() self.roles = self.__get_roles__() self.customer_managed_policies = self.__get_customer_managed_policies__() @@ -89,7 +89,7 @@ class IAM: try: - iam_client = IAM(aws_session) + iam_client = IAM(current_audit_info) except Exception as error: logger.critical(f"{error.__class__.__name__} -- {error}") quit() diff --git a/prowler.py b/prowler.py index 06694b54..d5d44250 100644 --- a/prowler.py +++ b/prowler.py @@ -11,7 +11,7 @@ from lib.check.check import ( run_check, ) from lib.logger import logger, logging_levels -from providers.aws.aws_provider import Input_Data, provider_set_session +from providers.aws.aws_provider import provider_set_session if __name__ == "__main__": # CLI Arguments @@ -103,17 +103,15 @@ if __name__ == "__main__": if args.no_banner: print_banner() - # Setting session - session_input = Input_Data( - profile=args.profile, - role_arn=args.role, - session_duration=args.session_duration, - external_id=args.external_id, - regions=args.filter_region, + # Set global session + provider_set_session( + args.profile, + args.role, + args.session_duration, + args.external_id, + args.filter_region, ) - provider_set_session(session_input) - # Load checks to execute logger.debug("Loading checks") checks_to_execute = load_checks_to_execute(