mirror of
https://github.com/ghndrx/prowler.git
synced 2026-02-10 06:45:08 +00:00
fix(threading): Improved threading for the AWS Service (#3175)
Co-authored-by: Pepe Fagoaga <pepe@verica.io>
This commit is contained in:
@@ -1,17 +1,21 @@
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.aws_provider import (
|
||||
generate_regional_clients,
|
||||
get_default_region,
|
||||
)
|
||||
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
|
||||
|
||||
MAX_WORKERS = 10
|
||||
|
||||
|
||||
class AWSService:
|
||||
"""The AWSService class offers a parent class for each AWS Service to generate:
|
||||
- AWS Regional Clients
|
||||
- Shared information like the account ID and ARN, the the AWS partition and the checks audited
|
||||
- AWS Session
|
||||
- Thread pool for the __threading_call__
|
||||
- Also handles if the AWS Service is Global
|
||||
"""
|
||||
|
||||
@@ -42,14 +46,40 @@ class AWSService:
|
||||
self.region = get_default_region(self.service, audit_info)
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
|
||||
# Thread pool for __threading_call__
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
||||
|
||||
def __get_session__(self):
|
||||
return self.session
|
||||
|
||||
def __threading_call__(self, call):
|
||||
threads = []
|
||||
for regional_client in self.regional_clients.values():
|
||||
threads.append(threading.Thread(target=call, args=(regional_client,)))
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
def __threading_call__(self, call, iterator=None):
|
||||
# Use the provided iterator, or default to self.regional_clients
|
||||
items = iterator if iterator is not None else self.regional_clients.values()
|
||||
# Determine the total count for logging
|
||||
item_count = len(items)
|
||||
|
||||
# Trim leading and trailing underscores from the call's name
|
||||
call_name = call.__name__.strip("_")
|
||||
# Add Capitalization
|
||||
call_name = " ".join([x.capitalize() for x in call_name.split("_")])
|
||||
|
||||
# Print a message based on the call's name, and if its regional or processing a list of items
|
||||
if iterator is None:
|
||||
logger.info(
|
||||
f"{self.service.upper()} - Starting threads for '{call_name}' function across {item_count} regions..."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{self.service.upper()} - Starting threads for '{call_name}' function to process {item_count} items..."
|
||||
)
|
||||
|
||||
# Submit tasks to the thread pool
|
||||
futures = [self.thread_pool.submit(call, item) for item in items]
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result() # Raises exceptions from the thread, if any
|
||||
except Exception:
|
||||
# Handle exceptions if necessary
|
||||
pass # Replace 'pass' with any additional exception handling logic. Currently handled within the called function
|
||||
|
||||
@@ -17,7 +17,7 @@ class EC2(AWSService):
|
||||
super().__init__(__class__.__name__, audit_info)
|
||||
self.instances = []
|
||||
self.__threading_call__(self.__describe_instances__)
|
||||
self.__get_instance_user_data__()
|
||||
self.__threading_call__(self.__get_instance_user_data__, self.instances)
|
||||
self.security_groups = []
|
||||
self.regions_with_sgs = []
|
||||
self.__threading_call__(self.__describe_security_groups__)
|
||||
@@ -27,7 +27,7 @@ class EC2(AWSService):
|
||||
self.volumes_with_snapshots = {}
|
||||
self.regions_with_snapshots = {}
|
||||
self.__threading_call__(self.__describe_snapshots__)
|
||||
self.__get_snapshot_public__()
|
||||
self.__threading_call__(self.__determine_public_snapshots__, self.snapshots)
|
||||
self.network_interfaces = []
|
||||
self.__threading_call__(self.__describe_public_network_interfaces__)
|
||||
self.__threading_call__(self.__describe_sg_network_interfaces__)
|
||||
@@ -36,12 +36,11 @@ class EC2(AWSService):
|
||||
self.volumes = []
|
||||
self.__threading_call__(self.__describe_volumes__)
|
||||
self.ebs_encryption_by_default = []
|
||||
self.__threading_call__(self.__get_ebs_encryption_by_default__)
|
||||
self.__threading_call__(self.__get_ebs_encryption_settings__)
|
||||
self.elastic_ips = []
|
||||
self.__threading_call__(self.__describe_addresses__)
|
||||
self.__threading_call__(self.__describe_ec2_addresses__)
|
||||
|
||||
def __describe_instances__(self, regional_client):
|
||||
logger.info("EC2 - Describing EC2 Instances...")
|
||||
try:
|
||||
describe_instances_paginator = regional_client.get_paginator(
|
||||
"describe_instances"
|
||||
@@ -106,7 +105,6 @@ class EC2(AWSService):
|
||||
)
|
||||
|
||||
def __describe_security_groups__(self, regional_client):
|
||||
logger.info("EC2 - Describing Security Groups...")
|
||||
try:
|
||||
describe_security_groups_paginator = regional_client.get_paginator(
|
||||
"describe_security_groups"
|
||||
@@ -155,7 +153,6 @@ class EC2(AWSService):
|
||||
)
|
||||
|
||||
def __describe_network_acls__(self, regional_client):
|
||||
logger.info("EC2 - Describing Network ACLs...")
|
||||
try:
|
||||
describe_network_acls_paginator = regional_client.get_paginator(
|
||||
"describe_network_acls"
|
||||
@@ -186,7 +183,6 @@ class EC2(AWSService):
|
||||
)
|
||||
|
||||
def __describe_snapshots__(self, regional_client):
|
||||
logger.info("EC2 - Describing Snapshots...")
|
||||
try:
|
||||
snapshots_in_region = False
|
||||
describe_snapshots_paginator = regional_client.get_paginator(
|
||||
@@ -219,35 +215,30 @@ class EC2(AWSService):
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def __get_snapshot_public__(self):
|
||||
logger.info("EC2 - Getting snapshot volume attribute permissions...")
|
||||
for snapshot in self.snapshots:
|
||||
try:
|
||||
regional_client = self.regional_clients[snapshot.region]
|
||||
snapshot_public = regional_client.describe_snapshot_attribute(
|
||||
Attribute="createVolumePermission", SnapshotId=snapshot.id
|
||||
)
|
||||
for permission in snapshot_public["CreateVolumePermissions"]:
|
||||
if "Group" in permission:
|
||||
if permission["Group"] == "all":
|
||||
snapshot.public = True
|
||||
def __determine_public_snapshots__(self, snapshot):
|
||||
try:
|
||||
regional_client = self.regional_clients[snapshot.region]
|
||||
snapshot_public = regional_client.describe_snapshot_attribute(
|
||||
Attribute="createVolumePermission", SnapshotId=snapshot.id
|
||||
)
|
||||
for permission in snapshot_public["CreateVolumePermissions"]:
|
||||
if "Group" in permission:
|
||||
if permission["Group"] == "all":
|
||||
snapshot.public = True
|
||||
|
||||
except ClientError as error:
|
||||
if error.response["Error"]["Code"] == "InvalidSnapshot.NotFound":
|
||||
logger.warning(
|
||||
f"{snapshot.region} --"
|
||||
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
|
||||
f" {error}"
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
except ClientError as error:
|
||||
if error.response["Error"]["Code"] == "InvalidSnapshot.NotFound":
|
||||
logger.warning(
|
||||
f"{snapshot.region} --"
|
||||
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
|
||||
f" {error}"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def __describe_public_network_interfaces__(self, regional_client):
|
||||
logger.info("EC2 - Describing Network Interfaces...")
|
||||
try:
|
||||
# Get Network Interfaces with Public IPs
|
||||
describe_network_interfaces_paginator = regional_client.get_paginator(
|
||||
@@ -274,7 +265,6 @@ class EC2(AWSService):
|
||||
)
|
||||
|
||||
def __describe_sg_network_interfaces__(self, regional_client):
|
||||
logger.info("EC2 - Describing Network Interfaces...")
|
||||
try:
|
||||
# Get Network Interfaces for Security Groups
|
||||
for sg in self.security_groups:
|
||||
@@ -299,30 +289,25 @@ class EC2(AWSService):
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def __get_instance_user_data__(self):
|
||||
logger.info("EC2 - Getting instance user data...")
|
||||
for instance in self.instances:
|
||||
try:
|
||||
regional_client = self.regional_clients[instance.region]
|
||||
user_data = regional_client.describe_instance_attribute(
|
||||
Attribute="userData", InstanceId=instance.id
|
||||
)["UserData"]
|
||||
if "Value" in user_data:
|
||||
instance.user_data = user_data["Value"]
|
||||
|
||||
except ClientError as error:
|
||||
if error.response["Error"]["Code"] == "InvalidInstanceID.NotFound":
|
||||
logger.warning(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
continue
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
def __get_instance_user_data__(self, instance):
|
||||
try:
|
||||
regional_client = self.regional_clients[instance.region]
|
||||
user_data = regional_client.describe_instance_attribute(
|
||||
Attribute="userData", InstanceId=instance.id
|
||||
)["UserData"]
|
||||
if "Value" in user_data:
|
||||
instance.user_data = user_data["Value"]
|
||||
except ClientError as error:
|
||||
if error.response["Error"]["Code"] == "InvalidInstanceID.NotFound":
|
||||
logger.warning(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def __describe_images__(self, regional_client):
|
||||
logger.info("EC2 - Describing Images...")
|
||||
try:
|
||||
for image in regional_client.describe_images(Owners=["self"])["Images"]:
|
||||
arn = f"arn:{self.audited_partition}:ec2:{regional_client.region}:{self.audited_account}:image/{image['ImageId']}"
|
||||
@@ -345,7 +330,6 @@ class EC2(AWSService):
|
||||
)
|
||||
|
||||
def __describe_volumes__(self, regional_client):
|
||||
logger.info("EC2 - Describing Volumes...")
|
||||
try:
|
||||
describe_volumes_paginator = regional_client.get_paginator(
|
||||
"describe_volumes"
|
||||
@@ -370,8 +354,7 @@ class EC2(AWSService):
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def __describe_addresses__(self, regional_client):
|
||||
logger.info("EC2 - Describing Elastic IPs...")
|
||||
def __describe_ec2_addresses__(self, regional_client):
|
||||
try:
|
||||
for address in regional_client.describe_addresses()["Addresses"]:
|
||||
public_ip = None
|
||||
@@ -402,8 +385,7 @@ class EC2(AWSService):
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def __get_ebs_encryption_by_default__(self, regional_client):
|
||||
logger.info("EC2 - Get EBS Encryption By Default...")
|
||||
def __get_ebs_encryption_settings__(self, regional_client):
|
||||
try:
|
||||
volumes_in_region = False
|
||||
for volume in self.volumes:
|
||||
|
||||
@@ -28,6 +28,7 @@ class S3(AWSService):
|
||||
self.__threading_call__(self.__get_bucket_tagging__)
|
||||
|
||||
# In the S3 service we override the "__threading_call__" method because we spawn a process per bucket instead of per region
|
||||
# TODO: Replace the above function with the service __threading_call__ using the buckets as the iterator
|
||||
def __threading_call__(self, call):
|
||||
threads = []
|
||||
for bucket in self.buckets:
|
||||
|
||||
Reference in New Issue
Block a user