diff --git a/prowler/providers/aws/lib/service/service.py b/prowler/providers/aws/lib/service/service.py index 9fb8dd5c..8f9b2bdc 100644 --- a/prowler/providers/aws/lib/service/service.py +++ b/prowler/providers/aws/lib/service/service.py @@ -1,17 +1,21 @@ -import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from prowler.lib.logger import logger from prowler.providers.aws.aws_provider import ( generate_regional_clients, get_default_region, ) from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info +MAX_WORKERS = 10 + class AWSService: """The AWSService class offers a parent class for each AWS Service to generate: - AWS Regional Clients - Shared information like the account ID and ARN, the the AWS partition and the checks audited - AWS Session + - Thread pool for the __threading_call__ - Also handles if the AWS Service is Global """ @@ -42,14 +46,40 @@ class AWSService: self.region = get_default_region(self.service, audit_info) self.client = self.session.client(self.service, self.region) + # Thread pool for __threading_call__ + self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS) + def __get_session__(self): return self.session - def __threading_call__(self, call): - threads = [] - for regional_client in self.regional_clients.values(): - threads.append(threading.Thread(target=call, args=(regional_client,))) - for t in threads: - t.start() - for t in threads: - t.join() + def __threading_call__(self, call, iterator=None): + # Use the provided iterator, or default to self.regional_clients + items = iterator if iterator is not None else self.regional_clients.values() + # Determine the total count for logging + item_count = len(items) + + # Trim leading and trailing underscores from the call's name + call_name = call.__name__.strip("_") + # Add Capitalization + call_name = " ".join([x.capitalize() for x in call_name.split("_")]) + + # Print a message based on the call's name, and if its regional or processing a list of items + if iterator is None: + logger.info( + f"{self.service.upper()} - Starting threads for '{call_name}' function across {item_count} regions..." + ) + else: + logger.info( + f"{self.service.upper()} - Starting threads for '{call_name}' function to process {item_count} items..." + ) + + # Submit tasks to the thread pool + futures = [self.thread_pool.submit(call, item) for item in items] + + # Wait for all tasks to complete + for future in as_completed(futures): + try: + future.result() # Raises exceptions from the thread, if any + except Exception: + # Handle exceptions if necessary + pass # Replace 'pass' with any additional exception handling logic. Currently handled within the called function diff --git a/prowler/providers/aws/services/ec2/ec2_service.py b/prowler/providers/aws/services/ec2/ec2_service.py index 4838f041..bb1863bd 100644 --- a/prowler/providers/aws/services/ec2/ec2_service.py +++ b/prowler/providers/aws/services/ec2/ec2_service.py @@ -17,7 +17,7 @@ class EC2(AWSService): super().__init__(__class__.__name__, audit_info) self.instances = [] self.__threading_call__(self.__describe_instances__) - self.__get_instance_user_data__() + self.__threading_call__(self.__get_instance_user_data__, self.instances) self.security_groups = [] self.regions_with_sgs = [] self.__threading_call__(self.__describe_security_groups__) @@ -27,7 +27,7 @@ class EC2(AWSService): self.volumes_with_snapshots = {} self.regions_with_snapshots = {} self.__threading_call__(self.__describe_snapshots__) - self.__get_snapshot_public__() + self.__threading_call__(self.__determine_public_snapshots__, self.snapshots) self.network_interfaces = [] self.__threading_call__(self.__describe_public_network_interfaces__) self.__threading_call__(self.__describe_sg_network_interfaces__) @@ -36,12 +36,11 @@ class EC2(AWSService): self.volumes = [] self.__threading_call__(self.__describe_volumes__) self.ebs_encryption_by_default = [] - self.__threading_call__(self.__get_ebs_encryption_by_default__) + self.__threading_call__(self.__get_ebs_encryption_settings__) self.elastic_ips = [] - self.__threading_call__(self.__describe_addresses__) + self.__threading_call__(self.__describe_ec2_addresses__) def __describe_instances__(self, regional_client): - logger.info("EC2 - Describing EC2 Instances...") try: describe_instances_paginator = regional_client.get_paginator( "describe_instances" @@ -106,7 +105,6 @@ class EC2(AWSService): ) def __describe_security_groups__(self, regional_client): - logger.info("EC2 - Describing Security Groups...") try: describe_security_groups_paginator = regional_client.get_paginator( "describe_security_groups" @@ -155,7 +153,6 @@ class EC2(AWSService): ) def __describe_network_acls__(self, regional_client): - logger.info("EC2 - Describing Network ACLs...") try: describe_network_acls_paginator = regional_client.get_paginator( "describe_network_acls" @@ -186,7 +183,6 @@ class EC2(AWSService): ) def __describe_snapshots__(self, regional_client): - logger.info("EC2 - Describing Snapshots...") try: snapshots_in_region = False describe_snapshots_paginator = regional_client.get_paginator( @@ -219,35 +215,30 @@ class EC2(AWSService): f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - def __get_snapshot_public__(self): - logger.info("EC2 - Getting snapshot volume attribute permissions...") - for snapshot in self.snapshots: - try: - regional_client = self.regional_clients[snapshot.region] - snapshot_public = regional_client.describe_snapshot_attribute( - Attribute="createVolumePermission", SnapshotId=snapshot.id - ) - for permission in snapshot_public["CreateVolumePermissions"]: - if "Group" in permission: - if permission["Group"] == "all": - snapshot.public = True + def __determine_public_snapshots__(self, snapshot): + try: + regional_client = self.regional_clients[snapshot.region] + snapshot_public = regional_client.describe_snapshot_attribute( + Attribute="createVolumePermission", SnapshotId=snapshot.id + ) + for permission in snapshot_public["CreateVolumePermissions"]: + if "Group" in permission: + if permission["Group"] == "all": + snapshot.public = True - except ClientError as error: - if error.response["Error"]["Code"] == "InvalidSnapshot.NotFound": - logger.warning( - f"{snapshot.region} --" - f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:" - f" {error}" - ) - continue - - except Exception as error: - logger.error( - f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + except ClientError as error: + if error.response["Error"]["Code"] == "InvalidSnapshot.NotFound": + logger.warning( + f"{snapshot.region} --" + f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:" + f" {error}" ) + except Exception as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) def __describe_public_network_interfaces__(self, regional_client): - logger.info("EC2 - Describing Network Interfaces...") try: # Get Network Interfaces with Public IPs describe_network_interfaces_paginator = regional_client.get_paginator( @@ -274,7 +265,6 @@ class EC2(AWSService): ) def __describe_sg_network_interfaces__(self, regional_client): - logger.info("EC2 - Describing Network Interfaces...") try: # Get Network Interfaces for Security Groups for sg in self.security_groups: @@ -299,30 +289,25 @@ class EC2(AWSService): f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - def __get_instance_user_data__(self): - logger.info("EC2 - Getting instance user data...") - for instance in self.instances: - try: - regional_client = self.regional_clients[instance.region] - user_data = regional_client.describe_instance_attribute( - Attribute="userData", InstanceId=instance.id - )["UserData"] - if "Value" in user_data: - instance.user_data = user_data["Value"] - - except ClientError as error: - if error.response["Error"]["Code"] == "InvalidInstanceID.NotFound": - logger.warning( - f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" - ) - continue - except Exception as error: - logger.error( + def __get_instance_user_data__(self, instance): + try: + regional_client = self.regional_clients[instance.region] + user_data = regional_client.describe_instance_attribute( + Attribute="userData", InstanceId=instance.id + )["UserData"] + if "Value" in user_data: + instance.user_data = user_data["Value"] + except ClientError as error: + if error.response["Error"]["Code"] == "InvalidInstanceID.NotFound": + logger.warning( f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) + except Exception as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) def __describe_images__(self, regional_client): - logger.info("EC2 - Describing Images...") try: for image in regional_client.describe_images(Owners=["self"])["Images"]: arn = f"arn:{self.audited_partition}:ec2:{regional_client.region}:{self.audited_account}:image/{image['ImageId']}" @@ -345,7 +330,6 @@ class EC2(AWSService): ) def __describe_volumes__(self, regional_client): - logger.info("EC2 - Describing Volumes...") try: describe_volumes_paginator = regional_client.get_paginator( "describe_volumes" @@ -370,8 +354,7 @@ class EC2(AWSService): f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - def __describe_addresses__(self, regional_client): - logger.info("EC2 - Describing Elastic IPs...") + def __describe_ec2_addresses__(self, regional_client): try: for address in regional_client.describe_addresses()["Addresses"]: public_ip = None @@ -402,8 +385,7 @@ class EC2(AWSService): f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - def __get_ebs_encryption_by_default__(self, regional_client): - logger.info("EC2 - Get EBS Encryption By Default...") + def __get_ebs_encryption_settings__(self, regional_client): try: volumes_in_region = False for volume in self.volumes: diff --git a/prowler/providers/aws/services/s3/s3_service.py b/prowler/providers/aws/services/s3/s3_service.py index 12a85b41..1bcb9275 100644 --- a/prowler/providers/aws/services/s3/s3_service.py +++ b/prowler/providers/aws/services/s3/s3_service.py @@ -28,6 +28,7 @@ class S3(AWSService): self.__threading_call__(self.__get_bucket_tagging__) # In the S3 service we override the "__threading_call__" method because we spawn a process per bucket instead of per region + # TODO: Replace the above function with the service __threading_call__ using the buckets as the iterator def __threading_call__(self, call): threads = [] for bucket in self.buckets: