mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
feat: add S3 model upload support
- Add upload_to_s3 function to handler - Save trained BERT models to S3 when credentials provided - Save LoRA adapters to S3 when credentials provided - Input params: s3_bucket, s3_prefix, aws_access_key_id, aws_secret_access_key, aws_region
This commit is contained in:
@@ -3,13 +3,49 @@ RunPod Serverless Handler for DDI Model Training
|
||||
|
||||
Supports both BERT-style classification and LLM fine-tuning with LoRA.
|
||||
Uses 176K real DrugBank DDI samples with drug names.
|
||||
Saves trained models to S3.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import runpod
|
||||
import boto3
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
|
||||
def upload_to_s3(local_path: str, s3_bucket: str, s3_prefix: str, aws_credentials: Dict) -> str:
|
||||
"""Upload a directory to S3 and return the S3 URI."""
|
||||
import tarfile
|
||||
import tempfile
|
||||
|
||||
# Create tar.gz of the model directory
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
tar_name = f"model_{timestamp}.tar.gz"
|
||||
tar_path = os.path.join(tempfile.gettempdir(), tar_name)
|
||||
|
||||
print(f"Creating archive: {tar_path}")
|
||||
with tarfile.open(tar_path, "w:gz") as tar:
|
||||
tar.add(local_path, arcname="model")
|
||||
|
||||
# Upload to S3
|
||||
s3_client = boto3.client(
|
||||
's3',
|
||||
aws_access_key_id=aws_credentials.get('aws_access_key_id'),
|
||||
aws_secret_access_key=aws_credentials.get('aws_secret_access_key'),
|
||||
region_name=aws_credentials.get('aws_region', 'us-east-1')
|
||||
)
|
||||
|
||||
s3_key = f"{s3_prefix}/{tar_name}"
|
||||
print(f"Uploading to s3://{s3_bucket}/{s3_key}")
|
||||
|
||||
s3_client.upload_file(tar_path, s3_bucket, s3_key)
|
||||
|
||||
# Cleanup
|
||||
os.remove(tar_path)
|
||||
|
||||
return f"s3://{s3_bucket}/{s3_key}"
|
||||
|
||||
|
||||
# DDI severity labels
|
||||
DDI_SEVERITY = {
|
||||
1: "minor",
|
||||
@@ -229,10 +265,31 @@ def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
|
||||
|
||||
# Save LoRA adapter to S3 if credentials provided
|
||||
s3_uri = None
|
||||
s3_bucket = job_input.get('s3_bucket')
|
||||
if s3_bucket:
|
||||
save_dir = os.path.join(work_dir, 'lora_adapter')
|
||||
print(f"Saving LoRA adapter to {save_dir}...")
|
||||
model.save_pretrained(save_dir)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
aws_creds = {
|
||||
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
||||
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
||||
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
||||
}
|
||||
model_short = model_name.split('/')[-1]
|
||||
s3_prefix = job_input.get('s3_prefix', f'ddi-models/lora-{model_short}')
|
||||
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
|
||||
metrics['s3_uri'] = s3_uri
|
||||
print(f"LoRA adapter uploaded to {s3_uri}")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'metrics': metrics,
|
||||
'message': f'Gemma 3 fine-tuned on {len(raw_data):,} real DrugBank DDI samples'
|
||||
'model_uri': s3_uri,
|
||||
'message': f'LLM fine-tuned on {len(raw_data):,} real DrugBank DDI samples'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -357,7 +414,26 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
|
||||
})
|
||||
|
||||
return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained on DrugBank data'}
|
||||
# Save model to S3 if credentials provided
|
||||
s3_uri = None
|
||||
s3_bucket = job_input.get('s3_bucket')
|
||||
if s3_bucket:
|
||||
save_dir = os.path.join(work_dir, 'saved_model')
|
||||
print(f"Saving model to {save_dir}...")
|
||||
trainer.save_model(save_dir)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
aws_creds = {
|
||||
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
||||
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
||||
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
||||
}
|
||||
s3_prefix = job_input.get('s3_prefix', 'ddi-models/bert')
|
||||
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
|
||||
metrics['s3_uri'] = s3_uri
|
||||
print(f"Model uploaded to {s3_uri}")
|
||||
|
||||
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'BERT classifier trained on DrugBank data'}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
Reference in New Issue
Block a user