feat(global_aws_session): Global data structure for the current AWS audit (#1212)

* fix(audit info): Common data structure for current audit

* fix(iam): iam session audit fixed

* feat(aws_session): Include else block

Co-authored-by: Pepe Fagoaga <pepe@verica.io>
This commit is contained in:
Nacho Rivera
2022-06-21 07:53:49 +02:00
committed by GitHub
parent b89b883741
commit e52ab12696
5 changed files with 172 additions and 176 deletions

View File

@@ -1,6 +1,3 @@
from dataclasses import dataclass
from datetime import datetime
from arnparse import arnparse
from boto3 import session
from botocore.credentials import RefreshableCredentials
@@ -8,63 +5,32 @@ from botocore.session import get_session
from lib.arn.arn import arn_parsing
from lib.logger import logger
@dataclass
class AWS_Credentials:
aws_access_key_id: str
aws_session_token: str
aws_secret_access_key: str
expiration: datetime
@dataclass
class Input_Data:
profile: str
role_arn: str
session_duration: int
external_id: str
regions: list
@dataclass
class AWS_Assume_Role:
role_arn: str
session_duration: int
external_id: str
sts_session: session
partition: str
@dataclass
class AWS_Session_Info:
profile: str
credentials: AWS_Credentials
role_info: AWS_Assume_Role
from providers.aws.models import AWS_Assume_Role, AWS_Audit_Info, AWS_Credentials
################## AWS PROVIDER
class AWS_Provider:
def __init__(self, session_info):
self.aws_session = self.set_session(session_info)
self.role_info = session_info.role_info
def __init__(self, audit_info):
logger.info("Instantiating aws provider ...")
self.aws_session = self.set_session(audit_info)
self.role_info = audit_info.assumed_role_info
def get_session(self):
return self.aws_session
def set_session(self, session_info):
def set_session(self, audit_info):
try:
if session_info.credentials:
if audit_info.credentials:
# If we receive a credentials object filled is coming form an assumed role, so renewal is needed
logger.info("Creating session for assumed role ...")
# From botocore we can use RefreshableCredentials class, which has an attribute (refresh_using)
# that needs to be a method without arguments that retrieves a new set of fresh credentials
# asuming the role again. -> https://github.com/boto/botocore/blob/098cc255f81a25b852e1ecdeb7adebd94c7b1b73/botocore/credentials.py#L395
assumed_refreshable_credentials = RefreshableCredentials(
access_key=session_info.credentials.aws_access_key_id,
secret_key=session_info.credentials.aws_secret_access_key,
token=session_info.credentials.aws_session_token,
expiry_time=session_info.credentials.expiration,
access_key=audit_info.credentials.aws_access_key_id,
secret_key=audit_info.credentials.aws_secret_access_key,
token=audit_info.credentials.aws_session_token,
expiry_time=audit_info.credentials.expiration,
refresh_using=self.refresh,
method="sts-assume-role",
)
@@ -74,13 +40,13 @@ class AWS_Provider:
assumed_botocore_session.set_config_variable("region", "us-east-1")
return session.Session(
profile_name=session_info.profile,
profile_name=audit_info.profile,
botocore_session=assumed_botocore_session,
)
# If we do not receive credentials start the session using the profile
else:
logger.info("Creating session for not assumed identity ...")
return session.Session(profile_name=session_info.profile)
return session.Session(profile_name=audit_info.profile)
except Exception as error:
logger.critical(f"{error.__class__.__name__} -- {error}")
quit()
@@ -105,6 +71,84 @@ class AWS_Provider:
return refreshed_credentials
def provider_set_session(
input_profile, input_role, input_session_duration, input_external_id, input_regions
):
# Mark variable that stores all the info about the audit as global
global current_audit_info
assumed_session = None
# Setting session
current_audit_info = AWS_Audit_Info(
original_session=None,
audit_session=None,
audited_account=None,
audited_partition=None,
profile=input_profile,
credentials=None,
assumed_role_info=AWS_Assume_Role(
role_arn=input_role,
session_duration=input_session_duration,
external_id=input_external_id,
),
audited_regions=input_regions,
)
logger.info("Generating original session ...")
# Create an global original session using only profile/basic credentials info
current_audit_info.original_session = AWS_Provider(current_audit_info).get_session()
logger.info("Validating credentials ...")
# Verificate if we have valid credentials
caller_identity = validate_credentials(current_audit_info.original_session)
logger.info("Credentials validated")
logger.info(f"Original caller identity UserId : {caller_identity['UserId']}")
logger.info(f"Original caller identity ARN : {caller_identity['Arn']}")
current_audit_info.audited_account = caller_identity["Account"]
current_audit_info.audited_partition = arnparse(caller_identity["Arn"]).partition
logger.info("Checking if role assumption is needed ...")
if current_audit_info.assumed_role_info.role_arn:
# Check if role arn is valid
try:
# this returns the arn already parsed, calls arnparse, into a dict to be used when it is needed to access its fields
role_arn_parsed = arn_parsing(current_audit_info.assumed_role_info.role_arn)
except Exception as error:
logger.critical(f"{error.__class__.__name__} -- {error}")
quit()
else:
logger.info(
f"Assuming role {current_audit_info.assumed_role_info.role_arn}"
)
# Assume the role
assumed_role_response = assume_role()
logger.info("Role assumed")
# Set the info needed to create a session with an assumed role
current_audit_info.credentials = AWS_Credentials(
aws_access_key_id=assumed_role_response["Credentials"]["AccessKeyId"],
aws_session_token=assumed_role_response["Credentials"]["SessionToken"],
aws_secret_access_key=assumed_role_response["Credentials"][
"SecretAccessKey"
],
expiration=assumed_role_response["Credentials"]["Expiration"],
)
assumed_session = AWS_Provider(current_audit_info).get_session()
if assumed_session:
logger.info("Audit session is the new session created assuming role")
current_audit_info.audit_session = assumed_session
current_audit_info.audited_account = role_arn_parsed.account_id
current_audit_info.audited_partition = role_arn_parsed.partition
else:
logger.info("Audit session is the original one")
current_audit_info.audit_session = current_audit_info.original_session
def validate_credentials(validate_session):
try:
validate_credentials_client = validate_session.client("sts")
@@ -116,103 +160,25 @@ def validate_credentials(validate_session):
return caller_identity
def provider_set_session(session_input):
# global variables that are going to be shared accross the project
global aws_session
global original_session
global audited_regions
global audited_partition
global audited_account
assumed_session = None
# Initialize a session info dataclass only with info about the profile
session_info = AWS_Session_Info(
session_input.profile,
None,
None,
)
# Create an global original session using only profile/basic credentials info
original_session = AWS_Provider(session_info).get_session()
logger.info("Validating credentials ...")
# Verificate if we have valid credentials
caller_identity = validate_credentials(original_session)
logger.info("Credentials validated")
logger.info(f"Original caller identity UserId : {caller_identity['UserId']}")
logger.info(f"Original caller identity ARN : {caller_identity['Arn']}")
# Set some global values for original session
audited_regions = session_input.regions
audited_account = caller_identity["Account"]
audited_partition = arnparse(caller_identity["Arn"]).partition
if session_input.role_arn:
# Check if role arn is valid
try:
# this returns the arn already parsed, calls arnparse, into a dict to be used when it is needed to access its fields
role_arn_parsed = arn_parsing(session_input.role_arn)
except Exception as error:
logger.critical(f"{error.__class__.__name__} -- {error}")
quit()
# Set info for role assumption if needed
role_info = AWS_Assume_Role(
session_input.role_arn,
session_input.session_duration,
session_input.external_id,
original_session,
audited_partition,
)
logger.info(f"Assuming role {role_info.role_arn}")
# Assume the role
assumed_role_response = assume_role(role_info)
logger.info("Role assumed")
# Set the info needed to create a session with an assumed role
session_info = AWS_Session_Info(
session_input.profile,
AWS_Credentials(
aws_access_key_id=assumed_role_response["Credentials"]["AccessKeyId"],
aws_session_token=assumed_role_response["Credentials"]["SessionToken"],
aws_secret_access_key=assumed_role_response["Credentials"][
"SecretAccessKey"
],
expiration=assumed_role_response["Credentials"]["Expiration"],
),
role_info,
)
assumed_session = AWS_Provider(session_info).get_session()
if assumed_session:
aws_session = assumed_session
audited_account = role_arn_parsed.account_id
audited_partition = role_arn_parsed.partition
else:
aws_session = original_session
def assume_role(role_info):
def assume_role():
try:
# set the info to assume the role from the partition, account and role name
sts_client = role_info.sts_session.client("sts")
sts_client = current_audit_info.original_session.client("sts")
# If external id, set it to the assume role api call
if role_info.external_id:
if current_audit_info.assumed_role_info.external_id:
assumed_credentials = sts_client.assume_role(
RoleArn=role_info.role_arn,
RoleSessionName="ProwlerProSession",
DurationSeconds=role_info.session_duration,
ExternalId=role_info.external_id,
RoleArn=current_audit_info.assumed_role_info.role_arn,
RoleSessionName="ProwlerProAsessmentSession",
DurationSeconds=current_audit_info.assumed_role_info.session_duration,
ExternalId=current_audit_info.assumed_role_info.external_id,
)
# else assume the role without the external id
else:
assumed_credentials = sts_client.assume_role(
RoleArn=role_info.role_arn,
RoleSessionName="ProwlerProSession",
DurationSeconds=role_info.session_duration,
RoleArn=current_audit_info.assumed_role_info.role_arn,
RoleSessionName="ProwlerProAsessmentSession",
DurationSeconds=current_audit_info.assumed_role_info.session_duration,
)
except Exception as error:
logger.critical(f"{error.__class__.__name__} -- {error}")

31
providers/aws/models.py Normal file
View File

@@ -0,0 +1,31 @@
from dataclasses import dataclass
from datetime import datetime
from boto3 import session
@dataclass
class AWS_Credentials:
aws_access_key_id: str
aws_session_token: str
aws_secret_access_key: str
expiration: datetime
@dataclass
class AWS_Assume_Role:
role_arn: str
session_duration: int
external_id: str
@dataclass
class AWS_Audit_Info:
original_session: session.Session
audit_session: session.Session
audited_account: int
audited_partition: str
profile: str
credentials: AWS_Credentials
assumed_role_info: AWS_Assume_Role
audited_regions: list

View File

@@ -5,31 +5,24 @@ import urllib.request
from config.config import aws_services_json_file, aws_services_json_url
from lib.logger import logger
from lib.utils.utils import open_file, parse_json_file
from providers.aws.aws_provider import (
audited_account,
audited_partition,
audited_regions,
aws_session,
)
from providers.aws.aws_provider import current_audit_info
################## EC2
class EC2:
def __init__(self, aws_session, audited_regions):
def __init__(self, audit_info):
self.service = "ec2"
self.aws_session = aws_session
self.session = audit_info.audit_session
self.audited_account = audit_info.audited_account
self.regional_clients = self.__generate_regional_clients__(
self.service, audited_regions
self.service, audit_info
)
self.__threading_call__(self.__describe_snapshots__)
def __get_clients__(self):
return self.clients
def __get_session__(self):
return self.aws_session
return self.session
def __generate_regional_clients__(self, service, audited_regions):
def __generate_regional_clients__(self, service, audit_info):
regional_clients = []
try: # Try to get the list online
with urllib.request.urlopen(aws_services_json_url) as url:
@@ -40,46 +33,50 @@ class EC2:
data = parse_json_file(f)
for att in data["prices"]:
if audited_regions: # Check for input aws audited_regions
if (
audit_info.audited_regions
): # Check for input aws audit_info.audited_regions
if (
service in att["id"].split(":")[0]
and att["attributes"]["aws:region"] in audited_regions
and att["attributes"]["aws:region"] in audit_info.audited_regions
): # Check if service has this region
region = att["attributes"]["aws:region"]
regional_client = aws_session.client(service, region_name=region)
regional_client = audit_info.audit_session.client(
service, region_name=region
)
regional_client.region = region
regional_clients.append(regional_client)
else:
if audited_partition in "aws":
if audit_info.audited_partition in "aws":
if (
service in att["id"].split(":")[0]
and "gov" not in att["attributes"]["aws:region"]
and "cn" not in att["attributes"]["aws:region"]
):
region = att["attributes"]["aws:region"]
regional_client = aws_session.client(
regional_client = audit_info.audit_session.client(
service, region_name=region
)
regional_client.region = region
regional_clients.append(regional_client)
elif audited_partition in "cn":
elif audit_info.audited_partition in "cn":
if (
service in att["id"].split(":")[0]
and "cn" in att["attributes"]["aws:region"]
):
region = att["attributes"]["aws:region"]
regional_client = aws_session.client(
regional_client = audit_info.audit_session.client(
service, region_name=region
)
regional_client.region = region
regional_clients.append(regional_client)
elif audited_partition in "gov":
elif audit_info.audited_partition in "gov":
if (
service in att["id"].split(":")[0]
and "gov" in att["attributes"]["aws:region"]
):
region = att["attributes"]["aws:region"]
regional_client = aws_session.client(
regional_client = audit_info.audit_session.client(
service, region_name=region
)
regional_client.region = region
@@ -90,13 +87,17 @@ class EC2:
def __threading_call__(self, call):
threads = []
for regional_client in self.regional_clients:
threads.append(threading.Thread(target=call, args=(regional_client,)))
threads.append(
threading.Thread(
target=call, args=(regional_client, self.audited_account)
)
)
for t in threads:
t.start()
for t in threads:
t.join()
def __describe_snapshots__(self, regional_client):
def __describe_snapshots__(self, regional_client, audited_account):
logger.info("EC2 - Describing Snapshots...")
try:
describe_snapshots_paginator = regional_client.get_paginator(
@@ -114,4 +115,4 @@ class EC2:
regional_client.snapshots = snapshots
ec2_client = EC2(aws_session, audited_regions)
ec2_client = EC2(current_audit_info)

View File

@@ -1,13 +1,13 @@
from lib.logger import logger
from providers.aws.aws_provider import aws_session
from providers.aws.aws_provider import current_audit_info
################## IAM
class IAM:
def __init__(self, session):
def __init__(self, audit_info):
self.service = "iam"
self.session = session
self.client = session.client(self.service)
self.session = audit_info.audit_session
self.client = self.session.client(self.service)
self.users = self.__get_users__()
self.roles = self.__get_roles__()
self.customer_managed_policies = self.__get_customer_managed_policies__()
@@ -89,7 +89,7 @@ class IAM:
try:
iam_client = IAM(aws_session)
iam_client = IAM(current_audit_info)
except Exception as error:
logger.critical(f"{error.__class__.__name__} -- {error}")
quit()

View File

@@ -11,7 +11,7 @@ from lib.check.check import (
run_check,
)
from lib.logger import logger, logging_levels
from providers.aws.aws_provider import Input_Data, provider_set_session
from providers.aws.aws_provider import provider_set_session
if __name__ == "__main__":
# CLI Arguments
@@ -103,17 +103,15 @@ if __name__ == "__main__":
if args.no_banner:
print_banner()
# Setting session
session_input = Input_Data(
profile=args.profile,
role_arn=args.role,
session_duration=args.session_duration,
external_id=args.external_id,
regions=args.filter_region,
# Set global session
provider_set_session(
args.profile,
args.role,
args.session_duration,
args.external_id,
args.filter_region,
)
provider_set_session(session_input)
# Load checks to execute
logger.debug("Loading checks")
checks_to_execute = load_checks_to_execute(