fix(threading): Improved threading for the AWS Service (#3175)

Co-authored-by: Pepe Fagoaga <pepe@verica.io>
This commit is contained in:
Fennerr
2023-12-12 13:50:26 +02:00
committed by GitHub
parent 3c3dfb380b
commit 2441cca810
3 changed files with 81 additions and 68 deletions

View File

@@ -1,17 +1,21 @@
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from prowler.lib.logger import logger
from prowler.providers.aws.aws_provider import (
generate_regional_clients,
get_default_region,
)
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
MAX_WORKERS = 10
class AWSService:
"""The AWSService class offers a parent class for each AWS Service to generate:
- AWS Regional Clients
- Shared information like the account ID and ARN, the the AWS partition and the checks audited
- AWS Session
- Thread pool for the __threading_call__
- Also handles if the AWS Service is Global
"""
@@ -42,14 +46,40 @@ class AWSService:
self.region = get_default_region(self.service, audit_info)
self.client = self.session.client(self.service, self.region)
# Thread pool for __threading_call__
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
def __get_session__(self):
return self.session
def __threading_call__(self, call):
threads = []
for regional_client in self.regional_clients.values():
threads.append(threading.Thread(target=call, args=(regional_client,)))
for t in threads:
t.start()
for t in threads:
t.join()
def __threading_call__(self, call, iterator=None):
# Use the provided iterator, or default to self.regional_clients
items = iterator if iterator is not None else self.regional_clients.values()
# Determine the total count for logging
item_count = len(items)
# Trim leading and trailing underscores from the call's name
call_name = call.__name__.strip("_")
# Add Capitalization
call_name = " ".join([x.capitalize() for x in call_name.split("_")])
# Print a message based on the call's name, and if its regional or processing a list of items
if iterator is None:
logger.info(
f"{self.service.upper()} - Starting threads for '{call_name}' function across {item_count} regions..."
)
else:
logger.info(
f"{self.service.upper()} - Starting threads for '{call_name}' function to process {item_count} items..."
)
# Submit tasks to the thread pool
futures = [self.thread_pool.submit(call, item) for item in items]
# Wait for all tasks to complete
for future in as_completed(futures):
try:
future.result() # Raises exceptions from the thread, if any
except Exception:
# Handle exceptions if necessary
pass # Replace 'pass' with any additional exception handling logic. Currently handled within the called function

View File

@@ -17,7 +17,7 @@ class EC2(AWSService):
super().__init__(__class__.__name__, audit_info)
self.instances = []
self.__threading_call__(self.__describe_instances__)
self.__get_instance_user_data__()
self.__threading_call__(self.__get_instance_user_data__, self.instances)
self.security_groups = []
self.regions_with_sgs = []
self.__threading_call__(self.__describe_security_groups__)
@@ -27,7 +27,7 @@ class EC2(AWSService):
self.volumes_with_snapshots = {}
self.regions_with_snapshots = {}
self.__threading_call__(self.__describe_snapshots__)
self.__get_snapshot_public__()
self.__threading_call__(self.__determine_public_snapshots__, self.snapshots)
self.network_interfaces = []
self.__threading_call__(self.__describe_public_network_interfaces__)
self.__threading_call__(self.__describe_sg_network_interfaces__)
@@ -36,12 +36,11 @@ class EC2(AWSService):
self.volumes = []
self.__threading_call__(self.__describe_volumes__)
self.ebs_encryption_by_default = []
self.__threading_call__(self.__get_ebs_encryption_by_default__)
self.__threading_call__(self.__get_ebs_encryption_settings__)
self.elastic_ips = []
self.__threading_call__(self.__describe_addresses__)
self.__threading_call__(self.__describe_ec2_addresses__)
def __describe_instances__(self, regional_client):
logger.info("EC2 - Describing EC2 Instances...")
try:
describe_instances_paginator = regional_client.get_paginator(
"describe_instances"
@@ -106,7 +105,6 @@ class EC2(AWSService):
)
def __describe_security_groups__(self, regional_client):
logger.info("EC2 - Describing Security Groups...")
try:
describe_security_groups_paginator = regional_client.get_paginator(
"describe_security_groups"
@@ -155,7 +153,6 @@ class EC2(AWSService):
)
def __describe_network_acls__(self, regional_client):
logger.info("EC2 - Describing Network ACLs...")
try:
describe_network_acls_paginator = regional_client.get_paginator(
"describe_network_acls"
@@ -186,7 +183,6 @@ class EC2(AWSService):
)
def __describe_snapshots__(self, regional_client):
logger.info("EC2 - Describing Snapshots...")
try:
snapshots_in_region = False
describe_snapshots_paginator = regional_client.get_paginator(
@@ -219,35 +215,30 @@ class EC2(AWSService):
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __get_snapshot_public__(self):
logger.info("EC2 - Getting snapshot volume attribute permissions...")
for snapshot in self.snapshots:
try:
regional_client = self.regional_clients[snapshot.region]
snapshot_public = regional_client.describe_snapshot_attribute(
Attribute="createVolumePermission", SnapshotId=snapshot.id
)
for permission in snapshot_public["CreateVolumePermissions"]:
if "Group" in permission:
if permission["Group"] == "all":
snapshot.public = True
def __determine_public_snapshots__(self, snapshot):
try:
regional_client = self.regional_clients[snapshot.region]
snapshot_public = regional_client.describe_snapshot_attribute(
Attribute="createVolumePermission", SnapshotId=snapshot.id
)
for permission in snapshot_public["CreateVolumePermissions"]:
if "Group" in permission:
if permission["Group"] == "all":
snapshot.public = True
except ClientError as error:
if error.response["Error"]["Code"] == "InvalidSnapshot.NotFound":
logger.warning(
f"{snapshot.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
continue
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
except ClientError as error:
if error.response["Error"]["Code"] == "InvalidSnapshot.NotFound":
logger.warning(
f"{snapshot.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __describe_public_network_interfaces__(self, regional_client):
logger.info("EC2 - Describing Network Interfaces...")
try:
# Get Network Interfaces with Public IPs
describe_network_interfaces_paginator = regional_client.get_paginator(
@@ -274,7 +265,6 @@ class EC2(AWSService):
)
def __describe_sg_network_interfaces__(self, regional_client):
logger.info("EC2 - Describing Network Interfaces...")
try:
# Get Network Interfaces for Security Groups
for sg in self.security_groups:
@@ -299,30 +289,25 @@ class EC2(AWSService):
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __get_instance_user_data__(self):
logger.info("EC2 - Getting instance user data...")
for instance in self.instances:
try:
regional_client = self.regional_clients[instance.region]
user_data = regional_client.describe_instance_attribute(
Attribute="userData", InstanceId=instance.id
)["UserData"]
if "Value" in user_data:
instance.user_data = user_data["Value"]
except ClientError as error:
if error.response["Error"]["Code"] == "InvalidInstanceID.NotFound":
logger.warning(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
continue
except Exception as error:
logger.error(
def __get_instance_user_data__(self, instance):
try:
regional_client = self.regional_clients[instance.region]
user_data = regional_client.describe_instance_attribute(
Attribute="userData", InstanceId=instance.id
)["UserData"]
if "Value" in user_data:
instance.user_data = user_data["Value"]
except ClientError as error:
if error.response["Error"]["Code"] == "InvalidInstanceID.NotFound":
logger.warning(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __describe_images__(self, regional_client):
logger.info("EC2 - Describing Images...")
try:
for image in regional_client.describe_images(Owners=["self"])["Images"]:
arn = f"arn:{self.audited_partition}:ec2:{regional_client.region}:{self.audited_account}:image/{image['ImageId']}"
@@ -345,7 +330,6 @@ class EC2(AWSService):
)
def __describe_volumes__(self, regional_client):
logger.info("EC2 - Describing Volumes...")
try:
describe_volumes_paginator = regional_client.get_paginator(
"describe_volumes"
@@ -370,8 +354,7 @@ class EC2(AWSService):
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __describe_addresses__(self, regional_client):
logger.info("EC2 - Describing Elastic IPs...")
def __describe_ec2_addresses__(self, regional_client):
try:
for address in regional_client.describe_addresses()["Addresses"]:
public_ip = None
@@ -402,8 +385,7 @@ class EC2(AWSService):
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def __get_ebs_encryption_by_default__(self, regional_client):
logger.info("EC2 - Get EBS Encryption By Default...")
def __get_ebs_encryption_settings__(self, regional_client):
try:
volumes_in_region = False
for volume in self.volumes:

View File

@@ -28,6 +28,7 @@ class S3(AWSService):
self.__threading_call__(self.__get_bucket_tagging__)
# In the S3 service we override the "__threading_call__" method because we spawn a process per bucket instead of per region
# TODO: Replace the above function with the service __threading_call__ using the buckets as the iterator
def __threading_call__(self, call):
threads = []
for bucket in self.buckets: