feat(global_aws_session): Global data structure for the current AWS audit (#1212)

* fix(audit info): Common data structure for current audit

* fix(iam): iam session audit fixed

* feat(aws_session): Include else block

Co-authored-by: Pepe Fagoaga <pepe@verica.io>
This commit is contained in:
Nacho Rivera
2022-06-21 07:53:49 +02:00
committed by GitHub
parent b89b883741
commit e52ab12696
5 changed files with 172 additions and 176 deletions

View File

@@ -1,6 +1,3 @@
from dataclasses import dataclass
from datetime import datetime
from arnparse import arnparse from arnparse import arnparse
from boto3 import session from boto3 import session
from botocore.credentials import RefreshableCredentials from botocore.credentials import RefreshableCredentials
@@ -8,63 +5,32 @@ from botocore.session import get_session
from lib.arn.arn import arn_parsing from lib.arn.arn import arn_parsing
from lib.logger import logger from lib.logger import logger
from providers.aws.models import AWS_Assume_Role, AWS_Audit_Info, AWS_Credentials
@dataclass
class AWS_Credentials:
aws_access_key_id: str
aws_session_token: str
aws_secret_access_key: str
expiration: datetime
@dataclass
class Input_Data:
profile: str
role_arn: str
session_duration: int
external_id: str
regions: list
@dataclass
class AWS_Assume_Role:
role_arn: str
session_duration: int
external_id: str
sts_session: session
partition: str
@dataclass
class AWS_Session_Info:
profile: str
credentials: AWS_Credentials
role_info: AWS_Assume_Role
################## AWS PROVIDER ################## AWS PROVIDER
class AWS_Provider: class AWS_Provider:
def __init__(self, session_info): def __init__(self, audit_info):
self.aws_session = self.set_session(session_info) logger.info("Instantiating aws provider ...")
self.role_info = session_info.role_info self.aws_session = self.set_session(audit_info)
self.role_info = audit_info.assumed_role_info
def get_session(self): def get_session(self):
return self.aws_session return self.aws_session
def set_session(self, session_info): def set_session(self, audit_info):
try: try:
if session_info.credentials: if audit_info.credentials:
# If we receive a credentials object filled is coming form an assumed role, so renewal is needed # If we receive a credentials object filled is coming form an assumed role, so renewal is needed
logger.info("Creating session for assumed role ...") logger.info("Creating session for assumed role ...")
# From botocore we can use RefreshableCredentials class, which has an attribute (refresh_using) # From botocore we can use RefreshableCredentials class, which has an attribute (refresh_using)
# that needs to be a method without arguments that retrieves a new set of fresh credentials # that needs to be a method without arguments that retrieves a new set of fresh credentials
# asuming the role again. -> https://github.com/boto/botocore/blob/098cc255f81a25b852e1ecdeb7adebd94c7b1b73/botocore/credentials.py#L395 # asuming the role again. -> https://github.com/boto/botocore/blob/098cc255f81a25b852e1ecdeb7adebd94c7b1b73/botocore/credentials.py#L395
assumed_refreshable_credentials = RefreshableCredentials( assumed_refreshable_credentials = RefreshableCredentials(
access_key=session_info.credentials.aws_access_key_id, access_key=audit_info.credentials.aws_access_key_id,
secret_key=session_info.credentials.aws_secret_access_key, secret_key=audit_info.credentials.aws_secret_access_key,
token=session_info.credentials.aws_session_token, token=audit_info.credentials.aws_session_token,
expiry_time=session_info.credentials.expiration, expiry_time=audit_info.credentials.expiration,
refresh_using=self.refresh, refresh_using=self.refresh,
method="sts-assume-role", method="sts-assume-role",
) )
@@ -74,13 +40,13 @@ class AWS_Provider:
assumed_botocore_session.set_config_variable("region", "us-east-1") assumed_botocore_session.set_config_variable("region", "us-east-1")
return session.Session( return session.Session(
profile_name=session_info.profile, profile_name=audit_info.profile,
botocore_session=assumed_botocore_session, botocore_session=assumed_botocore_session,
) )
# If we do not receive credentials start the session using the profile # If we do not receive credentials start the session using the profile
else: else:
logger.info("Creating session for not assumed identity ...") logger.info("Creating session for not assumed identity ...")
return session.Session(profile_name=session_info.profile) return session.Session(profile_name=audit_info.profile)
except Exception as error: except Exception as error:
logger.critical(f"{error.__class__.__name__} -- {error}") logger.critical(f"{error.__class__.__name__} -- {error}")
quit() quit()
@@ -105,6 +71,84 @@ class AWS_Provider:
return refreshed_credentials return refreshed_credentials
def provider_set_session(
input_profile, input_role, input_session_duration, input_external_id, input_regions
):
# Mark variable that stores all the info about the audit as global
global current_audit_info
assumed_session = None
# Setting session
current_audit_info = AWS_Audit_Info(
original_session=None,
audit_session=None,
audited_account=None,
audited_partition=None,
profile=input_profile,
credentials=None,
assumed_role_info=AWS_Assume_Role(
role_arn=input_role,
session_duration=input_session_duration,
external_id=input_external_id,
),
audited_regions=input_regions,
)
logger.info("Generating original session ...")
# Create an global original session using only profile/basic credentials info
current_audit_info.original_session = AWS_Provider(current_audit_info).get_session()
logger.info("Validating credentials ...")
# Verificate if we have valid credentials
caller_identity = validate_credentials(current_audit_info.original_session)
logger.info("Credentials validated")
logger.info(f"Original caller identity UserId : {caller_identity['UserId']}")
logger.info(f"Original caller identity ARN : {caller_identity['Arn']}")
current_audit_info.audited_account = caller_identity["Account"]
current_audit_info.audited_partition = arnparse(caller_identity["Arn"]).partition
logger.info("Checking if role assumption is needed ...")
if current_audit_info.assumed_role_info.role_arn:
# Check if role arn is valid
try:
# this returns the arn already parsed, calls arnparse, into a dict to be used when it is needed to access its fields
role_arn_parsed = arn_parsing(current_audit_info.assumed_role_info.role_arn)
except Exception as error:
logger.critical(f"{error.__class__.__name__} -- {error}")
quit()
else:
logger.info(
f"Assuming role {current_audit_info.assumed_role_info.role_arn}"
)
# Assume the role
assumed_role_response = assume_role()
logger.info("Role assumed")
# Set the info needed to create a session with an assumed role
current_audit_info.credentials = AWS_Credentials(
aws_access_key_id=assumed_role_response["Credentials"]["AccessKeyId"],
aws_session_token=assumed_role_response["Credentials"]["SessionToken"],
aws_secret_access_key=assumed_role_response["Credentials"][
"SecretAccessKey"
],
expiration=assumed_role_response["Credentials"]["Expiration"],
)
assumed_session = AWS_Provider(current_audit_info).get_session()
if assumed_session:
logger.info("Audit session is the new session created assuming role")
current_audit_info.audit_session = assumed_session
current_audit_info.audited_account = role_arn_parsed.account_id
current_audit_info.audited_partition = role_arn_parsed.partition
else:
logger.info("Audit session is the original one")
current_audit_info.audit_session = current_audit_info.original_session
def validate_credentials(validate_session): def validate_credentials(validate_session):
try: try:
validate_credentials_client = validate_session.client("sts") validate_credentials_client = validate_session.client("sts")
@@ -116,103 +160,25 @@ def validate_credentials(validate_session):
return caller_identity return caller_identity
def provider_set_session(session_input): def assume_role():
# global variables that are going to be shared accross the project
global aws_session
global original_session
global audited_regions
global audited_partition
global audited_account
assumed_session = None
# Initialize a session info dataclass only with info about the profile
session_info = AWS_Session_Info(
session_input.profile,
None,
None,
)
# Create an global original session using only profile/basic credentials info
original_session = AWS_Provider(session_info).get_session()
logger.info("Validating credentials ...")
# Verificate if we have valid credentials
caller_identity = validate_credentials(original_session)
logger.info("Credentials validated")
logger.info(f"Original caller identity UserId : {caller_identity['UserId']}")
logger.info(f"Original caller identity ARN : {caller_identity['Arn']}")
# Set some global values for original session
audited_regions = session_input.regions
audited_account = caller_identity["Account"]
audited_partition = arnparse(caller_identity["Arn"]).partition
if session_input.role_arn:
# Check if role arn is valid
try:
# this returns the arn already parsed, calls arnparse, into a dict to be used when it is needed to access its fields
role_arn_parsed = arn_parsing(session_input.role_arn)
except Exception as error:
logger.critical(f"{error.__class__.__name__} -- {error}")
quit()
# Set info for role assumption if needed
role_info = AWS_Assume_Role(
session_input.role_arn,
session_input.session_duration,
session_input.external_id,
original_session,
audited_partition,
)
logger.info(f"Assuming role {role_info.role_arn}")
# Assume the role
assumed_role_response = assume_role(role_info)
logger.info("Role assumed")
# Set the info needed to create a session with an assumed role
session_info = AWS_Session_Info(
session_input.profile,
AWS_Credentials(
aws_access_key_id=assumed_role_response["Credentials"]["AccessKeyId"],
aws_session_token=assumed_role_response["Credentials"]["SessionToken"],
aws_secret_access_key=assumed_role_response["Credentials"][
"SecretAccessKey"
],
expiration=assumed_role_response["Credentials"]["Expiration"],
),
role_info,
)
assumed_session = AWS_Provider(session_info).get_session()
if assumed_session:
aws_session = assumed_session
audited_account = role_arn_parsed.account_id
audited_partition = role_arn_parsed.partition
else:
aws_session = original_session
def assume_role(role_info):
try: try:
# set the info to assume the role from the partition, account and role name # set the info to assume the role from the partition, account and role name
sts_client = role_info.sts_session.client("sts") sts_client = current_audit_info.original_session.client("sts")
# If external id, set it to the assume role api call # If external id, set it to the assume role api call
if role_info.external_id: if current_audit_info.assumed_role_info.external_id:
assumed_credentials = sts_client.assume_role( assumed_credentials = sts_client.assume_role(
RoleArn=role_info.role_arn, RoleArn=current_audit_info.assumed_role_info.role_arn,
RoleSessionName="ProwlerProSession", RoleSessionName="ProwlerProAsessmentSession",
DurationSeconds=role_info.session_duration, DurationSeconds=current_audit_info.assumed_role_info.session_duration,
ExternalId=role_info.external_id, ExternalId=current_audit_info.assumed_role_info.external_id,
) )
# else assume the role without the external id # else assume the role without the external id
else: else:
assumed_credentials = sts_client.assume_role( assumed_credentials = sts_client.assume_role(
RoleArn=role_info.role_arn, RoleArn=current_audit_info.assumed_role_info.role_arn,
RoleSessionName="ProwlerProSession", RoleSessionName="ProwlerProAsessmentSession",
DurationSeconds=role_info.session_duration, DurationSeconds=current_audit_info.assumed_role_info.session_duration,
) )
except Exception as error: except Exception as error:
logger.critical(f"{error.__class__.__name__} -- {error}") logger.critical(f"{error.__class__.__name__} -- {error}")

31
providers/aws/models.py Normal file
View File

@@ -0,0 +1,31 @@
from dataclasses import dataclass
from datetime import datetime
from boto3 import session
@dataclass
class AWS_Credentials:
aws_access_key_id: str
aws_session_token: str
aws_secret_access_key: str
expiration: datetime
@dataclass
class AWS_Assume_Role:
role_arn: str
session_duration: int
external_id: str
@dataclass
class AWS_Audit_Info:
original_session: session.Session
audit_session: session.Session
audited_account: int
audited_partition: str
profile: str
credentials: AWS_Credentials
assumed_role_info: AWS_Assume_Role
audited_regions: list

View File

@@ -5,31 +5,24 @@ import urllib.request
from config.config import aws_services_json_file, aws_services_json_url from config.config import aws_services_json_file, aws_services_json_url
from lib.logger import logger from lib.logger import logger
from lib.utils.utils import open_file, parse_json_file from lib.utils.utils import open_file, parse_json_file
from providers.aws.aws_provider import ( from providers.aws.aws_provider import current_audit_info
audited_account,
audited_partition,
audited_regions,
aws_session,
)
################## EC2 ################## EC2
class EC2: class EC2:
def __init__(self, aws_session, audited_regions): def __init__(self, audit_info):
self.service = "ec2" self.service = "ec2"
self.aws_session = aws_session self.session = audit_info.audit_session
self.audited_account = audit_info.audited_account
self.regional_clients = self.__generate_regional_clients__( self.regional_clients = self.__generate_regional_clients__(
self.service, audited_regions self.service, audit_info
) )
self.__threading_call__(self.__describe_snapshots__) self.__threading_call__(self.__describe_snapshots__)
def __get_clients__(self):
return self.clients
def __get_session__(self): def __get_session__(self):
return self.aws_session return self.session
def __generate_regional_clients__(self, service, audited_regions): def __generate_regional_clients__(self, service, audit_info):
regional_clients = [] regional_clients = []
try: # Try to get the list online try: # Try to get the list online
with urllib.request.urlopen(aws_services_json_url) as url: with urllib.request.urlopen(aws_services_json_url) as url:
@@ -40,46 +33,50 @@ class EC2:
data = parse_json_file(f) data = parse_json_file(f)
for att in data["prices"]: for att in data["prices"]:
if audited_regions: # Check for input aws audited_regions if (
audit_info.audited_regions
): # Check for input aws audit_info.audited_regions
if ( if (
service in att["id"].split(":")[0] service in att["id"].split(":")[0]
and att["attributes"]["aws:region"] in audited_regions and att["attributes"]["aws:region"] in audit_info.audited_regions
): # Check if service has this region ): # Check if service has this region
region = att["attributes"]["aws:region"] region = att["attributes"]["aws:region"]
regional_client = aws_session.client(service, region_name=region) regional_client = audit_info.audit_session.client(
service, region_name=region
)
regional_client.region = region regional_client.region = region
regional_clients.append(regional_client) regional_clients.append(regional_client)
else: else:
if audited_partition in "aws": if audit_info.audited_partition in "aws":
if ( if (
service in att["id"].split(":")[0] service in att["id"].split(":")[0]
and "gov" not in att["attributes"]["aws:region"] and "gov" not in att["attributes"]["aws:region"]
and "cn" not in att["attributes"]["aws:region"] and "cn" not in att["attributes"]["aws:region"]
): ):
region = att["attributes"]["aws:region"] region = att["attributes"]["aws:region"]
regional_client = aws_session.client( regional_client = audit_info.audit_session.client(
service, region_name=region service, region_name=region
) )
regional_client.region = region regional_client.region = region
regional_clients.append(regional_client) regional_clients.append(regional_client)
elif audited_partition in "cn": elif audit_info.audited_partition in "cn":
if ( if (
service in att["id"].split(":")[0] service in att["id"].split(":")[0]
and "cn" in att["attributes"]["aws:region"] and "cn" in att["attributes"]["aws:region"]
): ):
region = att["attributes"]["aws:region"] region = att["attributes"]["aws:region"]
regional_client = aws_session.client( regional_client = audit_info.audit_session.client(
service, region_name=region service, region_name=region
) )
regional_client.region = region regional_client.region = region
regional_clients.append(regional_client) regional_clients.append(regional_client)
elif audited_partition in "gov": elif audit_info.audited_partition in "gov":
if ( if (
service in att["id"].split(":")[0] service in att["id"].split(":")[0]
and "gov" in att["attributes"]["aws:region"] and "gov" in att["attributes"]["aws:region"]
): ):
region = att["attributes"]["aws:region"] region = att["attributes"]["aws:region"]
regional_client = aws_session.client( regional_client = audit_info.audit_session.client(
service, region_name=region service, region_name=region
) )
regional_client.region = region regional_client.region = region
@@ -90,13 +87,17 @@ class EC2:
def __threading_call__(self, call): def __threading_call__(self, call):
threads = [] threads = []
for regional_client in self.regional_clients: for regional_client in self.regional_clients:
threads.append(threading.Thread(target=call, args=(regional_client,))) threads.append(
threading.Thread(
target=call, args=(regional_client, self.audited_account)
)
)
for t in threads: for t in threads:
t.start() t.start()
for t in threads: for t in threads:
t.join() t.join()
def __describe_snapshots__(self, regional_client): def __describe_snapshots__(self, regional_client, audited_account):
logger.info("EC2 - Describing Snapshots...") logger.info("EC2 - Describing Snapshots...")
try: try:
describe_snapshots_paginator = regional_client.get_paginator( describe_snapshots_paginator = regional_client.get_paginator(
@@ -114,4 +115,4 @@ class EC2:
regional_client.snapshots = snapshots regional_client.snapshots = snapshots
ec2_client = EC2(aws_session, audited_regions) ec2_client = EC2(current_audit_info)

View File

@@ -1,13 +1,13 @@
from lib.logger import logger from lib.logger import logger
from providers.aws.aws_provider import aws_session from providers.aws.aws_provider import current_audit_info
################## IAM ################## IAM
class IAM: class IAM:
def __init__(self, session): def __init__(self, audit_info):
self.service = "iam" self.service = "iam"
self.session = session self.session = audit_info.audit_session
self.client = session.client(self.service) self.client = self.session.client(self.service)
self.users = self.__get_users__() self.users = self.__get_users__()
self.roles = self.__get_roles__() self.roles = self.__get_roles__()
self.customer_managed_policies = self.__get_customer_managed_policies__() self.customer_managed_policies = self.__get_customer_managed_policies__()
@@ -89,7 +89,7 @@ class IAM:
try: try:
iam_client = IAM(aws_session) iam_client = IAM(current_audit_info)
except Exception as error: except Exception as error:
logger.critical(f"{error.__class__.__name__} -- {error}") logger.critical(f"{error.__class__.__name__} -- {error}")
quit() quit()

View File

@@ -11,7 +11,7 @@ from lib.check.check import (
run_check, run_check,
) )
from lib.logger import logger, logging_levels from lib.logger import logger, logging_levels
from providers.aws.aws_provider import Input_Data, provider_set_session from providers.aws.aws_provider import provider_set_session
if __name__ == "__main__": if __name__ == "__main__":
# CLI Arguments # CLI Arguments
@@ -103,17 +103,15 @@ if __name__ == "__main__":
if args.no_banner: if args.no_banner:
print_banner() print_banner()
# Setting session # Set global session
session_input = Input_Data( provider_set_session(
profile=args.profile, args.profile,
role_arn=args.role, args.role,
session_duration=args.session_duration, args.session_duration,
external_id=args.external_id, args.external_id,
regions=args.filter_region, args.filter_region,
) )
provider_set_session(session_input)
# Load checks to execute # Load checks to execute
logger.debug("Loading checks") logger.debug("Loading checks")
checks_to_execute = load_checks_to_execute( checks_to_execute = load_checks_to_execute(