Files
kubeflow-pipelines/components/runpod_trainer/handler.py

159 lines
5.4 KiB
Python

"""
RunPod Serverless Handler for DDI Model Training
This runs on RunPod GPU instances and trains the Bio_ClinicalBERT model
for drug-drug interaction detection.
"""
import os
import json
import runpod
from typing import Dict, Any, List, Optional
def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""
Train DDI detection model.
Expected input:
{
"model_name": "emilyalsentzer/Bio_ClinicalBERT",
"training_data": [{"text": "...", "label": 0}, ...], # Inline data
"epochs": 3,
"learning_rate": 2e-5,
"batch_size": 16
}
"""
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer
)
from datasets import Dataset
import tempfile
import shutil
# Extract parameters
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
training_data = job_input.get('training_data', None)
epochs = job_input.get('epochs', 3)
learning_rate = job_input.get('learning_rate', 2e-5)
batch_size = job_input.get('batch_size', 16)
# Use sample data if none provided
if not training_data:
print("No training data provided, using sample DDI dataset...")
training_data = [
{"text": "warfarin and aspirin interaction causes bleeding risk", "label": 3},
{"text": "metformin with lisinopril is safe combination", "label": 0},
{"text": "fluoxetine tramadol causes serotonin syndrome", "label": 4},
{"text": "simvastatin amiodarone increases myopathy risk", "label": 3},
{"text": "omeprazole reduces clopidogrel efficacy", "label": 2},
{"text": "digoxin amiodarone toxicity risk elevated", "label": 3},
{"text": "lithium NSAIDs increases lithium levels", "label": 3},
{"text": "benzodiazepines opioids respiratory depression", "label": 4},
{"text": "metronidazole alcohol disulfiram reaction", "label": 4},
{"text": "ACE inhibitors potassium hyperkalemia risk", "label": 2},
{"text": "amlodipine atorvastatin safe combination", "label": 0},
{"text": "gabapentin pregabalin CNS depression additive", "label": 2},
{"text": "warfarin vitamin K antagonism reduced effect", "label": 2},
{"text": "insulin metformin hypoglycemia risk", "label": 1},
{"text": "aspirin ibuprofen GI bleeding increased", "label": 3},
] * 10 # 150 samples
# Create temp directory
work_dir = tempfile.mkdtemp()
model_dir = os.path.join(work_dir, 'model')
try:
print(f"Training samples: {len(training_data)}")
print(f"Model: {model_name}")
print(f"Epochs: {epochs}, Batch size: {batch_size}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
# Load dataset
dataset = Dataset.from_list(training_data)
# Load model and tokenizer
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=5 # DDI severity: none, minor, moderate, major, contraindicated
)
# Tokenize dataset
def tokenize_function(examples):
return tokenizer(
examples['text'],
padding='max_length',
truncation=True,
max_length=128
)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Training arguments
training_args = TrainingArguments(
output_dir=model_dir,
num_train_epochs=epochs,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
warmup_steps=50,
weight_decay=0.01,
logging_steps=10,
save_strategy='no', # Don't save checkpoints (avoids tensor contiguity issues)
fp16=torch.cuda.is_available(),
report_to='none',
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
# Train
print("Starting training...")
train_result = trainer.train()
# Get metrics
metrics = {
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(training_data),
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
}
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
return {
'status': 'success',
'metrics': metrics,
'message': 'Model trained successfully'
}
except Exception as e:
import traceback
return {
'status': 'error',
'error': str(e),
'traceback': traceback.format_exc()
}
finally:
# Cleanup
shutil.rmtree(work_dir, ignore_errors=True)
def handler(job):
"""RunPod serverless handler."""
job_input = job.get('input', {})
return train_ddi_model(job_input)
# RunPod serverless entrypoint
runpod.serverless.start({'handler': handler})