From 60c0b79b107374629ef0452f6eb4babc25f238a4 Mon Sep 17 00:00:00 2001 From: Nacho Rivera Date: Tue, 21 Nov 2023 16:05:26 +0100 Subject: [PATCH] fix(outputs): initialize_file_descriptor is called dynamically (#3050) --- prowler/lib/outputs/file_descriptors.py | 37 ++++++------------- prowler/providers/common/outputs.py | 15 ++++++++ tests/providers/common/common_outputs_test.py | 14 +++++++ 3 files changed, 40 insertions(+), 26 deletions(-) diff --git a/prowler/lib/outputs/file_descriptors.py b/prowler/lib/outputs/file_descriptors.py index a2339e12..9b5def4d 100644 --- a/prowler/lib/outputs/file_descriptors.py +++ b/prowler/lib/outputs/file_descriptors.py @@ -12,8 +12,6 @@ from prowler.config.config import ( from prowler.lib.logger import logger from prowler.lib.outputs.html import add_html_header from prowler.lib.outputs.models import ( - Aws_Check_Output_CSV, - Azure_Check_Output_CSV, Check_Output_CSV_AWS_CIS, Check_Output_CSV_AWS_ISO27001_2013, Check_Output_CSV_AWS_Well_Architected, @@ -21,19 +19,18 @@ from prowler.lib.outputs.models import ( Check_Output_CSV_GCP_CIS, Check_Output_CSV_Generic_Compliance, Check_Output_MITRE_ATTACK, - Gcp_Check_Output_CSV, generate_csv_fields, ) from prowler.lib.utils.utils import file_exists, open_file from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info -from prowler.providers.azure.lib.audit_info.models import Azure_Audit_Info +from prowler.providers.common.outputs import get_provider_output_model from prowler.providers.gcp.lib.audit_info.models import GCP_Audit_Info def initialize_file_descriptor( filename: str, output_mode: str, - audit_info: AWS_Audit_Info, + audit_info: Any, format: Any = None, ) -> TextIOWrapper: """Open/Create the output file. If needed include headers or the required format""" @@ -75,27 +72,15 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit for output_mode in output_modes: if output_mode == "csv": filename = f"{output_directory}/{output_filename}{csv_file_suffix}" - if isinstance(audit_info, AWS_Audit_Info): - file_descriptor = initialize_file_descriptor( - filename, - output_mode, - audit_info, - Aws_Check_Output_CSV, - ) - if isinstance(audit_info, Azure_Audit_Info): - file_descriptor = initialize_file_descriptor( - filename, - output_mode, - audit_info, - Azure_Check_Output_CSV, - ) - if isinstance(audit_info, GCP_Audit_Info): - file_descriptor = initialize_file_descriptor( - filename, - output_mode, - audit_info, - Gcp_Check_Output_CSV, - ) + output_model = get_provider_output_model( + audit_info.__class__.__name__ + ) + file_descriptor = initialize_file_descriptor( + filename, + output_mode, + audit_info, + output_model, + ) file_descriptors.update({output_mode: file_descriptor}) elif output_mode == "json": diff --git a/prowler/providers/common/outputs.py b/prowler/providers/common/outputs.py index 13607645..58567df1 100644 --- a/prowler/providers/common/outputs.py +++ b/prowler/providers/common/outputs.py @@ -29,6 +29,21 @@ def set_provider_output_options( return provider_output_options +def get_provider_output_model(audit_info_class_name): + """ + get_provider_output_model returns the model _Check_Output_CSV for each provider + """ + # from AWS_Audit_Info -> AWS -> aws -> Aws + output_provider = audit_info_class_name.split("_", 1)[0].lower().capitalize() + output_provider_model_name = f"{output_provider}_Check_Output_CSV" + output_provider_models_path = "prowler.lib.outputs.models" + output_provider_model = getattr( + importlib.import_module(output_provider_models_path), output_provider_model_name + ) + + return output_provider_model + + @dataclass class Provider_Output_Options: is_quiet: bool diff --git a/tests/providers/common/common_outputs_test.py b/tests/providers/common/common_outputs_test.py index 86e16048..3e24091b 100644 --- a/tests/providers/common/common_outputs_test.py +++ b/tests/providers/common/common_outputs_test.py @@ -16,6 +16,7 @@ from prowler.providers.common.outputs import ( Aws_Output_Options, Azure_Output_Options, Gcp_Output_Options, + get_provider_output_model, set_provider_output_options, ) from prowler.providers.gcp.lib.audit_info.models import GCP_Audit_Info @@ -393,3 +394,16 @@ class Test_Common_Output_Options: """ ) + + def test_get_provider_output_model(self): + audit_info_class_names = [ + "AWS_Audit_Info", + "GCP_Audit_Info", + "Azure_Audit_Info", + ] + for class_name in audit_info_class_names: + provider_prefix = class_name.split("_", 1)[0].lower().capitalize() + assert ( + get_provider_output_model(class_name).__name__ + == f"{provider_prefix}_Check_Output_CSV" + )