diff --git a/prowler/__main__.py b/prowler/__main__.py index 81ba3be1..d64db093 100644 --- a/prowler/__main__.py +++ b/prowler/__main__.py @@ -38,13 +38,12 @@ from prowler.lib.outputs.outputs import ( display_summary_table, send_to_s3_bucket, ) -from prowler.providers.aws.aws_provider import aws_provider_set_session from prowler.providers.aws.lib.allowlist.allowlist import parse_allowlist_file from prowler.providers.aws.lib.quick_inventory.quick_inventory import quick_inventory from prowler.providers.aws.lib.security_hub.security_hub import ( resolve_security_hub_previous_findings, ) -from prowler.providers.azure.azure_provider import azure_provider_set_session +from prowler.providers.common.common import set_provider_audit_info def prowler(): @@ -298,17 +297,6 @@ def prowler(): sp_env_auth = args.sp_env_auth browser_auth = args.browser_auth managed_entity_auth = args.managed_identity_auth - if provider == "azure": - if ( - not az_cli_auth - and not sp_env_auth - and not browser_auth - and not managed_entity_auth - ): - logger.critical( - "If you are using Azure provider you need to set one of the following options: --az-cli-auth, --sp-env-auth, --browser-auth, --managed-identity-auth" - ) - sys.exit() # We treat the compliance framework as another output format if compliance_framework: @@ -417,20 +405,20 @@ def prowler(): if output_modes: mkdir(output_directory) - if provider == "aws": - # Set global session - audit_info = aws_provider_set_session( - args.profile, - args.role, - args.session_duration, - args.external_id, - args.filter_region, - args.organizations_role, - ) - elif provider == "azure": - audit_info = azure_provider_set_session( - subscriptions, az_cli_auth, sp_env_auth, browser_auth, managed_entity_auth - ) + arguments = { + "profile": args.profile, + "role": args.role, + "session_duration": args.session_duration, + "external_id": args.external_id, + "regions": args.filter_region, + "organizations_role": args.organizations_role, + "subscriptions": subscriptions, + "az_cli_auth": az_cli_auth, + "sp_env_auth": sp_env_auth, + "browser_auth": browser_auth, + "managed_entity_auth": managed_entity_auth, + } + audit_info = set_provider_audit_info(provider, arguments) # Check if custom output filename was input, if not, set the default if not output_filename: diff --git a/prowler/providers/aws/aws_provider.py b/prowler/providers/aws/aws_provider.py index c0a98de2..6bdb2ff4 100644 --- a/prowler/providers/aws/aws_provider.py +++ b/prowler/providers/aws/aws_provider.py @@ -1,22 +1,14 @@ import os import sys -from arnparse import arnparse -from boto3 import client, session +from boto3 import session from botocore.credentials import RefreshableCredentials from botocore.session import get_session -from colorama import Fore, Style from prowler.config.config import aws_services_json_file from prowler.lib.logger import logger from prowler.lib.utils.utils import open_file, parse_json_file -from prowler.providers.aws.lib.arn.arn import arn_parsing -from prowler.providers.aws.lib.audit_info.audit_info import current_audit_info -from prowler.providers.aws.lib.audit_info.models import ( - AWS_Audit_Info, - AWS_Credentials, - AWS_Organizations_Info, -) +from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info ################## AWS PROVIDER @@ -84,147 +76,6 @@ class AWS_Provider: return refreshed_credentials -def aws_provider_set_session( - input_profile, - input_role, - input_session_duration, - input_external_id, - input_regions, - organizations_role_arn, -): - # Assumed AWS session - assumed_session = None - - # Setting session - current_audit_info.profile = input_profile - current_audit_info.audited_regions = input_regions - - logger.info("Generating original session ...") - # Create an global original session using only profile/basic credentials info - current_audit_info.original_session = AWS_Provider(current_audit_info).get_session() - logger.info("Validating credentials ...") - # Verificate if we have valid credentials - caller_identity = validate_credentials(current_audit_info.original_session) - - logger.info("Credentials validated") - logger.info(f"Original caller identity UserId : {caller_identity['UserId']}") - logger.info(f"Original caller identity ARN : {caller_identity['Arn']}") - - current_audit_info.audited_account = caller_identity["Account"] - current_audit_info.audited_identity_arn = caller_identity["Arn"] - current_audit_info.audited_user_id = caller_identity["UserId"] - current_audit_info.audited_partition = arnparse(caller_identity["Arn"]).partition - - logger.info("Checking if organizations role assumption is needed ...") - if organizations_role_arn: - current_audit_info.assumed_role_info.role_arn = organizations_role_arn - current_audit_info.assumed_role_info.session_duration = input_session_duration - - # Check if role arn is valid - try: - # this returns the arn already parsed, calls arnparse, into a dict to be used when it is needed to access its fields - role_arn_parsed = arn_parsing(current_audit_info.assumed_role_info.role_arn) - - except Exception as error: - logger.critical(f"{error.__class__.__name__} -- {error}") - sys.exit() - - else: - logger.info( - f"Getting organizations metadata for account {organizations_role_arn}" - ) - assumed_credentials = assume_role(current_audit_info) - current_audit_info.organizations_metadata = get_organizations_metadata( - current_audit_info.audited_account, assumed_credentials - ) - logger.info("Organizations metadata retrieved") - - logger.info("Checking if role assumption is needed ...") - if input_role: - current_audit_info.assumed_role_info.role_arn = input_role - current_audit_info.assumed_role_info.session_duration = input_session_duration - current_audit_info.assumed_role_info.external_id = input_external_id - - # Check if role arn is valid - try: - # this returns the arn already parsed, calls arnparse, into a dict to be used when it is needed to access its fields - role_arn_parsed = arn_parsing(current_audit_info.assumed_role_info.role_arn) - - except Exception as error: - logger.critical(f"{error.__class__.__name__} -- {error}") - sys.exit() - - else: - logger.info( - f"Assuming role {current_audit_info.assumed_role_info.role_arn}" - ) - # Assume the role - assumed_role_response = assume_role(current_audit_info) - logger.info("Role assumed") - # Set the info needed to create a session with an assumed role - current_audit_info.credentials = AWS_Credentials( - aws_access_key_id=assumed_role_response["Credentials"]["AccessKeyId"], - aws_session_token=assumed_role_response["Credentials"]["SessionToken"], - aws_secret_access_key=assumed_role_response["Credentials"][ - "SecretAccessKey" - ], - expiration=assumed_role_response["Credentials"]["Expiration"], - ) - assumed_session = AWS_Provider(current_audit_info).get_session() - - if assumed_session: - logger.info("Audit session is the new session created assuming role") - current_audit_info.audit_session = assumed_session - current_audit_info.audited_account = role_arn_parsed.account_id - current_audit_info.audited_partition = role_arn_parsed.partition - else: - logger.info("Audit session is the original one") - current_audit_info.audit_session = current_audit_info.original_session - - # Setting default region of session - if current_audit_info.audit_session.region_name: - current_audit_info.profile_region = current_audit_info.audit_session.region_name - else: - current_audit_info.profile_region = "us-east-1" - - print_audit_credentials(current_audit_info) - return current_audit_info - - -def print_audit_credentials(audit_info: AWS_Audit_Info): - # Beautify audited regions, set "all" if there is no filter region - regions = ( - ", ".join(audit_info.audited_regions) - if audit_info.audited_regions is not None - else "all" - ) - # Beautify audited profile, set "default" if there is no profile set - profile = audit_info.profile if audit_info.profile is not None else "default" - - report = f""" -This report is being generated using credentials below: - -AWS-CLI Profile: {Fore.YELLOW}[{profile}]{Style.RESET_ALL} AWS Filter Region: {Fore.YELLOW}[{regions}]{Style.RESET_ALL} -AWS Account: {Fore.YELLOW}[{audit_info.audited_account}]{Style.RESET_ALL} UserId: {Fore.YELLOW}[{audit_info.audited_user_id}]{Style.RESET_ALL} -Caller Identity ARN: {Fore.YELLOW}[{audit_info.audited_identity_arn}]{Style.RESET_ALL} -""" - # If -A is set, print Assumed Role ARN - if audit_info.assumed_role_info.role_arn is not None: - report += f"Assumed Role ARN: {Fore.YELLOW}[{audit_info.assumed_role_info.role_arn}]{Style.RESET_ALL}" - print(report) - - -def validate_credentials(validate_session: session) -> dict: - try: - validate_credentials_client = validate_session.client("sts") - caller_identity = validate_credentials_client.get_caller_identity() - except Exception as error: - logger.critical(f"{error.__class__.__name__} -- {error}") - sys.exit() - else: - return caller_identity - - def assume_role(audit_info: AWS_Audit_Info) -> dict: try: # set the info to assume the role from the partition, account and role name @@ -252,40 +103,6 @@ def assume_role(audit_info: AWS_Audit_Info) -> dict: return assumed_credentials -def get_organizations_metadata( - metadata_account: str, assumed_credentials: dict -) -> AWS_Organizations_Info: - try: - organizations_client = client( - "organizations", - aws_access_key_id=assumed_credentials["Credentials"]["AccessKeyId"], - aws_secret_access_key=assumed_credentials["Credentials"]["SecretAccessKey"], - aws_session_token=assumed_credentials["Credentials"]["SessionToken"], - ) - organizations_metadata = organizations_client.describe_account( - AccountId=metadata_account - ) - list_tags_for_resource = organizations_client.list_tags_for_resource( - ResourceId=metadata_account - ) - except Exception as error: - logger.critical(f"{error.__class__.__name__} -- {error}") - sys.exit() - else: - # Convert Tags dictionary to String - account_details_tags = "" - for tag in list_tags_for_resource["Tags"]: - account_details_tags += tag["Key"] + ":" + tag["Value"] + "," - organizations_info = AWS_Organizations_Info( - account_details_email=organizations_metadata["Account"]["Email"], - account_details_name=organizations_metadata["Account"]["Name"], - account_details_arn=organizations_metadata["Account"]["Arn"], - account_details_org=organizations_metadata["Account"]["Arn"].split("/")[1], - account_details_tags=account_details_tags, - ) - return organizations_info - - def generate_regional_clients(service: str, audit_info: AWS_Audit_Info) -> dict: regional_clients = {} # Get json locally diff --git a/prowler/providers/azure/azure_provider.py b/prowler/providers/azure/azure_provider.py index e9f47db2..a0509e78 100644 --- a/prowler/providers/azure/azure_provider.py +++ b/prowler/providers/azure/azure_provider.py @@ -6,11 +6,7 @@ from azure.mgmt.subscription import SubscriptionClient from msgraph.core import GraphClient from prowler.lib.logger import logger -from prowler.providers.azure.lib.audit_info.audit_info import azure_audit_info -from prowler.providers.azure.lib.audit_info.models import ( - Azure_Audit_Info, - Azure_Identity_Info, -) +from prowler.providers.azure.lib.audit_info.models import Azure_Identity_Info class Azure_Provider: @@ -183,21 +179,3 @@ class Azure_Provider: def get_identity(self): return self.identity - - -def azure_provider_set_session( - subscription_ids: list, - az_cli_auth: bool, - sp_env_auth: bool, - browser_auth: bool, - managed_entity_auth: bool, -) -> Azure_Audit_Info: - logger.info("Setting Azure session ...") - - azure_provider = Azure_Provider( - az_cli_auth, sp_env_auth, browser_auth, managed_entity_auth, subscription_ids - ) - azure_audit_info.credentials = azure_provider.get_credentials() - azure_audit_info.identity = azure_provider.get_identity() - - return azure_audit_info diff --git a/prowler/providers/azure/lib/audit_info/models.py b/prowler/providers/azure/lib/audit_info/models.py index 87a89e2d..10a3ef21 100644 --- a/prowler/providers/azure/lib/audit_info/models.py +++ b/prowler/providers/azure/lib/audit_info/models.py @@ -20,4 +20,3 @@ class Azure_Audit_Info: def __init__(self, credentials, identity): self.credentials = credentials self.identity = identity - self.is_azure = True diff --git a/prowler/providers/common/__init__.py b/prowler/providers/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/prowler/providers/common/common.py b/prowler/providers/common/common.py new file mode 100644 index 00000000..f5d5d537 --- /dev/null +++ b/prowler/providers/common/common.py @@ -0,0 +1,255 @@ +import importlib +import sys + +from arnparse import arnparse +from boto3 import client, session +from colorama import Fore, Style + +from prowler.lib.logger import logger +from prowler.providers.aws.aws_provider import AWS_Provider, assume_role +from prowler.providers.aws.lib.arn.arn import arn_parsing +from prowler.providers.aws.lib.audit_info.audit_info import current_audit_info +from prowler.providers.aws.lib.audit_info.models import ( + AWS_Audit_Info, + AWS_Credentials, + AWS_Organizations_Info, +) +from prowler.providers.azure.azure_provider import Azure_Provider +from prowler.providers.azure.lib.audit_info.audit_info import azure_audit_info +from prowler.providers.azure.lib.audit_info.models import Azure_Audit_Info + + +class Audit_Info: + def __init__(self): + logger.info("Instantiating audit info") + + def validate_credentials(self, validate_session: session) -> dict: + try: + validate_credentials_client = validate_session.client("sts") + caller_identity = validate_credentials_client.get_caller_identity() + except Exception as error: + logger.critical(f"{error.__class__.__name__} -- {error}") + sys.exit() + else: + return caller_identity + + def print_audit_credentials(self, audit_info: AWS_Audit_Info): + # Beautify audited regions, set "all" if there is no filter region + regions = ( + ", ".join(audit_info.audited_regions) + if audit_info.audited_regions is not None + else "all" + ) + # Beautify audited profile, set "default" if there is no profile set + profile = audit_info.profile if audit_info.profile is not None else "default" + + report = f""" + This report is being generated using credentials below: + + AWS-CLI Profile: {Fore.YELLOW}[{profile}]{Style.RESET_ALL} AWS Filter Region: {Fore.YELLOW}[{regions}]{Style.RESET_ALL} + AWS Account: {Fore.YELLOW}[{audit_info.audited_account}]{Style.RESET_ALL} UserId: {Fore.YELLOW}[{audit_info.audited_user_id}]{Style.RESET_ALL} + Caller Identity ARN: {Fore.YELLOW}[{audit_info.audited_identity_arn}]{Style.RESET_ALL} + """ + # If -A is set, print Assumed Role ARN + if audit_info.assumed_role_info.role_arn is not None: + report += f"Assumed Role ARN: {Fore.YELLOW}[{audit_info.assumed_role_info.role_arn}]{Style.RESET_ALL}" + print(report) + + def get_organizations_metadata( + self, metadata_account: str, assumed_credentials: dict + ) -> AWS_Organizations_Info: + try: + organizations_client = client( + "organizations", + aws_access_key_id=assumed_credentials["Credentials"]["AccessKeyId"], + aws_secret_access_key=assumed_credentials["Credentials"][ + "SecretAccessKey" + ], + aws_session_token=assumed_credentials["Credentials"]["SessionToken"], + ) + organizations_metadata = organizations_client.describe_account( + AccountId=metadata_account + ) + list_tags_for_resource = organizations_client.list_tags_for_resource( + ResourceId=metadata_account + ) + except Exception as error: + logger.critical(f"{error.__class__.__name__} -- {error}") + sys.exit() + else: + # Convert Tags dictionary to String + account_details_tags = "" + for tag in list_tags_for_resource["Tags"]: + account_details_tags += tag["Key"] + ":" + tag["Value"] + "," + organizations_info = AWS_Organizations_Info( + account_details_email=organizations_metadata["Account"]["Email"], + account_details_name=organizations_metadata["Account"]["Name"], + account_details_arn=organizations_metadata["Account"]["Arn"], + account_details_org=organizations_metadata["Account"]["Arn"].split("/")[ + 1 + ], + account_details_tags=account_details_tags, + ) + return organizations_info + + def set_aws_audit_info(self, arguments): + input_profile = arguments["profile"] + input_role = arguments["role"] + input_session_duration = arguments["session_duration"] + input_external_id = arguments["external_id"] + input_regions = arguments["regions"] + organizations_role_arn = arguments["organizations_role"] + # Assumed AWS session + assumed_session = None + + # Setting session + current_audit_info.profile = input_profile + current_audit_info.audited_regions = input_regions + + logger.info("Generating original session ...") + # Create an global original session using only profile/basic credentials info + current_audit_info.original_session = AWS_Provider( + current_audit_info + ).get_session() + logger.info("Validating credentials ...") + # Verificate if we have valid credentials + caller_identity = self.validate_credentials(current_audit_info.original_session) + + logger.info("Credentials validated") + logger.info(f"Original caller identity UserId : {caller_identity['UserId']}") + logger.info(f"Original caller identity ARN : {caller_identity['Arn']}") + + current_audit_info.audited_account = caller_identity["Account"] + current_audit_info.audited_identity_arn = caller_identity["Arn"] + current_audit_info.audited_user_id = caller_identity["UserId"] + current_audit_info.audited_partition = arnparse( + caller_identity["Arn"] + ).partition + + logger.info("Checking if organizations role assumption is needed ...") + if organizations_role_arn: + current_audit_info.assumed_role_info.role_arn = organizations_role_arn + current_audit_info.assumed_role_info.session_duration = ( + input_session_duration + ) + + # Check if role arn is valid + try: + # this returns the arn already parsed, calls arnparse, into a dict to be used when it is needed to access its fields + role_arn_parsed = arn_parsing( + current_audit_info.assumed_role_info.role_arn + ) + + except Exception as error: + logger.critical(f"{error.__class__.__name__} -- {error}") + sys.exit() + + else: + logger.info( + f"Getting organizations metadata for account {organizations_role_arn}" + ) + assumed_credentials = assume_role(current_audit_info) + current_audit_info.organizations_metadata = ( + self.get_organizations_metadata( + current_audit_info.audited_account, assumed_credentials + ) + ) + logger.info("Organizations metadata retrieved") + + logger.info("Checking if role assumption is needed ...") + if input_role: + current_audit_info.assumed_role_info.role_arn = input_role + current_audit_info.assumed_role_info.session_duration = ( + input_session_duration + ) + current_audit_info.assumed_role_info.external_id = input_external_id + + # Check if role arn is valid + try: + # this returns the arn already parsed, calls arnparse, into a dict to be used when it is needed to access its fields + role_arn_parsed = arn_parsing( + current_audit_info.assumed_role_info.role_arn + ) + + except Exception as error: + logger.critical(f"{error.__class__.__name__} -- {error}") + sys.exit() + + else: + logger.info( + f"Assuming role {current_audit_info.assumed_role_info.role_arn}" + ) + # Assume the role + assumed_role_response = self.assume_role(current_audit_info) + logger.info("Role assumed") + # Set the info needed to create a session with an assumed role + current_audit_info.credentials = AWS_Credentials( + aws_access_key_id=assumed_role_response["Credentials"][ + "AccessKeyId" + ], + aws_session_token=assumed_role_response["Credentials"][ + "SessionToken" + ], + aws_secret_access_key=assumed_role_response["Credentials"][ + "SecretAccessKey" + ], + expiration=assumed_role_response["Credentials"]["Expiration"], + ) + assumed_session = AWS_Provider(current_audit_info).get_session() + + if assumed_session: + logger.info("Audit session is the new session created assuming role") + current_audit_info.audit_session = assumed_session + current_audit_info.audited_account = role_arn_parsed.account_id + current_audit_info.audited_partition = role_arn_parsed.partition + else: + logger.info("Audit session is the original one") + current_audit_info.audit_session = current_audit_info.original_session + + # Setting default region of session + if current_audit_info.audit_session.region_name: + current_audit_info.profile_region = ( + current_audit_info.audit_session.region_name + ) + else: + current_audit_info.profile_region = "us-east-1" + + self.print_audit_credentials(current_audit_info) + return current_audit_info + + def set_azure_audit_info(self, arguments) -> Azure_Audit_Info: + logger.info("Setting Azure session ...") + subscription_ids = arguments["subscriptions"] + az_cli_auth = arguments["az_cli_auth"] + sp_env_auth = arguments["sp_env_auth"] + browser_auth = arguments["browser_auth"] + managed_entity_auth = arguments["managed_entity_auth"] + + azure_provider = Azure_Provider( + az_cli_auth, + sp_env_auth, + browser_auth, + managed_entity_auth, + subscription_ids, + ) + azure_audit_info.credentials = azure_provider.get_credentials() + azure_audit_info.identity = azure_provider.get_identity() + + return azure_audit_info + + +def set_provider_audit_info(provider: str, arguments: dict): + try: + provider_set_audit_info = f"set_{provider}_audit_info" + provider_audit_info = getattr(Audit_Info(), provider_set_audit_info)(arguments) + except Exception as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + else: + return provider_audit_info + + +def import_lib(path: str): + lib = importlib.import_module(path) + return lib diff --git a/tests/providers/aws/aws_provider_test.py b/tests/providers/aws/aws_provider_test.py index 4f573bb4..3d1efafc 100644 --- a/tests/providers/aws/aws_provider_test.py +++ b/tests/providers/aws/aws_provider_test.py @@ -1,46 +1,14 @@ -import json - import boto3 import sure # noqa -from moto import mock_iam, mock_organizations, mock_sts +from moto import mock_iam, mock_sts -from prowler.providers.aws.aws_provider import ( - assume_role, - get_organizations_metadata, - get_region_global_service, - validate_credentials, -) +from prowler.providers.aws.aws_provider import assume_role, get_region_global_service from prowler.providers.aws.lib.audit_info.models import AWS_Assume_Role, AWS_Audit_Info ACCOUNT_ID = 123456789012 class Test_AWS_Provider: - @mock_sts - @mock_iam - def test_validate_credentials(self): - # Create a mock IAM user - iam_client = boto3.client("iam", region_name="us-east-1") - iam_user = iam_client.create_user(UserName="test-user")["User"] - # Create a mock IAM access keys - access_key = iam_client.create_access_key(UserName=iam_user["UserName"])[ - "AccessKey" - ] - access_key_id = access_key["AccessKeyId"] - secret_access_key = access_key["SecretAccessKey"] - # Create AWS session to validate - session = boto3.session.Session( - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - region_name="us-east-1", - ) - # Validate AWS session - get_caller_identity = validate_credentials(session) - - get_caller_identity["Arn"].should.equal(iam_user["Arn"]) - get_caller_identity["UserId"].should.equal(iam_user["UserId"]) - # assert get_caller_identity["UserId"] == str(ACCOUNT_ID) - @mock_iam @mock_sts def test_assume_role(self): @@ -114,60 +82,6 @@ class Test_AWS_Provider: 21 + 1 + len(sessionName) ) - @mock_organizations - @mock_sts - @mock_iam - def test_organizations(self): - client = boto3.client("organizations", region_name="us-east-1") - iam_client = boto3.client("iam", region_name="us-east-1") - sts_client = boto3.client("sts", region_name="us-east-1") - - mockname = "mock-account" - mockdomain = "moto-example.org" - mockemail = "@".join([mockname, mockdomain]) - - org_id = client.create_organization(FeatureSet="ALL")["Organization"]["Id"] - account_id = client.create_account(AccountName=mockname, Email=mockemail)[ - "CreateAccountStatus" - ]["AccountId"] - - client.tag_resource( - ResourceId=account_id, Tags=[{"Key": "key", "Value": "value"}] - ) - - trust_policy_document = { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::{account_id}:root".format( - account_id=ACCOUNT_ID - ) - }, - "Action": "sts:AssumeRole", - }, - } - iam_role_arn = iam_client.role_arn = iam_client.create_role( - RoleName="test-role", - AssumeRolePolicyDocument=json.dumps(trust_policy_document), - )["Role"]["Arn"] - session_name = "new-session" - assumed_role = sts_client.assume_role( - RoleArn=iam_role_arn, RoleSessionName=session_name - ) - - org = get_organizations_metadata(account_id, assumed_role) - - org.account_details_email.should.equal(mockemail) - org.account_details_name.should.equal(mockname) - org.account_details_arn.should.equal( - "arn:aws:organizations::{0}:account/{1}/{2}".format( - ACCOUNT_ID, org_id, account_id - ) - ) - org.account_details_org.should.equal(org_id) - org.account_details_tags.should.equal("key:value,") - def test_get_region_global_service(self): # Create mock audit_info input_audit_info = AWS_Audit_Info( diff --git a/tests/providers/common/common_test.py b/tests/providers/common/common_test.py new file mode 100644 index 00000000..70e2db9c --- /dev/null +++ b/tests/providers/common/common_test.py @@ -0,0 +1,198 @@ +import json + +import boto3 +import sure # noqa +from mock import patch +from moto import mock_iam, mock_organizations, mock_sts + +from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info +from prowler.providers.azure.azure_provider import Azure_Provider +from prowler.providers.azure.lib.audit_info.models import ( + Azure_Audit_Info, + Azure_Identity_Info, +) +from prowler.providers.common.common import Audit_Info, set_provider_audit_info + +ACCOUNT_ID = 123456789012 +mock_current_audit_info = AWS_Audit_Info( + original_session=None, + audit_session=None, + audited_account="123456789012", + audited_identity_arn="arn:aws:iam::123456789012:user/test", + audited_user_id="test", + audited_partition="aws", + profile="default", + profile_region="eu-west-1", + credentials=None, + assumed_role_info=None, + audited_regions=["eu-west-2", "eu-west-1"], + organizations_metadata=None, +) + +mock_azure_audit_info = Azure_Audit_Info( + credentials=None, identity=Azure_Identity_Info() +) + +mock_set_audit_info = Audit_Info() + + +def mock_validate_credentials(*_): + caller_identity = { + "Arn": "arn:aws:iam::123456789012:user/test", + "Account": "123456789012", + "UserId": "test", + } + return caller_identity + + +def mock_print_audit_credentials(*_): + pass + + +def mock_set_identity_info(*_): + return Azure_Identity_Info() + + +def mock_set_credentials(*_): + return {} + + +class Test_Set_Audit_Info: + @patch( + "prowler.providers.common.common.current_audit_info", + new=mock_current_audit_info, + ) + @mock_sts + @mock_iam + def test_validate_credentials(self): + # Create a mock IAM user + iam_client = boto3.client("iam", region_name="us-east-1") + iam_user = iam_client.create_user(UserName="test-user")["User"] + # Create a mock IAM access keys + access_key = iam_client.create_access_key(UserName=iam_user["UserName"])[ + "AccessKey" + ] + access_key_id = access_key["AccessKeyId"] + secret_access_key = access_key["SecretAccessKey"] + # Create AWS session to validate + session = boto3.session.Session( + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + region_name="us-east-1", + ) + audit_info = Audit_Info() + get_caller_identity = audit_info.validate_credentials(session) + + get_caller_identity["Arn"].should.equal(iam_user["Arn"]) + get_caller_identity["UserId"].should.equal(iam_user["UserId"]) + # assert get_caller_identity["UserId"] == str(ACCOUNT_ID) + + @patch( + "prowler.providers.common.common.current_audit_info", + new=mock_current_audit_info, + ) + @mock_organizations + @mock_sts + @mock_iam + def test_organizations(self): + client = boto3.client("organizations", region_name="us-east-1") + iam_client = boto3.client("iam", region_name="us-east-1") + sts_client = boto3.client("sts", region_name="us-east-1") + + mockname = "mock-account" + mockdomain = "moto-example.org" + mockemail = "@".join([mockname, mockdomain]) + + org_id = client.create_organization(FeatureSet="ALL")["Organization"]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] + + client.tag_resource( + ResourceId=account_id, Tags=[{"Key": "key", "Value": "value"}] + ) + + trust_policy_document = { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::{account_id}:root".format( + account_id=ACCOUNT_ID + ) + }, + "Action": "sts:AssumeRole", + }, + } + iam_role_arn = iam_client.role_arn = iam_client.create_role( + RoleName="test-role", + AssumeRolePolicyDocument=json.dumps(trust_policy_document), + )["Role"]["Arn"] + session_name = "new-session" + assumed_role = sts_client.assume_role( + RoleArn=iam_role_arn, RoleSessionName=session_name + ) + + audit_info = Audit_Info() + org = audit_info.get_organizations_metadata(account_id, assumed_role) + + org.account_details_email.should.equal(mockemail) + org.account_details_name.should.equal(mockname) + org.account_details_arn.should.equal( + "arn:aws:organizations::{0}:account/{1}/{2}".format( + ACCOUNT_ID, org_id, account_id + ) + ) + org.account_details_org.should.equal(org_id) + org.account_details_tags.should.equal("key:value,") + + @patch( + "prowler.providers.common.common.current_audit_info", + new=mock_current_audit_info, + ) + @patch.object(Audit_Info, "validate_credentials", new=mock_validate_credentials) + @patch.object( + Audit_Info, "print_audit_credentials", new=mock_print_audit_credentials + ) + def test_set_audit_info_aws(self): + provider = "aws" + arguments = { + "profile": None, + "role": None, + "session_duration": None, + "external_id": None, + "regions": None, + "organizations_role": None, + "subscriptions": None, + "az_cli_auth": None, + "sp_env_auth": None, + "browser_auth": None, + "managed_entity_auth": None, + } + + audit_info = set_provider_audit_info(provider, arguments) + assert isinstance(audit_info, AWS_Audit_Info) + + @patch( + "prowler.providers.common.common.azure_audit_info", new=mock_azure_audit_info + ) + @patch.object(Azure_Provider, "__set_credentials__", new=mock_set_credentials) + @patch.object(Azure_Provider, "__set_identity_info__", new=mock_set_identity_info) + def test_set_audit_info_azure(self): + provider = "azure" + arguments = { + "profile": None, + "role": None, + "session_duration": None, + "external_id": None, + "regions": None, + "organizations_role": None, + "subscriptions": None, + "az_cli_auth": None, + "sp_env_auth": None, + "browser_auth": None, + "managed_entity_auth": None, + } + + audit_info = set_provider_audit_info(provider, arguments) + assert isinstance(audit_info, Azure_Audit_Info)