mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
fix: add aws_session_token support for SSO credentials
This commit is contained in:
@@ -32,6 +32,7 @@ def upload_to_s3(local_path: str, s3_bucket: str, s3_prefix: str, aws_credential
|
|||||||
's3',
|
's3',
|
||||||
aws_access_key_id=aws_credentials.get('aws_access_key_id'),
|
aws_access_key_id=aws_credentials.get('aws_access_key_id'),
|
||||||
aws_secret_access_key=aws_credentials.get('aws_secret_access_key'),
|
aws_secret_access_key=aws_credentials.get('aws_secret_access_key'),
|
||||||
|
aws_session_token=aws_credentials.get('aws_session_token'),
|
||||||
region_name=aws_credentials.get('aws_region', 'us-east-1')
|
region_name=aws_credentials.get('aws_region', 'us-east-1')
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -277,6 +278,7 @@ def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
aws_creds = {
|
aws_creds = {
|
||||||
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
||||||
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
||||||
|
'aws_session_token': job_input.get('aws_session_token'),
|
||||||
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
||||||
}
|
}
|
||||||
model_short = model_name.split('/')[-1]
|
model_short = model_name.split('/')[-1]
|
||||||
@@ -426,6 +428,7 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
aws_creds = {
|
aws_creds = {
|
||||||
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
||||||
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
||||||
|
'aws_session_token': job_input.get('aws_session_token'),
|
||||||
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
||||||
}
|
}
|
||||||
s3_prefix = job_input.get('s3_prefix', 'ddi-models/bert')
|
s3_prefix = job_input.get('s3_prefix', 'ddi-models/bert')
|
||||||
|
|||||||
Reference in New Issue
Block a user