diff --git a/prowler/providers/aws/aws_provider.py b/prowler/providers/aws/aws_provider.py index 9d034b58..0a1b5833 100644 --- a/prowler/providers/aws/aws_provider.py +++ b/prowler/providers/aws/aws_provider.py @@ -12,6 +12,7 @@ from prowler.lib.logger import logger from prowler.lib.utils.utils import open_file, parse_json_file from prowler.providers.aws.config import AWS_STS_GLOBAL_ENDPOINT_REGION from prowler.providers.aws.lib.audit_info.models import AWS_Assume_Role, AWS_Audit_Info +from prowler.providers.aws.lib.credentials.credentials import create_sts_session ################## AWS PROVIDER @@ -131,7 +132,7 @@ def assume_role( if sts_endpoint_region is None: sts_endpoint_region = AWS_STS_GLOBAL_ENDPOINT_REGION - sts_client = session.client("sts", sts_endpoint_region) + sts_client = create_sts_session(session, sts_endpoint_region) assumed_credentials = sts_client.assume_role(**assume_role_arguments) except Exception as error: logger.critical( diff --git a/prowler/providers/aws/lib/credentials/credentials.py b/prowler/providers/aws/lib/credentials/credentials.py index 66eb915e..8d7984d8 100644 --- a/prowler/providers/aws/lib/credentials/credentials.py +++ b/prowler/providers/aws/lib/credentials/credentials.py @@ -29,7 +29,7 @@ def validate_aws_credentials( # Get the first region passed to the -f/--region aws_region = input_regions[0] - validate_credentials_client = session.client("sts", aws_region) + validate_credentials_client = create_sts_session(session, aws_region) caller_identity = validate_credentials_client.get_caller_identity() # Include the region where the caller_identity has validated the credentials caller_identity["region"] = aws_region @@ -64,3 +64,11 @@ Caller Identity ARN: {Fore.YELLOW}[{audit_info.audited_identity_arn}]{Style.RESE report += f"""Assumed Role ARN: {Fore.YELLOW}[{audit_info.assumed_role_info.role_arn}]{Style.RESET_ALL} """ print(report) + + +def create_sts_session( + session: session.Session, aws_region: str +) -> session.Session.client: + return session.client( + "sts", aws_region, endpoint_url=f"https://sts.{aws_region}.amazonaws.com" + ) diff --git a/tests/providers/aws/lib/credentials/credentials_test.py b/tests/providers/aws/lib/credentials/credentials_test.py index f5481df7..84bbd372 100644 --- a/tests/providers/aws/lib/credentials/credentials_test.py +++ b/tests/providers/aws/lib/credentials/credentials_test.py @@ -6,7 +6,10 @@ from mock import patch from moto import mock_iam, mock_sts from prowler.providers.aws.lib.arn.arn import parse_iam_credentials_arn -from prowler.providers.aws.lib.credentials.credentials import validate_aws_credentials +from prowler.providers.aws.lib.credentials.credentials import ( + create_sts_session, + validate_aws_credentials, +) AWS_ACCOUNT_NUMBER = "123456789012" @@ -446,3 +449,75 @@ class Test_AWS_Credentials: assert caller_identity_arn.resource_type == "user" assert re.match("[0-9a-zA-Z]{20}", get_caller_identity["UserId"]) assert get_caller_identity["Account"] == AWS_ACCOUNT_NUMBER + + @mock_iam + @mock_sts + def test_create_sts_session(self): + aws_region = "eu-west-1" + # Create a mock IAM user + iam_client = boto3.client("iam", region_name=aws_region) + iam_user = iam_client.create_user(UserName="test-user")["User"] + # Create a mock IAM access keys + access_key = iam_client.create_access_key(UserName=iam_user["UserName"])[ + "AccessKey" + ] + access_key_id = access_key["AccessKeyId"] + secret_access_key = access_key["SecretAccessKey"] + # Create AWS session to validate + session = boto3.session.Session( + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + region_name=aws_region, + ) + sts_client = create_sts_session(session, aws_region) + + assert sts_client._endpoint._endpoint_prefix == "sts" + assert sts_client._endpoint.host == f"https://sts.{aws_region}.amazonaws.com" + + @mock_iam + @mock_sts + def test_create_sts_session_gov_cloud(self): + aws_region = "us-gov-east-1" + # Create a mock IAM user + iam_client = boto3.client("iam", region_name=aws_region) + iam_user = iam_client.create_user(UserName="test-user")["User"] + # Create a mock IAM access keys + access_key = iam_client.create_access_key(UserName=iam_user["UserName"])[ + "AccessKey" + ] + access_key_id = access_key["AccessKeyId"] + secret_access_key = access_key["SecretAccessKey"] + # Create AWS session to validate + session = boto3.session.Session( + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + region_name=aws_region, + ) + sts_client = create_sts_session(session, aws_region) + + assert sts_client._endpoint._endpoint_prefix == "sts" + assert sts_client._endpoint.host == f"https://sts.{aws_region}.amazonaws.com" + + @mock_iam + @mock_sts + def test_create_sts_session_china(self): + aws_region = "cn-north-1" + # Create a mock IAM user + iam_client = boto3.client("iam", region_name=aws_region) + iam_user = iam_client.create_user(UserName="test-user")["User"] + # Create a mock IAM access keys + access_key = iam_client.create_access_key(UserName=iam_user["UserName"])[ + "AccessKey" + ] + access_key_id = access_key["AccessKeyId"] + secret_access_key = access_key["SecretAccessKey"] + # Create AWS session to validate + session = boto3.session.Session( + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + region_name=aws_region, + ) + sts_client = create_sts_session(session, aws_region) + + assert sts_client._endpoint._endpoint_prefix == "sts" + assert sts_client._endpoint.host == f"https://sts.{aws_region}.amazonaws.com"