mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
feat: add S3 model upload support
- Add upload_to_s3 function to handler - Save trained BERT models to S3 when credentials provided - Save LoRA adapters to S3 when credentials provided - Input params: s3_bucket, s3_prefix, aws_access_key_id, aws_secret_access_key, aws_region
This commit is contained in:
@@ -3,13 +3,49 @@ RunPod Serverless Handler for DDI Model Training
|
|||||||
|
|
||||||
Supports both BERT-style classification and LLM fine-tuning with LoRA.
|
Supports both BERT-style classification and LLM fine-tuning with LoRA.
|
||||||
Uses 176K real DrugBank DDI samples with drug names.
|
Uses 176K real DrugBank DDI samples with drug names.
|
||||||
|
Saves trained models to S3.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import runpod
|
import runpod
|
||||||
|
import boto3
|
||||||
|
from datetime import datetime
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def upload_to_s3(local_path: str, s3_bucket: str, s3_prefix: str, aws_credentials: Dict) -> str:
|
||||||
|
"""Upload a directory to S3 and return the S3 URI."""
|
||||||
|
import tarfile
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# Create tar.gz of the model directory
|
||||||
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||||
|
tar_name = f"model_{timestamp}.tar.gz"
|
||||||
|
tar_path = os.path.join(tempfile.gettempdir(), tar_name)
|
||||||
|
|
||||||
|
print(f"Creating archive: {tar_path}")
|
||||||
|
with tarfile.open(tar_path, "w:gz") as tar:
|
||||||
|
tar.add(local_path, arcname="model")
|
||||||
|
|
||||||
|
# Upload to S3
|
||||||
|
s3_client = boto3.client(
|
||||||
|
's3',
|
||||||
|
aws_access_key_id=aws_credentials.get('aws_access_key_id'),
|
||||||
|
aws_secret_access_key=aws_credentials.get('aws_secret_access_key'),
|
||||||
|
region_name=aws_credentials.get('aws_region', 'us-east-1')
|
||||||
|
)
|
||||||
|
|
||||||
|
s3_key = f"{s3_prefix}/{tar_name}"
|
||||||
|
print(f"Uploading to s3://{s3_bucket}/{s3_key}")
|
||||||
|
|
||||||
|
s3_client.upload_file(tar_path, s3_bucket, s3_key)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
os.remove(tar_path)
|
||||||
|
|
||||||
|
return f"s3://{s3_bucket}/{s3_key}"
|
||||||
|
|
||||||
|
|
||||||
# DDI severity labels
|
# DDI severity labels
|
||||||
DDI_SEVERITY = {
|
DDI_SEVERITY = {
|
||||||
1: "minor",
|
1: "minor",
|
||||||
@@ -229,10 +265,31 @@ def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
|
|
||||||
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
|
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
|
||||||
|
|
||||||
|
# Save LoRA adapter to S3 if credentials provided
|
||||||
|
s3_uri = None
|
||||||
|
s3_bucket = job_input.get('s3_bucket')
|
||||||
|
if s3_bucket:
|
||||||
|
save_dir = os.path.join(work_dir, 'lora_adapter')
|
||||||
|
print(f"Saving LoRA adapter to {save_dir}...")
|
||||||
|
model.save_pretrained(save_dir)
|
||||||
|
tokenizer.save_pretrained(save_dir)
|
||||||
|
|
||||||
|
aws_creds = {
|
||||||
|
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
||||||
|
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
||||||
|
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
||||||
|
}
|
||||||
|
model_short = model_name.split('/')[-1]
|
||||||
|
s3_prefix = job_input.get('s3_prefix', f'ddi-models/lora-{model_short}')
|
||||||
|
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
|
||||||
|
metrics['s3_uri'] = s3_uri
|
||||||
|
print(f"LoRA adapter uploaded to {s3_uri}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'status': 'success',
|
'status': 'success',
|
||||||
'metrics': metrics,
|
'metrics': metrics,
|
||||||
'message': f'Gemma 3 fine-tuned on {len(raw_data):,} real DrugBank DDI samples'
|
'model_uri': s3_uri,
|
||||||
|
'message': f'LLM fine-tuned on {len(raw_data):,} real DrugBank DDI samples'
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -357,7 +414,26 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
|
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
|
||||||
})
|
})
|
||||||
|
|
||||||
return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained on DrugBank data'}
|
# Save model to S3 if credentials provided
|
||||||
|
s3_uri = None
|
||||||
|
s3_bucket = job_input.get('s3_bucket')
|
||||||
|
if s3_bucket:
|
||||||
|
save_dir = os.path.join(work_dir, 'saved_model')
|
||||||
|
print(f"Saving model to {save_dir}...")
|
||||||
|
trainer.save_model(save_dir)
|
||||||
|
tokenizer.save_pretrained(save_dir)
|
||||||
|
|
||||||
|
aws_creds = {
|
||||||
|
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
||||||
|
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
||||||
|
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
||||||
|
}
|
||||||
|
s3_prefix = job_input.get('s3_prefix', 'ddi-models/bert')
|
||||||
|
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
|
||||||
|
metrics['s3_uri'] = s3_uri
|
||||||
|
print(f"Model uploaded to {s3_uri}")
|
||||||
|
|
||||||
|
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'BERT classifier trained on DrugBank data'}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|||||||
12
components/runpod_trainer/runpod.toml
Normal file
12
components/runpod_trainer/runpod.toml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
[project]
|
||||||
|
name = "ddi-trainer"
|
||||||
|
base_image = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
|
||||||
|
gpu_types = ["NVIDIA RTX A4000", "NVIDIA RTX A5000", "NVIDIA RTX A6000", "NVIDIA GeForce RTX 4090"]
|
||||||
|
gpu_count = 1
|
||||||
|
volume_mount_path = "/runpod-volume"
|
||||||
|
|
||||||
|
[project.env_vars]
|
||||||
|
|
||||||
|
[runtime]
|
||||||
|
handler_path = "handler.py"
|
||||||
|
requirements_path = "requirements.txt"
|
||||||
Reference in New Issue
Block a user