feat: Add Gemma 3 12B with QLoRA fine-tuning

- Added PEFT, bitsandbytes, TRL for LoRA training
- 4-bit QLoRA quantization for 48GB GPU fit
- Instruction-tuning format for Gemma chat template
- Auto-detect model type (BERT vs LLM)
- Updated GPU tier to ADA_24/AMPERE_48
This commit is contained in:
2026-02-03 03:58:25 +00:00
parent 4ff491f847
commit 39922e8d2e
2 changed files with 264 additions and 151 deletions

View File

@@ -1,8 +1,8 @@
""" """
RunPod Serverless Handler for DDI Model Training RunPod Serverless Handler for DDI Model Training
This runs on RunPod GPU instances and trains the Bio_ClinicalBERT model Supports both BERT-style classification and LLM fine-tuning with LoRA.
for drug-drug interaction detection using real DDI data. Default: Gemma 3 12B with QLoRA for DDI severity classification.
""" """
import os import os
import json import json
@@ -12,86 +12,93 @@ from typing import Dict, Any, List, Optional
# DDI severity labels # DDI severity labels
DDI_LABELS = { DDI_LABELS = {
0: "none", # No significant interaction 0: "no_interaction",
1: "minor", # Minor interaction 1: "minor",
2: "moderate", # Moderate interaction 2: "moderate",
3: "major", # Major interaction 3: "major",
4: "contraindicated" # Contraindicated 4: "contraindicated"
}
LABEL_DESCRIPTIONS = {
0: "No clinically significant interaction",
1: "Minor interaction - minimal clinical significance",
2: "Moderate interaction - may require monitoring or dose adjustment",
3: "Major interaction - avoid combination if possible, high risk",
4: "Contraindicated - do not use together, life-threatening risk"
} }
def get_real_ddi_data(max_samples: int = 10000) -> List[Dict[str, Any]]: def get_ddi_training_data(max_samples: int = 5000) -> List[Dict[str, Any]]:
""" """Generate DDI training data formatted for instruction tuning."""
Generate real DDI training data from DrugBank patterns.
Uses curated drug interaction patterns based on clinical guidelines.
"""
import random import random
random.seed(42) random.seed(42)
# Real drug pairs with known interactions (based on clinical data) # Real drug interaction patterns based on clinical data
ddi_patterns = [ ddi_patterns = [
# Contraindicated (4) # Contraindicated (4)
{"drugs": ["fluoxetine", "tramadol"], "type": "serotonin syndrome risk", "label": 4}, {"drugs": ["fluoxetine", "tramadol"], "type": "serotonin syndrome risk", "label": 4},
{"drugs": ["fluoxetine", "monoamine oxidase inhibitor"], "type": "serotonin syndrome risk", "label": 4}, {"drugs": ["fluoxetine", "phenelzine"], "type": "serotonin syndrome risk", "label": 4},
{"drugs": ["simvastatin", "itraconazole"], "type": "rhabdomyolysis risk", "label": 4}, {"drugs": ["simvastatin", "itraconazole"], "type": "rhabdomyolysis risk", "label": 4},
{"drugs": ["methotrexate", "trimethoprim"], "type": "severe bone marrow suppression", "label": 4}, {"drugs": ["methotrexate", "trimethoprim"], "type": "severe bone marrow suppression", "label": 4},
{"drugs": ["warfarin", "miconazole"], "type": "severe bleeding risk", "label": 4}, {"drugs": ["warfarin", "miconazole"], "type": "severe bleeding risk", "label": 4},
{"drugs": ["cisapride", "erythromycin"], "type": "QT prolongation cardiac arrest", "label": 4}, {"drugs": ["cisapride", "erythromycin"], "type": "QT prolongation cardiac arrest", "label": 4},
{"drugs": ["pimozide", "clarithromycin"], "type": "QT prolongation risk", "label": 4}, {"drugs": ["pimozide", "clarithromycin"], "type": "QT prolongation risk", "label": 4},
{"drugs": ["ergotamine", "ritonavir"], "type": "ergot toxicity risk", "label": 4}, {"drugs": ["ergotamine", "ritonavir"], "type": "ergot toxicity risk", "label": 4},
{"drugs": ["sildenafil", "nitrates"], "type": "severe hypotension", "label": 4}, {"drugs": ["sildenafil", "nitroglycerin"], "type": "severe hypotension", "label": 4},
{"drugs": ["linezolid", "serotonergic agents"], "type": "serotonin syndrome", "label": 4}, {"drugs": ["linezolid", "sertraline"], "type": "serotonin syndrome", "label": 4},
{"drugs": ["maoi", "meperidine"], "type": "hypertensive crisis", "label": 4},
{"drugs": ["metronidazole", "disulfiram"], "type": "psychosis risk", "label": 4},
# Major (3) # Major (3)
{"drugs": ["warfarin", "aspirin"], "type": "increased bleeding risk", "label": 3}, {"drugs": ["warfarin", "aspirin"], "type": "increased bleeding risk", "label": 3},
{"drugs": ["digoxin", "amiodarone"], "type": "digoxin toxicity elevated", "label": 3}, {"drugs": ["digoxin", "amiodarone"], "type": "digoxin toxicity elevated", "label": 3},
{"drugs": ["lithium", "ibuprofen"], "type": "lithium toxicity risk", "label": 3}, {"drugs": ["lithium", "ibuprofen"], "type": "lithium toxicity risk", "label": 3},
{"drugs": ["metformin", "contrast media"], "type": "lactic acidosis risk", "label": 3}, {"drugs": ["metformin", "iodinated contrast"], "type": "lactic acidosis risk", "label": 3},
{"drugs": ["potassium", "ACE inhibitor"], "type": "hyperkalemia risk", "label": 3}, {"drugs": ["potassium chloride", "lisinopril"], "type": "hyperkalemia risk", "label": 3},
{"drugs": ["opioid", "benzodiazepine"], "type": "respiratory depression", "label": 3}, {"drugs": ["oxycodone", "alprazolam"], "type": "respiratory depression", "label": 3},
{"drugs": ["theophylline", "ciprofloxacin"], "type": "theophylline toxicity", "label": 3}, {"drugs": ["theophylline", "ciprofloxacin"], "type": "theophylline toxicity", "label": 3},
{"drugs": ["phenytoin", "fluconazole"], "type": "phenytoin toxicity", "label": 3}, {"drugs": ["phenytoin", "fluconazole"], "type": "phenytoin toxicity", "label": 3},
{"drugs": ["carbamazepine", "verapamil"], "type": "carbamazepine toxicity", "label": 3}, {"drugs": ["carbamazepine", "verapamil"], "type": "carbamazepine toxicity", "label": 3},
{"drugs": ["cyclosporine", "ketoconazole"], "type": "nephrotoxicity risk", "label": 3}, {"drugs": ["cyclosporine", "ketoconazole"], "type": "nephrotoxicity risk", "label": 3},
{"drugs": ["methotrexate", "NSAIDs"], "type": "methotrexate toxicity", "label": 3}, {"drugs": ["methotrexate", "ibuprofen"], "type": "methotrexate toxicity", "label": 3},
{"drugs": ["quinidine", "digoxin"], "type": "digoxin toxicity", "label": 3}, {"drugs": ["quinidine", "digoxin"], "type": "digoxin toxicity", "label": 3},
{"drugs": ["clopidogrel", "omeprazole"], "type": "reduced antiplatelet effect", "label": 3}, {"drugs": ["clopidogrel", "omeprazole"], "type": "reduced antiplatelet effect", "label": 3},
{"drugs": ["warfarin", "vitamin K"], "type": "reduced anticoagulation", "label": 3}, {"drugs": ["warfarin", "rifampin"], "type": "reduced anticoagulation", "label": 3},
{"drugs": ["dabigatran", "rifampin"], "type": "reduced anticoagulant effect", "label": 3}, {"drugs": ["dabigatran", "rifampin"], "type": "reduced anticoagulant effect", "label": 3},
# Moderate (2) # Moderate (2)
{"drugs": ["simvastatin", "amlodipine"], "type": "increased statin exposure", "label": 2}, {"drugs": ["simvastatin", "amlodipine"], "type": "increased statin exposure", "label": 2},
{"drugs": ["metformin", "cimetidine"], "type": "increased metformin levels", "label": 2}, {"drugs": ["metformin", "cimetidine"], "type": "increased metformin levels", "label": 2},
{"drugs": ["levothyroxine", "calcium"], "type": "reduced thyroid absorption", "label": 2}, {"drugs": ["levothyroxine", "calcium carbonate"], "type": "reduced thyroid absorption", "label": 2},
{"drugs": ["gabapentin", "antacids"], "type": "reduced gabapentin absorption", "label": 2}, {"drugs": ["gabapentin", "aluminum hydroxide"], "type": "reduced gabapentin absorption", "label": 2},
{"drugs": ["furosemide", "gentamicin"], "type": "ototoxicity risk", "label": 2}, {"drugs": ["furosemide", "gentamicin"], "type": "ototoxicity risk", "label": 2},
{"drugs": ["prednisone", "NSAIDs"], "type": "GI bleeding risk", "label": 2}, {"drugs": ["prednisone", "naproxen"], "type": "GI bleeding risk", "label": 2},
{"drugs": ["metoprolol", "verapamil"], "type": "bradycardia risk", "label": 2}, {"drugs": ["metoprolol", "verapamil"], "type": "bradycardia risk", "label": 2},
{"drugs": ["sertraline", "tramadol"], "type": "seizure threshold lowered", "label": 2}, {"drugs": ["sertraline", "tramadol"], "type": "seizure threshold lowered", "label": 2},
{"drugs": ["losartan", "potassium supplements"], "type": "hyperkalemia risk", "label": 2}, {"drugs": ["losartan", "potassium supplements"], "type": "hyperkalemia risk", "label": 2},
{"drugs": ["alprazolam", "ketoconazole"], "type": "increased sedation", "label": 2}, {"drugs": ["alprazolam", "ketoconazole"], "type": "increased sedation", "label": 2},
{"drugs": ["atorvastatin", "grapefruit"], "type": "increased statin levels", "label": 2}, {"drugs": ["atorvastatin", "grapefruit juice"], "type": "increased statin levels", "label": 2},
{"drugs": ["ciprofloxacin", "iron"], "type": "reduced antibiotic absorption", "label": 2}, {"drugs": ["ciprofloxacin", "ferrous sulfate"], "type": "reduced antibiotic absorption", "label": 2},
{"drugs": ["warfarin", "acetaminophen"], "type": "slight INR increase", "label": 2}, {"drugs": ["warfarin", "acetaminophen"], "type": "slight INR increase", "label": 2},
{"drugs": ["insulin", "beta blocker"], "type": "masked hypoglycemia", "label": 2}, {"drugs": ["insulin", "propranolol"], "type": "masked hypoglycemia", "label": 2},
{"drugs": ["digoxin", "spironolactone"], "type": "increased digoxin levels", "label": 2}, {"drugs": ["digoxin", "spironolactone"], "type": "increased digoxin levels", "label": 2},
# Minor (1) # Minor (1)
{"drugs": ["aspirin", "ibuprofen"], "type": "reduced cardioprotection", "label": 1}, {"drugs": ["aspirin", "ibuprofen"], "type": "reduced cardioprotection", "label": 1},
{"drugs": ["metformin", "vitamin B12"], "type": "reduced B12 absorption long-term", "label": 1}, {"drugs": ["metformin", "vitamin B12"], "type": "reduced B12 absorption long-term", "label": 1},
{"drugs": ["amoxicillin", "oral contraceptives"], "type": "theoretical reduced efficacy", "label": 1}, {"drugs": ["amoxicillin", "ethinyl estradiol"], "type": "theoretical reduced efficacy", "label": 1},
{"drugs": ["proton pump inhibitor", "vitamin B12"], "type": "reduced absorption", "label": 1}, {"drugs": ["omeprazole", "vitamin B12"], "type": "reduced absorption", "label": 1},
{"drugs": ["caffeine", "fluoroquinolones"], "type": "increased caffeine effect", "label": 1}, {"drugs": ["caffeine", "ciprofloxacin"], "type": "increased caffeine effect", "label": 1},
{"drugs": ["antacids", "iron"], "type": "timing interaction", "label": 1}, {"drugs": ["calcium carbonate", "ferrous sulfate"], "type": "timing interaction", "label": 1},
{"drugs": ["statin", "niacin"], "type": "monitoring recommended", "label": 1}, {"drugs": ["atorvastatin", "niacin"], "type": "monitoring recommended", "label": 1},
{"drugs": ["ACE inhibitor", "aspirin"], "type": "possible reduced effect", "label": 1}, {"drugs": ["lisinopril", "aspirin"], "type": "possible reduced effect", "label": 1},
{"drugs": ["thiazide", "calcium"], "type": "hypercalcemia monitoring", "label": 1}, {"drugs": ["hydrochlorothiazide", "calcium"], "type": "hypercalcemia monitoring", "label": 1},
{"drugs": ["beta blocker", "clonidine"], "type": "withdrawal monitoring", "label": 1}, {"drugs": ["metoprolol", "clonidine"], "type": "withdrawal monitoring", "label": 1},
# No interaction (0) # No interaction (0)
{"drugs": ["amlodipine", "atorvastatin"], "type": "safe combination", "label": 0}, {"drugs": ["amlodipine", "atorvastatin"], "type": "safe combination", "label": 0},
{"drugs": ["metformin", "lisinopril"], "type": "complementary therapy", "label": 0}, {"drugs": ["metformin", "lisinopril"], "type": "complementary therapy", "label": 0},
{"drugs": ["omeprazole", "levothyroxine"], "type": "can be used together", "label": 0}, {"drugs": ["omeprazole", "levothyroxine"], "type": "can be used together with spacing", "label": 0},
{"drugs": ["aspirin", "atorvastatin"], "type": "standard combination", "label": 0}, {"drugs": ["aspirin", "atorvastatin"], "type": "standard combination", "label": 0},
{"drugs": ["metoprolol", "lisinopril"], "type": "common combination", "label": 0}, {"drugs": ["metoprolol", "lisinopril"], "type": "common combination", "label": 0},
{"drugs": ["gabapentin", "acetaminophen"], "type": "no interaction", "label": 0}, {"drugs": ["gabapentin", "acetaminophen"], "type": "no interaction", "label": 0},
@@ -101,52 +108,201 @@ def get_real_ddi_data(max_samples: int = 10000) -> List[Dict[str, Any]]:
{"drugs": ["pantoprazole", "amlodipine"], "type": "no known interaction", "label": 0}, {"drugs": ["pantoprazole", "amlodipine"], "type": "no known interaction", "label": 0},
] ]
# Expand with variations
training_data = [] training_data = []
for pattern in ddi_patterns: for pattern in ddi_patterns:
drug1, drug2 = pattern["drugs"] drug1, drug2 = pattern["drugs"]
interaction_type = pattern["type"] interaction_type = pattern["type"]
label = pattern["label"] label = pattern["label"]
label_name = DDI_LABELS[label]
label_desc = LABEL_DESCRIPTIONS[label]
# Create multiple text variations # Create instruction-tuning format
variations = [ prompts = [
f"{drug1} and {drug2} interaction: {interaction_type}", f"Analyze the drug-drug interaction between {drug1} and {drug2}.",
f"{drug2} combined with {drug1} causes {interaction_type}", f"What is the severity of combining {drug1} with {drug2}?",
f"Patient taking {drug1} with {drug2}: {interaction_type}", f"A patient is taking {drug1}. They need to start {drug2}. Assess the interaction risk.",
f"Concomitant use of {drug1} and {drug2} leads to {interaction_type}", f"Evaluate the interaction: {drug1} + {drug2}",
f"{drug1} {drug2} drug-drug interaction {interaction_type}", f"Drug interaction check: {drug1} and {drug2}",
] ]
for text in variations: for prompt in prompts:
training_data.append({"text": text, "label": label}) response = f"Severity: {label_name.upper()}\nInteraction: {interaction_type}\nRecommendation: {label_desc}"
training_data.append({
"instruction": prompt,
"response": response,
"label": label,
"label_name": label_name
})
# Shuffle and limit # Shuffle and replicate to reach target size
random.shuffle(training_data) random.shuffle(training_data)
# Replicate to reach target size
while len(training_data) < max_samples: while len(training_data) < max_samples:
training_data.extend(training_data[:min(len(training_data), max_samples - len(training_data))]) training_data.extend(training_data[:min(len(training_data), max_samples - len(training_data))])
return training_data[:max_samples] return training_data[:max_samples]
def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]: def format_for_gemma(example: Dict) -> str:
""" """Format example for Gemma instruction tuning."""
Train DDI detection model. return f"""<start_of_turn>user
{example['instruction']}<end_of_turn>
<start_of_turn>model
{example['response']}<end_of_turn>"""
def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""Train Gemma 3 with QLoRA for DDI classification."""
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import Dataset
import tempfile
import shutil
Expected input: # Parameters
{ model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
"model_name": "emilyalsentzer/Bio_ClinicalBERT", max_samples = job_input.get('max_samples', 2000)
"use_real_data": true, epochs = job_input.get('epochs', 1)
"max_samples": 5000, learning_rate = job_input.get('learning_rate', 2e-4)
"training_data": [...], # Or provide inline data batch_size = job_input.get('batch_size', 4)
"epochs": 3, lora_r = job_input.get('lora_r', 16)
"learning_rate": 2e-5, lora_alpha = job_input.get('lora_alpha', 32)
"batch_size": 16, max_seq_length = job_input.get('max_seq_length', 512)
"eval_split": 0.1
} work_dir = tempfile.mkdtemp()
"""
try:
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"Model: {model_name}")
print(f"LoRA r={lora_r}, alpha={lora_alpha}")
print(f"Samples: {max_samples}, Epochs: {epochs}")
# Load training data
print("Loading DDI training data...")
training_data = get_ddi_training_data(max_samples=max_samples)
# Format for Gemma
formatted_data = [{"text": format_for_gemma(ex)} for ex in training_data]
dataset = Dataset.from_list(formatted_data)
print(f"Dataset size: {len(dataset)}")
# QLoRA config - 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load model
print(f"Loading {model_name} with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)
# LoRA config
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Training arguments
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=0.1,
logging_steps=10,
save_strategy="no",
bf16=True,
gradient_checkpointing=True,
optim="paged_adamw_8bit",
report_to="none",
max_grad_norm=0.3,
)
# SFT Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
peft_config=lora_config,
processing_class=tokenizer,
max_seq_length=max_seq_length,
)
# Train
print("Starting LoRA fine-tuning...")
train_result = trainer.train()
# Metrics
metrics = {
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(training_data),
'lora_r': lora_r,
'lora_alpha': lora_alpha,
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad),
'quantization': '4-bit QLoRA',
}
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
return {
'status': 'success',
'metrics': metrics,
'message': f'Gemma 3 12B fine-tuned with QLoRA on DDI data'
}
except Exception as e:
import traceback
return {
'status': 'error',
'error': str(e),
'traceback': traceback.format_exc()
}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
# Clear GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""Train BERT-style classifier (original approach)."""
import torch import torch
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
@@ -159,165 +315,119 @@ def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
import tempfile import tempfile
import shutil import shutil
# Extract parameters
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
use_real_data = job_input.get('use_real_data', True)
max_samples = job_input.get('max_samples', 5000) max_samples = job_input.get('max_samples', 5000)
training_data = job_input.get('training_data', None)
epochs = job_input.get('epochs', 3) epochs = job_input.get('epochs', 3)
learning_rate = job_input.get('learning_rate', 2e-5) learning_rate = job_input.get('learning_rate', 2e-5)
batch_size = job_input.get('batch_size', 16) batch_size = job_input.get('batch_size', 16)
eval_split = job_input.get('eval_split', 0.1) eval_split = job_input.get('eval_split', 0.1)
# Load data
if use_real_data and not training_data:
print("Loading curated DDI dataset...")
training_data = get_real_ddi_data(max_samples=max_samples)
elif not training_data:
print("No training data provided, using sample DDI dataset...")
training_data = get_real_ddi_data(max_samples=150)
# Create temp directory
work_dir = tempfile.mkdtemp() work_dir = tempfile.mkdtemp()
model_dir = os.path.join(work_dir, 'model')
try: try:
print(f"Training samples: {len(training_data)}")
print(f"Model: {model_name}")
print(f"Epochs: {epochs}, Batch size: {batch_size}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"Model: {model_name}")
# Count label distribution # Get data in BERT format
label_counts = {} raw_data = get_ddi_training_data(max_samples=max_samples)
for item in training_data: training_data = [{"text": d["instruction"], "label": d["label"]} for d in raw_data]
label_counts[item['label']] = label_counts.get(item['label'], 0) + 1
print(f"Label distribution: {label_counts}")
# Split into train/eval # Split
if eval_split > 0 and len(training_data) > 100: if eval_split > 0:
train_data, eval_data = train_test_split( train_data, eval_data = train_test_split(
training_data, training_data, test_size=eval_split, random_state=42,
test_size=eval_split,
random_state=42,
stratify=[d['label'] for d in training_data] stratify=[d['label'] for d in training_data]
) )
print(f"Train: {len(train_data)}, Eval: {len(eval_data)}")
else: else:
train_data = training_data train_data, eval_data = training_data, None
eval_data = None
# Create datasets
train_dataset = Dataset.from_list(train_data) train_dataset = Dataset.from_list(train_data)
eval_dataset = Dataset.from_list(eval_data) if eval_data else None eval_dataset = Dataset.from_list(eval_data) if eval_data else None
# Load model and tokenizer # Load model
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained( model = AutoModelForSequenceClassification.from_pretrained(
model_name, model_name, num_labels=5
num_labels=5 # DDI severity: none(0), minor(1), moderate(2), major(3), contraindicated(4)
) )
# Tokenize datasets def tokenize(examples):
def tokenize_function(examples): return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256)
return tokenizer(
examples['text'],
padding='max_length',
truncation=True,
max_length=256
)
tokenized_train = train_dataset.map(tokenize_function, batched=True) train_dataset = train_dataset.map(tokenize, batched=True)
tokenized_eval = eval_dataset.map(tokenize_function, batched=True) if eval_dataset else None if eval_dataset:
eval_dataset = eval_dataset.map(tokenize, batched=True)
# Training arguments
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir=model_dir, output_dir=work_dir,
num_train_epochs=epochs, num_train_epochs=epochs,
learning_rate=learning_rate, learning_rate=learning_rate,
per_device_train_batch_size=batch_size, per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size, eval_strategy='epoch' if eval_dataset else 'no',
warmup_ratio=0.1,
weight_decay=0.01,
logging_steps=50,
eval_strategy='epoch' if tokenized_eval else 'no',
save_strategy='no', save_strategy='no',
fp16=torch.cuda.is_available(), fp16=torch.cuda.is_available(),
report_to='none', report_to='none',
load_best_model_at_end=False,
) )
# Compute metrics function
def compute_metrics(eval_pred): def compute_metrics(eval_pred):
from sklearn.metrics import accuracy_score, f1_score from sklearn.metrics import accuracy_score, f1_score
predictions, labels = eval_pred preds = eval_pred.predictions.argmax(-1)
predictions = predictions.argmax(-1)
return { return {
'accuracy': accuracy_score(labels, predictions), 'accuracy': accuracy_score(eval_pred.label_ids, preds),
'f1_macro': f1_score(labels, predictions, average='macro'), 'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'),
'f1_weighted': f1_score(labels, predictions, average='weighted'),
} }
# Initialize trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_train, train_dataset=train_dataset,
eval_dataset=tokenized_eval, eval_dataset=eval_dataset,
compute_metrics=compute_metrics if tokenized_eval else None, compute_metrics=compute_metrics if eval_dataset else None,
) )
# Train
print("Starting training...")
train_result = trainer.train() train_result = trainer.train()
# Get metrics
metrics = { metrics = {
'train_loss': float(train_result.training_loss), 'train_loss': float(train_result.training_loss),
'epochs': epochs, 'epochs': epochs,
'model_name': model_name, 'model_name': model_name,
'samples': len(train_data), 'samples': len(train_data),
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'data_source': 'curated_ddi'
} }
# Run evaluation if we have eval data if eval_dataset:
if tokenized_eval:
print("Running evaluation...")
eval_result = trainer.evaluate() eval_result = trainer.evaluate()
metrics.update({ metrics.update({
'eval_loss': float(eval_result['eval_loss']),
'eval_accuracy': float(eval_result['eval_accuracy']), 'eval_accuracy': float(eval_result['eval_accuracy']),
'eval_f1_macro': float(eval_result['eval_f1_macro']),
'eval_f1_weighted': float(eval_result['eval_f1_weighted']), 'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
}) })
print(f"Training complete! Loss: {metrics['train_loss']:.4f}") return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained'}
if 'eval_accuracy' in metrics:
print(f"Eval accuracy: {metrics['eval_accuracy']:.4f}, F1: {metrics['eval_f1_weighted']:.4f}")
return {
'status': 'success',
'metrics': metrics,
'message': 'Model trained successfully on curated DDI data'
}
except Exception as e: except Exception as e:
import traceback import traceback
return { return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
'status': 'error',
'error': str(e),
'traceback': traceback.format_exc()
}
finally: finally:
# Cleanup
shutil.rmtree(work_dir, ignore_errors=True) shutil.rmtree(work_dir, ignore_errors=True)
def handler(job): def handler(job):
"""RunPod serverless handler.""" """RunPod serverless handler."""
job_input = job.get('input', {}) job_input = job.get('input', {})
return train_ddi_model(job_input)
# Choose training mode
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
use_lora = job_input.get('use_lora', True)
# Auto-detect: use LoRA for large models
if 'gemma' in model_name.lower() or 'llama' in model_name.lower() or 'mistral' in model_name.lower():
use_lora = True
elif 'bert' in model_name.lower():
use_lora = False
if use_lora:
return train_gemma_lora(job_input)
else:
return train_bert_classifier(job_input)
# RunPod serverless entrypoint # RunPod serverless entrypoint

View File

@@ -1,5 +1,5 @@
runpod>=1.7.0 runpod>=1.7.0
transformers==4.44.0 transformers>=4.48.0
datasets>=2.16.0 datasets>=2.16.0
accelerate>=0.30.0 accelerate>=0.30.0
boto3>=1.34.0 boto3>=1.34.0
@@ -7,3 +7,6 @@ scikit-learn>=1.3.0
scipy>=1.11.0 scipy>=1.11.0
safetensors>=0.4.0 safetensors>=0.4.0
requests>=2.31.0 requests>=2.31.0
peft>=0.14.0
bitsandbytes>=0.45.0
trl>=0.14.0