fix(outputs): initialize_file_descriptor is called dynamically (#3050)

This commit is contained in:
Nacho Rivera
2023-11-21 16:05:26 +01:00
committed by GitHub
parent f9d2e7aa93
commit 60c0b79b10
3 changed files with 40 additions and 26 deletions

View File

@@ -12,8 +12,6 @@ from prowler.config.config import (
from prowler.lib.logger import logger
from prowler.lib.outputs.html import add_html_header
from prowler.lib.outputs.models import (
Aws_Check_Output_CSV,
Azure_Check_Output_CSV,
Check_Output_CSV_AWS_CIS,
Check_Output_CSV_AWS_ISO27001_2013,
Check_Output_CSV_AWS_Well_Architected,
@@ -21,19 +19,18 @@ from prowler.lib.outputs.models import (
Check_Output_CSV_GCP_CIS,
Check_Output_CSV_Generic_Compliance,
Check_Output_MITRE_ATTACK,
Gcp_Check_Output_CSV,
generate_csv_fields,
)
from prowler.lib.utils.utils import file_exists, open_file
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
from prowler.providers.azure.lib.audit_info.models import Azure_Audit_Info
from prowler.providers.common.outputs import get_provider_output_model
from prowler.providers.gcp.lib.audit_info.models import GCP_Audit_Info
def initialize_file_descriptor(
filename: str,
output_mode: str,
audit_info: AWS_Audit_Info,
audit_info: Any,
format: Any = None,
) -> TextIOWrapper:
"""Open/Create the output file. If needed include headers or the required format"""
@@ -75,27 +72,15 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit
for output_mode in output_modes:
if output_mode == "csv":
filename = f"{output_directory}/{output_filename}{csv_file_suffix}"
if isinstance(audit_info, AWS_Audit_Info):
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
Aws_Check_Output_CSV,
)
if isinstance(audit_info, Azure_Audit_Info):
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
Azure_Check_Output_CSV,
)
if isinstance(audit_info, GCP_Audit_Info):
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
Gcp_Check_Output_CSV,
)
output_model = get_provider_output_model(
audit_info.__class__.__name__
)
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
output_model,
)
file_descriptors.update({output_mode: file_descriptor})
elif output_mode == "json":

View File

@@ -29,6 +29,21 @@ def set_provider_output_options(
return provider_output_options
def get_provider_output_model(audit_info_class_name):
"""
get_provider_output_model returns the model _Check_Output_CSV for each provider
"""
# from AWS_Audit_Info -> AWS -> aws -> Aws
output_provider = audit_info_class_name.split("_", 1)[0].lower().capitalize()
output_provider_model_name = f"{output_provider}_Check_Output_CSV"
output_provider_models_path = "prowler.lib.outputs.models"
output_provider_model = getattr(
importlib.import_module(output_provider_models_path), output_provider_model_name
)
return output_provider_model
@dataclass
class Provider_Output_Options:
is_quiet: bool

View File

@@ -16,6 +16,7 @@ from prowler.providers.common.outputs import (
Aws_Output_Options,
Azure_Output_Options,
Gcp_Output_Options,
get_provider_output_model,
set_provider_output_options,
)
from prowler.providers.gcp.lib.audit_info.models import GCP_Audit_Info
@@ -393,3 +394,16 @@ class Test_Common_Output_Options:
</div>
"""
)
def test_get_provider_output_model(self):
audit_info_class_names = [
"AWS_Audit_Info",
"GCP_Audit_Info",
"Azure_Audit_Info",
]
for class_name in audit_info_class_names:
provider_prefix = class_name.split("_", 1)[0].lower().capitalize()
assert (
get_provider_output_model(class_name).__name__
== f"{provider_prefix}_Check_Output_CSV"
)