Files
kubeflow-pipelines/components/runpod_trainer/handler.py
Greg Hendrickson 45b96e2094 feat: Switch to Llama 3.1 8B (Bedrock-compatible)
- Default model now meta-llama/Llama-3.1-8B-Instruct
- Added multi-model chat format support:
  - Llama 3 format
  - Mistral/Mixtral format
  - Qwen format
  - Gemma format
- Trained model can be imported to AWS Bedrock
2026-02-03 04:38:54 +00:00

390 lines
14 KiB
Python

"""
RunPod Serverless Handler for DDI Model Training
Supports both BERT-style classification and LLM fine-tuning with LoRA.
Uses 176K real DrugBank DDI samples with drug names.
"""
import os
import json
import runpod
from typing import Dict, Any, List, Optional
# DDI severity labels
DDI_SEVERITY = {
1: "minor",
2: "moderate",
3: "major",
4: "contraindicated"
}
def load_drugbank_data(max_samples: int = None, severity_filter: List[int] = None) -> List[Dict]:
"""Load real DrugBank DDI data from bundled file."""
data_path = os.environ.get('DDI_DATA_PATH', '/app/data/drugbank_ddi_complete.jsonl')
if not os.path.exists(data_path):
print(f"WARNING: Data file not found at {data_path}, using curated fallback")
return get_curated_fallback()
data = []
with open(data_path) as f:
for line in f:
item = json.loads(line)
if severity_filter and item['severity'] not in severity_filter:
continue
data.append(item)
if max_samples and len(data) >= max_samples:
break
return data
def get_curated_fallback() -> List[Dict]:
"""Fallback curated data if main file not available."""
patterns = [
{"drug1": "fluoxetine", "drug2": "tramadol", "interaction_text": "fluoxetine may increase the risk of serotonin syndrome when combined with tramadol", "severity": 4},
{"drug1": "warfarin", "drug2": "aspirin", "interaction_text": "warfarin may increase the risk of bleeding when combined with aspirin", "severity": 3},
{"drug1": "simvastatin", "drug2": "amlodipine", "interaction_text": "The serum concentration of simvastatin can be increased when combined with amlodipine", "severity": 2},
{"drug1": "metformin", "drug2": "lisinopril", "interaction_text": "metformin and lisinopril have no significant interaction", "severity": 1},
]
return patterns * 50 # 200 samples
def format_for_llm(item: Dict, model_name: str = "") -> str:
"""Format DDI item for LLM instruction tuning. Auto-detects format based on model."""
severity_name = DDI_SEVERITY.get(item['severity'], 'unknown')
user_msg = f"Analyze the drug interaction between {item['drug1']} and {item['drug2']}."
assistant_msg = f"""Interaction: {item['interaction_text']}
Severity: {severity_name.upper()}
Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}"""
# Llama 3 format
if 'llama' in model_name.lower():
return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{user_msg}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{assistant_msg}<|eot_id|>"""
# Mistral/Mixtral format
elif 'mistral' in model_name.lower() or 'mixtral' in model_name.lower():
return f"""<s>[INST] {user_msg} [/INST] {assistant_msg}</s>"""
# Qwen format
elif 'qwen' in model_name.lower():
return f"""<|im_start|>user
{user_msg}<|im_end|>
<|im_start|>assistant
{assistant_msg}<|im_end|>"""
# Gemma format
elif 'gemma' in model_name.lower():
return f"""<start_of_turn>user
{user_msg}<end_of_turn>
<start_of_turn>model
{assistant_msg}<end_of_turn>"""
# Generic fallback
else:
return f"""### User: {user_msg}\n### Assistant: {assistant_msg}"""
def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""Train LLM with QLoRA for DDI classification. Supports Llama, Mistral, Qwen, Gemma."""
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import Dataset
import tempfile
import shutil
# Parameters - Default to Llama 3.1 8B (Bedrock-compatible)
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
max_samples = job_input.get('max_samples', 10000)
epochs = job_input.get('epochs', 1)
learning_rate = job_input.get('learning_rate', 2e-4)
batch_size = job_input.get('batch_size', 2)
lora_r = job_input.get('lora_r', 16)
lora_alpha = job_input.get('lora_alpha', 32)
max_seq_length = job_input.get('max_seq_length', 512)
severity_filter = job_input.get('severity_filter', None) # e.g., [3, 4] for major/contraindicated only
work_dir = tempfile.mkdtemp()
try:
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"Model: {model_name}")
print(f"LoRA r={lora_r}, alpha={lora_alpha}")
# Load training data
print(f"Loading DrugBank DDI data (max {max_samples})...")
raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter)
# Format for the target LLM
formatted_data = [{"text": format_for_llm(item, model_name)} for item in raw_data]
dataset = Dataset.from_list(formatted_data)
print(f"Dataset size: {len(dataset)}")
# Severity distribution
from collections import Counter
sev_dist = Counter(item['severity'] for item in raw_data)
print(f"Severity distribution: {dict(sev_dist)}")
# QLoRA config - 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load model
print(f"Loading {model_name} with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)
# LoRA config
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")
# Training arguments
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=8,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=0.1,
logging_steps=25,
save_strategy="no",
bf16=True,
gradient_checkpointing=True,
optim="paged_adamw_8bit",
report_to="none",
max_grad_norm=0.3,
)
# SFT Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
peft_config=lora_config,
processing_class=tokenizer,
max_seq_length=max_seq_length,
)
# Train
print("Starting LoRA fine-tuning...")
train_result = trainer.train()
# Metrics
metrics = {
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(raw_data),
'lora_r': lora_r,
'lora_alpha': lora_alpha,
'trainable_params': trainable_params,
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'vram_gb': torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0,
'data_source': 'drugbank_176k',
'severity_dist': dict(sev_dist),
}
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
return {
'status': 'success',
'metrics': metrics,
'message': f'Gemma 3 fine-tuned on {len(raw_data):,} real DrugBank DDI samples'
}
except Exception as e:
import traceback
return {
'status': 'error',
'error': str(e),
'traceback': traceback.format_exc()
}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""Train BERT-style classifier for DDI severity prediction."""
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer
)
from datasets import Dataset
from sklearn.model_selection import train_test_split
import tempfile
import shutil
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
max_samples = job_input.get('max_samples', 50000)
epochs = job_input.get('epochs', 3)
learning_rate = job_input.get('learning_rate', 2e-5)
batch_size = job_input.get('batch_size', 16)
eval_split = job_input.get('eval_split', 0.1)
work_dir = tempfile.mkdtemp()
try:
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"Model: {model_name}")
# Load data
raw_data = load_drugbank_data(max_samples=max_samples)
# Create text + label format
# Shift severity to 0-indexed (1-4 -> 0-3)
training_data = [{
"text": f"{d['drug1']} and {d['drug2']}: {d['interaction_text']}",
"label": d['severity'] - 1 # 0-indexed
} for d in raw_data]
print(f"Loaded {len(training_data)} samples")
# Split
if eval_split > 0:
train_data, eval_data = train_test_split(
training_data, test_size=eval_split, random_state=42,
stratify=[d['label'] for d in training_data]
)
else:
train_data, eval_data = training_data, None
train_dataset = Dataset.from_list(train_data)
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
# Load model (4 classes: minor, moderate, major, contraindicated)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=4
)
def tokenize(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256)
train_dataset = train_dataset.map(tokenize, batched=True)
if eval_dataset:
eval_dataset = eval_dataset.map(tokenize, batched=True)
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
eval_strategy='epoch' if eval_dataset else 'no',
save_strategy='no',
fp16=torch.cuda.is_available(),
report_to='none',
)
def compute_metrics(eval_pred):
from sklearn.metrics import accuracy_score, f1_score
preds = eval_pred.predictions.argmax(-1)
return {
'accuracy': accuracy_score(eval_pred.label_ids, preds),
'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics if eval_dataset else None,
)
train_result = trainer.train()
metrics = {
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(train_data),
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'data_source': 'drugbank_176k',
}
if eval_dataset:
eval_result = trainer.evaluate()
metrics.update({
'eval_accuracy': float(eval_result['eval_accuracy']),
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
})
return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained on DrugBank data'}
except Exception as e:
import traceback
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
def handler(job):
"""RunPod serverless handler."""
job_input = job.get('input', {})
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
use_lora = job_input.get('use_lora', True)
# Auto-detect: use LoRA for large models
if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']):
use_lora = True
elif 'bert' in model_name.lower():
use_lora = False
if use_lora:
return train_llm_lora(job_input)
else:
return train_bert_classifier(job_input)
# RunPod serverless entrypoint
runpod.serverless.start({'handler': handler})