fix: add aws_session_token support for SSO credentials

This commit is contained in:
2026-02-03 15:42:19 +00:00
parent 59c808cb3a
commit 9595ef09fd

View File

@@ -32,6 +32,7 @@ def upload_to_s3(local_path: str, s3_bucket: str, s3_prefix: str, aws_credential
's3', 's3',
aws_access_key_id=aws_credentials.get('aws_access_key_id'), aws_access_key_id=aws_credentials.get('aws_access_key_id'),
aws_secret_access_key=aws_credentials.get('aws_secret_access_key'), aws_secret_access_key=aws_credentials.get('aws_secret_access_key'),
aws_session_token=aws_credentials.get('aws_session_token'),
region_name=aws_credentials.get('aws_region', 'us-east-1') region_name=aws_credentials.get('aws_region', 'us-east-1')
) )
@@ -277,6 +278,7 @@ def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
aws_creds = { aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'), 'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'), 'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_session_token': job_input.get('aws_session_token'),
'aws_region': job_input.get('aws_region', 'us-east-1'), 'aws_region': job_input.get('aws_region', 'us-east-1'),
} }
model_short = model_name.split('/')[-1] model_short = model_name.split('/')[-1]
@@ -426,6 +428,7 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
aws_creds = { aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'), 'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'), 'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_session_token': job_input.get('aws_session_token'),
'aws_region': job_input.get('aws_region', 'us-east-1'), 'aws_region': job_input.get('aws_region', 'us-east-1'),
} }
s3_prefix = job_input.get('s3_prefix', 'ddi-models/bert') s3_prefix = job_input.get('s3_prefix', 'ddi-models/bert')