mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
feat: Add Gemma 3 12B with QLoRA fine-tuning
- Added PEFT, bitsandbytes, TRL for LoRA training - 4-bit QLoRA quantization for 48GB GPU fit - Instruction-tuning format for Gemma chat template - Auto-detect model type (BERT vs LLM) - Updated GPU tier to ADA_24/AMPERE_48
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
RunPod Serverless Handler for DDI Model Training
|
RunPod Serverless Handler for DDI Model Training
|
||||||
|
|
||||||
This runs on RunPod GPU instances and trains the Bio_ClinicalBERT model
|
Supports both BERT-style classification and LLM fine-tuning with LoRA.
|
||||||
for drug-drug interaction detection using real DDI data.
|
Default: Gemma 3 12B with QLoRA for DDI severity classification.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
@@ -12,86 +12,93 @@ from typing import Dict, Any, List, Optional
|
|||||||
|
|
||||||
# DDI severity labels
|
# DDI severity labels
|
||||||
DDI_LABELS = {
|
DDI_LABELS = {
|
||||||
0: "none", # No significant interaction
|
0: "no_interaction",
|
||||||
1: "minor", # Minor interaction
|
1: "minor",
|
||||||
2: "moderate", # Moderate interaction
|
2: "moderate",
|
||||||
3: "major", # Major interaction
|
3: "major",
|
||||||
4: "contraindicated" # Contraindicated
|
4: "contraindicated"
|
||||||
|
}
|
||||||
|
|
||||||
|
LABEL_DESCRIPTIONS = {
|
||||||
|
0: "No clinically significant interaction",
|
||||||
|
1: "Minor interaction - minimal clinical significance",
|
||||||
|
2: "Moderate interaction - may require monitoring or dose adjustment",
|
||||||
|
3: "Major interaction - avoid combination if possible, high risk",
|
||||||
|
4: "Contraindicated - do not use together, life-threatening risk"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_real_ddi_data(max_samples: int = 10000) -> List[Dict[str, Any]]:
|
def get_ddi_training_data(max_samples: int = 5000) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""Generate DDI training data formatted for instruction tuning."""
|
||||||
Generate real DDI training data from DrugBank patterns.
|
|
||||||
Uses curated drug interaction patterns based on clinical guidelines.
|
|
||||||
"""
|
|
||||||
import random
|
import random
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
|
|
||||||
# Real drug pairs with known interactions (based on clinical data)
|
# Real drug interaction patterns based on clinical data
|
||||||
ddi_patterns = [
|
ddi_patterns = [
|
||||||
# Contraindicated (4)
|
# Contraindicated (4)
|
||||||
{"drugs": ["fluoxetine", "tramadol"], "type": "serotonin syndrome risk", "label": 4},
|
{"drugs": ["fluoxetine", "tramadol"], "type": "serotonin syndrome risk", "label": 4},
|
||||||
{"drugs": ["fluoxetine", "monoamine oxidase inhibitor"], "type": "serotonin syndrome risk", "label": 4},
|
{"drugs": ["fluoxetine", "phenelzine"], "type": "serotonin syndrome risk", "label": 4},
|
||||||
{"drugs": ["simvastatin", "itraconazole"], "type": "rhabdomyolysis risk", "label": 4},
|
{"drugs": ["simvastatin", "itraconazole"], "type": "rhabdomyolysis risk", "label": 4},
|
||||||
{"drugs": ["methotrexate", "trimethoprim"], "type": "severe bone marrow suppression", "label": 4},
|
{"drugs": ["methotrexate", "trimethoprim"], "type": "severe bone marrow suppression", "label": 4},
|
||||||
{"drugs": ["warfarin", "miconazole"], "type": "severe bleeding risk", "label": 4},
|
{"drugs": ["warfarin", "miconazole"], "type": "severe bleeding risk", "label": 4},
|
||||||
{"drugs": ["cisapride", "erythromycin"], "type": "QT prolongation cardiac arrest", "label": 4},
|
{"drugs": ["cisapride", "erythromycin"], "type": "QT prolongation cardiac arrest", "label": 4},
|
||||||
{"drugs": ["pimozide", "clarithromycin"], "type": "QT prolongation risk", "label": 4},
|
{"drugs": ["pimozide", "clarithromycin"], "type": "QT prolongation risk", "label": 4},
|
||||||
{"drugs": ["ergotamine", "ritonavir"], "type": "ergot toxicity risk", "label": 4},
|
{"drugs": ["ergotamine", "ritonavir"], "type": "ergot toxicity risk", "label": 4},
|
||||||
{"drugs": ["sildenafil", "nitrates"], "type": "severe hypotension", "label": 4},
|
{"drugs": ["sildenafil", "nitroglycerin"], "type": "severe hypotension", "label": 4},
|
||||||
{"drugs": ["linezolid", "serotonergic agents"], "type": "serotonin syndrome", "label": 4},
|
{"drugs": ["linezolid", "sertraline"], "type": "serotonin syndrome", "label": 4},
|
||||||
|
{"drugs": ["maoi", "meperidine"], "type": "hypertensive crisis", "label": 4},
|
||||||
|
{"drugs": ["metronidazole", "disulfiram"], "type": "psychosis risk", "label": 4},
|
||||||
|
|
||||||
# Major (3)
|
# Major (3)
|
||||||
{"drugs": ["warfarin", "aspirin"], "type": "increased bleeding risk", "label": 3},
|
{"drugs": ["warfarin", "aspirin"], "type": "increased bleeding risk", "label": 3},
|
||||||
{"drugs": ["digoxin", "amiodarone"], "type": "digoxin toxicity elevated", "label": 3},
|
{"drugs": ["digoxin", "amiodarone"], "type": "digoxin toxicity elevated", "label": 3},
|
||||||
{"drugs": ["lithium", "ibuprofen"], "type": "lithium toxicity risk", "label": 3},
|
{"drugs": ["lithium", "ibuprofen"], "type": "lithium toxicity risk", "label": 3},
|
||||||
{"drugs": ["metformin", "contrast media"], "type": "lactic acidosis risk", "label": 3},
|
{"drugs": ["metformin", "iodinated contrast"], "type": "lactic acidosis risk", "label": 3},
|
||||||
{"drugs": ["potassium", "ACE inhibitor"], "type": "hyperkalemia risk", "label": 3},
|
{"drugs": ["potassium chloride", "lisinopril"], "type": "hyperkalemia risk", "label": 3},
|
||||||
{"drugs": ["opioid", "benzodiazepine"], "type": "respiratory depression", "label": 3},
|
{"drugs": ["oxycodone", "alprazolam"], "type": "respiratory depression", "label": 3},
|
||||||
{"drugs": ["theophylline", "ciprofloxacin"], "type": "theophylline toxicity", "label": 3},
|
{"drugs": ["theophylline", "ciprofloxacin"], "type": "theophylline toxicity", "label": 3},
|
||||||
{"drugs": ["phenytoin", "fluconazole"], "type": "phenytoin toxicity", "label": 3},
|
{"drugs": ["phenytoin", "fluconazole"], "type": "phenytoin toxicity", "label": 3},
|
||||||
{"drugs": ["carbamazepine", "verapamil"], "type": "carbamazepine toxicity", "label": 3},
|
{"drugs": ["carbamazepine", "verapamil"], "type": "carbamazepine toxicity", "label": 3},
|
||||||
{"drugs": ["cyclosporine", "ketoconazole"], "type": "nephrotoxicity risk", "label": 3},
|
{"drugs": ["cyclosporine", "ketoconazole"], "type": "nephrotoxicity risk", "label": 3},
|
||||||
{"drugs": ["methotrexate", "NSAIDs"], "type": "methotrexate toxicity", "label": 3},
|
{"drugs": ["methotrexate", "ibuprofen"], "type": "methotrexate toxicity", "label": 3},
|
||||||
{"drugs": ["quinidine", "digoxin"], "type": "digoxin toxicity", "label": 3},
|
{"drugs": ["quinidine", "digoxin"], "type": "digoxin toxicity", "label": 3},
|
||||||
{"drugs": ["clopidogrel", "omeprazole"], "type": "reduced antiplatelet effect", "label": 3},
|
{"drugs": ["clopidogrel", "omeprazole"], "type": "reduced antiplatelet effect", "label": 3},
|
||||||
{"drugs": ["warfarin", "vitamin K"], "type": "reduced anticoagulation", "label": 3},
|
{"drugs": ["warfarin", "rifampin"], "type": "reduced anticoagulation", "label": 3},
|
||||||
{"drugs": ["dabigatran", "rifampin"], "type": "reduced anticoagulant effect", "label": 3},
|
{"drugs": ["dabigatran", "rifampin"], "type": "reduced anticoagulant effect", "label": 3},
|
||||||
|
|
||||||
# Moderate (2)
|
# Moderate (2)
|
||||||
{"drugs": ["simvastatin", "amlodipine"], "type": "increased statin exposure", "label": 2},
|
{"drugs": ["simvastatin", "amlodipine"], "type": "increased statin exposure", "label": 2},
|
||||||
{"drugs": ["metformin", "cimetidine"], "type": "increased metformin levels", "label": 2},
|
{"drugs": ["metformin", "cimetidine"], "type": "increased metformin levels", "label": 2},
|
||||||
{"drugs": ["levothyroxine", "calcium"], "type": "reduced thyroid absorption", "label": 2},
|
{"drugs": ["levothyroxine", "calcium carbonate"], "type": "reduced thyroid absorption", "label": 2},
|
||||||
{"drugs": ["gabapentin", "antacids"], "type": "reduced gabapentin absorption", "label": 2},
|
{"drugs": ["gabapentin", "aluminum hydroxide"], "type": "reduced gabapentin absorption", "label": 2},
|
||||||
{"drugs": ["furosemide", "gentamicin"], "type": "ototoxicity risk", "label": 2},
|
{"drugs": ["furosemide", "gentamicin"], "type": "ototoxicity risk", "label": 2},
|
||||||
{"drugs": ["prednisone", "NSAIDs"], "type": "GI bleeding risk", "label": 2},
|
{"drugs": ["prednisone", "naproxen"], "type": "GI bleeding risk", "label": 2},
|
||||||
{"drugs": ["metoprolol", "verapamil"], "type": "bradycardia risk", "label": 2},
|
{"drugs": ["metoprolol", "verapamil"], "type": "bradycardia risk", "label": 2},
|
||||||
{"drugs": ["sertraline", "tramadol"], "type": "seizure threshold lowered", "label": 2},
|
{"drugs": ["sertraline", "tramadol"], "type": "seizure threshold lowered", "label": 2},
|
||||||
{"drugs": ["losartan", "potassium supplements"], "type": "hyperkalemia risk", "label": 2},
|
{"drugs": ["losartan", "potassium supplements"], "type": "hyperkalemia risk", "label": 2},
|
||||||
{"drugs": ["alprazolam", "ketoconazole"], "type": "increased sedation", "label": 2},
|
{"drugs": ["alprazolam", "ketoconazole"], "type": "increased sedation", "label": 2},
|
||||||
{"drugs": ["atorvastatin", "grapefruit"], "type": "increased statin levels", "label": 2},
|
{"drugs": ["atorvastatin", "grapefruit juice"], "type": "increased statin levels", "label": 2},
|
||||||
{"drugs": ["ciprofloxacin", "iron"], "type": "reduced antibiotic absorption", "label": 2},
|
{"drugs": ["ciprofloxacin", "ferrous sulfate"], "type": "reduced antibiotic absorption", "label": 2},
|
||||||
{"drugs": ["warfarin", "acetaminophen"], "type": "slight INR increase", "label": 2},
|
{"drugs": ["warfarin", "acetaminophen"], "type": "slight INR increase", "label": 2},
|
||||||
{"drugs": ["insulin", "beta blocker"], "type": "masked hypoglycemia", "label": 2},
|
{"drugs": ["insulin", "propranolol"], "type": "masked hypoglycemia", "label": 2},
|
||||||
{"drugs": ["digoxin", "spironolactone"], "type": "increased digoxin levels", "label": 2},
|
{"drugs": ["digoxin", "spironolactone"], "type": "increased digoxin levels", "label": 2},
|
||||||
|
|
||||||
# Minor (1)
|
# Minor (1)
|
||||||
{"drugs": ["aspirin", "ibuprofen"], "type": "reduced cardioprotection", "label": 1},
|
{"drugs": ["aspirin", "ibuprofen"], "type": "reduced cardioprotection", "label": 1},
|
||||||
{"drugs": ["metformin", "vitamin B12"], "type": "reduced B12 absorption long-term", "label": 1},
|
{"drugs": ["metformin", "vitamin B12"], "type": "reduced B12 absorption long-term", "label": 1},
|
||||||
{"drugs": ["amoxicillin", "oral contraceptives"], "type": "theoretical reduced efficacy", "label": 1},
|
{"drugs": ["amoxicillin", "ethinyl estradiol"], "type": "theoretical reduced efficacy", "label": 1},
|
||||||
{"drugs": ["proton pump inhibitor", "vitamin B12"], "type": "reduced absorption", "label": 1},
|
{"drugs": ["omeprazole", "vitamin B12"], "type": "reduced absorption", "label": 1},
|
||||||
{"drugs": ["caffeine", "fluoroquinolones"], "type": "increased caffeine effect", "label": 1},
|
{"drugs": ["caffeine", "ciprofloxacin"], "type": "increased caffeine effect", "label": 1},
|
||||||
{"drugs": ["antacids", "iron"], "type": "timing interaction", "label": 1},
|
{"drugs": ["calcium carbonate", "ferrous sulfate"], "type": "timing interaction", "label": 1},
|
||||||
{"drugs": ["statin", "niacin"], "type": "monitoring recommended", "label": 1},
|
{"drugs": ["atorvastatin", "niacin"], "type": "monitoring recommended", "label": 1},
|
||||||
{"drugs": ["ACE inhibitor", "aspirin"], "type": "possible reduced effect", "label": 1},
|
{"drugs": ["lisinopril", "aspirin"], "type": "possible reduced effect", "label": 1},
|
||||||
{"drugs": ["thiazide", "calcium"], "type": "hypercalcemia monitoring", "label": 1},
|
{"drugs": ["hydrochlorothiazide", "calcium"], "type": "hypercalcemia monitoring", "label": 1},
|
||||||
{"drugs": ["beta blocker", "clonidine"], "type": "withdrawal monitoring", "label": 1},
|
{"drugs": ["metoprolol", "clonidine"], "type": "withdrawal monitoring", "label": 1},
|
||||||
|
|
||||||
# No interaction (0)
|
# No interaction (0)
|
||||||
{"drugs": ["amlodipine", "atorvastatin"], "type": "safe combination", "label": 0},
|
{"drugs": ["amlodipine", "atorvastatin"], "type": "safe combination", "label": 0},
|
||||||
{"drugs": ["metformin", "lisinopril"], "type": "complementary therapy", "label": 0},
|
{"drugs": ["metformin", "lisinopril"], "type": "complementary therapy", "label": 0},
|
||||||
{"drugs": ["omeprazole", "levothyroxine"], "type": "can be used together", "label": 0},
|
{"drugs": ["omeprazole", "levothyroxine"], "type": "can be used together with spacing", "label": 0},
|
||||||
{"drugs": ["aspirin", "atorvastatin"], "type": "standard combination", "label": 0},
|
{"drugs": ["aspirin", "atorvastatin"], "type": "standard combination", "label": 0},
|
||||||
{"drugs": ["metoprolol", "lisinopril"], "type": "common combination", "label": 0},
|
{"drugs": ["metoprolol", "lisinopril"], "type": "common combination", "label": 0},
|
||||||
{"drugs": ["gabapentin", "acetaminophen"], "type": "no interaction", "label": 0},
|
{"drugs": ["gabapentin", "acetaminophen"], "type": "no interaction", "label": 0},
|
||||||
@@ -101,52 +108,201 @@ def get_real_ddi_data(max_samples: int = 10000) -> List[Dict[str, Any]]:
|
|||||||
{"drugs": ["pantoprazole", "amlodipine"], "type": "no known interaction", "label": 0},
|
{"drugs": ["pantoprazole", "amlodipine"], "type": "no known interaction", "label": 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Expand with variations
|
|
||||||
training_data = []
|
training_data = []
|
||||||
|
|
||||||
for pattern in ddi_patterns:
|
for pattern in ddi_patterns:
|
||||||
drug1, drug2 = pattern["drugs"]
|
drug1, drug2 = pattern["drugs"]
|
||||||
interaction_type = pattern["type"]
|
interaction_type = pattern["type"]
|
||||||
label = pattern["label"]
|
label = pattern["label"]
|
||||||
|
label_name = DDI_LABELS[label]
|
||||||
|
label_desc = LABEL_DESCRIPTIONS[label]
|
||||||
|
|
||||||
# Create multiple text variations
|
# Create instruction-tuning format
|
||||||
variations = [
|
prompts = [
|
||||||
f"{drug1} and {drug2} interaction: {interaction_type}",
|
f"Analyze the drug-drug interaction between {drug1} and {drug2}.",
|
||||||
f"{drug2} combined with {drug1} causes {interaction_type}",
|
f"What is the severity of combining {drug1} with {drug2}?",
|
||||||
f"Patient taking {drug1} with {drug2}: {interaction_type}",
|
f"A patient is taking {drug1}. They need to start {drug2}. Assess the interaction risk.",
|
||||||
f"Concomitant use of {drug1} and {drug2} leads to {interaction_type}",
|
f"Evaluate the interaction: {drug1} + {drug2}",
|
||||||
f"{drug1} {drug2} drug-drug interaction {interaction_type}",
|
f"Drug interaction check: {drug1} and {drug2}",
|
||||||
]
|
]
|
||||||
|
|
||||||
for text in variations:
|
for prompt in prompts:
|
||||||
training_data.append({"text": text, "label": label})
|
response = f"Severity: {label_name.upper()}\nInteraction: {interaction_type}\nRecommendation: {label_desc}"
|
||||||
|
|
||||||
|
training_data.append({
|
||||||
|
"instruction": prompt,
|
||||||
|
"response": response,
|
||||||
|
"label": label,
|
||||||
|
"label_name": label_name
|
||||||
|
})
|
||||||
|
|
||||||
# Shuffle and limit
|
# Shuffle and replicate to reach target size
|
||||||
random.shuffle(training_data)
|
random.shuffle(training_data)
|
||||||
|
|
||||||
# Replicate to reach target size
|
|
||||||
while len(training_data) < max_samples:
|
while len(training_data) < max_samples:
|
||||||
training_data.extend(training_data[:min(len(training_data), max_samples - len(training_data))])
|
training_data.extend(training_data[:min(len(training_data), max_samples - len(training_data))])
|
||||||
|
|
||||||
return training_data[:max_samples]
|
return training_data[:max_samples]
|
||||||
|
|
||||||
|
|
||||||
def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
def format_for_gemma(example: Dict) -> str:
|
||||||
"""
|
"""Format example for Gemma instruction tuning."""
|
||||||
Train DDI detection model.
|
return f"""<start_of_turn>user
|
||||||
|
{example['instruction']}<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
{example['response']}<end_of_turn>"""
|
||||||
|
|
||||||
|
|
||||||
|
def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Train Gemma 3 with QLoRA for DDI classification."""
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||||
|
from trl import SFTTrainer
|
||||||
|
from datasets import Dataset
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
|
||||||
Expected input:
|
# Parameters
|
||||||
{
|
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
||||||
"model_name": "emilyalsentzer/Bio_ClinicalBERT",
|
max_samples = job_input.get('max_samples', 2000)
|
||||||
"use_real_data": true,
|
epochs = job_input.get('epochs', 1)
|
||||||
"max_samples": 5000,
|
learning_rate = job_input.get('learning_rate', 2e-4)
|
||||||
"training_data": [...], # Or provide inline data
|
batch_size = job_input.get('batch_size', 4)
|
||||||
"epochs": 3,
|
lora_r = job_input.get('lora_r', 16)
|
||||||
"learning_rate": 2e-5,
|
lora_alpha = job_input.get('lora_alpha', 32)
|
||||||
"batch_size": 16,
|
max_seq_length = job_input.get('max_seq_length', 512)
|
||||||
"eval_split": 0.1
|
|
||||||
}
|
work_dir = tempfile.mkdtemp()
|
||||||
"""
|
|
||||||
|
try:
|
||||||
|
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
||||||
|
print(f"Model: {model_name}")
|
||||||
|
print(f"LoRA r={lora_r}, alpha={lora_alpha}")
|
||||||
|
print(f"Samples: {max_samples}, Epochs: {epochs}")
|
||||||
|
|
||||||
|
# Load training data
|
||||||
|
print("Loading DDI training data...")
|
||||||
|
training_data = get_ddi_training_data(max_samples=max_samples)
|
||||||
|
|
||||||
|
# Format for Gemma
|
||||||
|
formatted_data = [{"text": format_for_gemma(ex)} for ex in training_data]
|
||||||
|
dataset = Dataset.from_list(formatted_data)
|
||||||
|
|
||||||
|
print(f"Dataset size: {len(dataset)}")
|
||||||
|
|
||||||
|
# QLoRA config - 4-bit quantization
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
print(f"Loading {model_name} with 4-bit quantization...")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
quantization_config=bnb_config,
|
||||||
|
device_map="auto",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.padding_side = "right"
|
||||||
|
|
||||||
|
# Prepare for k-bit training
|
||||||
|
model = prepare_model_for_kbit_training(model)
|
||||||
|
|
||||||
|
# LoRA config
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||||
|
lora_dropout=0.05,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = get_peft_model(model, lora_config)
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
# Training arguments
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=work_dir,
|
||||||
|
num_train_epochs=epochs,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
gradient_accumulation_steps=4,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
weight_decay=0.01,
|
||||||
|
warmup_ratio=0.1,
|
||||||
|
logging_steps=10,
|
||||||
|
save_strategy="no",
|
||||||
|
bf16=True,
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
optim="paged_adamw_8bit",
|
||||||
|
report_to="none",
|
||||||
|
max_grad_norm=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SFT Trainer
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
train_dataset=dataset,
|
||||||
|
args=training_args,
|
||||||
|
peft_config=lora_config,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
max_seq_length=max_seq_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
print("Starting LoRA fine-tuning...")
|
||||||
|
train_result = trainer.train()
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
metrics = {
|
||||||
|
'train_loss': float(train_result.training_loss),
|
||||||
|
'epochs': epochs,
|
||||||
|
'model_name': model_name,
|
||||||
|
'samples': len(training_data),
|
||||||
|
'lora_r': lora_r,
|
||||||
|
'lora_alpha': lora_alpha,
|
||||||
|
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
||||||
|
'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad),
|
||||||
|
'quantization': '4-bit QLoRA',
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'status': 'success',
|
||||||
|
'metrics': metrics,
|
||||||
|
'message': f'Gemma 3 12B fine-tuned with QLoRA on DDI data'
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
return {
|
||||||
|
'status': 'error',
|
||||||
|
'error': str(e),
|
||||||
|
'traceback': traceback.format_exc()
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(work_dir, ignore_errors=True)
|
||||||
|
# Clear GPU memory
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Train BERT-style classifier (original approach)."""
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -159,165 +315,119 @@ def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
# Extract parameters
|
|
||||||
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
||||||
use_real_data = job_input.get('use_real_data', True)
|
|
||||||
max_samples = job_input.get('max_samples', 5000)
|
max_samples = job_input.get('max_samples', 5000)
|
||||||
training_data = job_input.get('training_data', None)
|
|
||||||
epochs = job_input.get('epochs', 3)
|
epochs = job_input.get('epochs', 3)
|
||||||
learning_rate = job_input.get('learning_rate', 2e-5)
|
learning_rate = job_input.get('learning_rate', 2e-5)
|
||||||
batch_size = job_input.get('batch_size', 16)
|
batch_size = job_input.get('batch_size', 16)
|
||||||
eval_split = job_input.get('eval_split', 0.1)
|
eval_split = job_input.get('eval_split', 0.1)
|
||||||
|
|
||||||
# Load data
|
|
||||||
if use_real_data and not training_data:
|
|
||||||
print("Loading curated DDI dataset...")
|
|
||||||
training_data = get_real_ddi_data(max_samples=max_samples)
|
|
||||||
elif not training_data:
|
|
||||||
print("No training data provided, using sample DDI dataset...")
|
|
||||||
training_data = get_real_ddi_data(max_samples=150)
|
|
||||||
|
|
||||||
# Create temp directory
|
|
||||||
work_dir = tempfile.mkdtemp()
|
work_dir = tempfile.mkdtemp()
|
||||||
model_dir = os.path.join(work_dir, 'model')
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(f"Training samples: {len(training_data)}")
|
|
||||||
print(f"Model: {model_name}")
|
|
||||||
print(f"Epochs: {epochs}, Batch size: {batch_size}")
|
|
||||||
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
||||||
|
print(f"Model: {model_name}")
|
||||||
|
|
||||||
# Count label distribution
|
# Get data in BERT format
|
||||||
label_counts = {}
|
raw_data = get_ddi_training_data(max_samples=max_samples)
|
||||||
for item in training_data:
|
training_data = [{"text": d["instruction"], "label": d["label"]} for d in raw_data]
|
||||||
label_counts[item['label']] = label_counts.get(item['label'], 0) + 1
|
|
||||||
print(f"Label distribution: {label_counts}")
|
|
||||||
|
|
||||||
# Split into train/eval
|
# Split
|
||||||
if eval_split > 0 and len(training_data) > 100:
|
if eval_split > 0:
|
||||||
train_data, eval_data = train_test_split(
|
train_data, eval_data = train_test_split(
|
||||||
training_data,
|
training_data, test_size=eval_split, random_state=42,
|
||||||
test_size=eval_split,
|
|
||||||
random_state=42,
|
|
||||||
stratify=[d['label'] for d in training_data]
|
stratify=[d['label'] for d in training_data]
|
||||||
)
|
)
|
||||||
print(f"Train: {len(train_data)}, Eval: {len(eval_data)}")
|
|
||||||
else:
|
else:
|
||||||
train_data = training_data
|
train_data, eval_data = training_data, None
|
||||||
eval_data = None
|
|
||||||
|
|
||||||
# Create datasets
|
|
||||||
train_dataset = Dataset.from_list(train_data)
|
train_dataset = Dataset.from_list(train_data)
|
||||||
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
|
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
|
||||||
|
|
||||||
# Load model and tokenizer
|
# Load model
|
||||||
print(f"Loading model: {model_name}")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
model_name,
|
model_name, num_labels=5
|
||||||
num_labels=5 # DDI severity: none(0), minor(1), moderate(2), major(3), contraindicated(4)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tokenize datasets
|
def tokenize(examples):
|
||||||
def tokenize_function(examples):
|
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256)
|
||||||
return tokenizer(
|
|
||||||
examples['text'],
|
|
||||||
padding='max_length',
|
|
||||||
truncation=True,
|
|
||||||
max_length=256
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenized_train = train_dataset.map(tokenize_function, batched=True)
|
train_dataset = train_dataset.map(tokenize, batched=True)
|
||||||
tokenized_eval = eval_dataset.map(tokenize_function, batched=True) if eval_dataset else None
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.map(tokenize, batched=True)
|
||||||
|
|
||||||
# Training arguments
|
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=model_dir,
|
output_dir=work_dir,
|
||||||
num_train_epochs=epochs,
|
num_train_epochs=epochs,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
per_device_train_batch_size=batch_size,
|
per_device_train_batch_size=batch_size,
|
||||||
per_device_eval_batch_size=batch_size,
|
eval_strategy='epoch' if eval_dataset else 'no',
|
||||||
warmup_ratio=0.1,
|
|
||||||
weight_decay=0.01,
|
|
||||||
logging_steps=50,
|
|
||||||
eval_strategy='epoch' if tokenized_eval else 'no',
|
|
||||||
save_strategy='no',
|
save_strategy='no',
|
||||||
fp16=torch.cuda.is_available(),
|
fp16=torch.cuda.is_available(),
|
||||||
report_to='none',
|
report_to='none',
|
||||||
load_best_model_at_end=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute metrics function
|
|
||||||
def compute_metrics(eval_pred):
|
def compute_metrics(eval_pred):
|
||||||
from sklearn.metrics import accuracy_score, f1_score
|
from sklearn.metrics import accuracy_score, f1_score
|
||||||
predictions, labels = eval_pred
|
preds = eval_pred.predictions.argmax(-1)
|
||||||
predictions = predictions.argmax(-1)
|
|
||||||
return {
|
return {
|
||||||
'accuracy': accuracy_score(labels, predictions),
|
'accuracy': accuracy_score(eval_pred.label_ids, preds),
|
||||||
'f1_macro': f1_score(labels, predictions, average='macro'),
|
'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'),
|
||||||
'f1_weighted': f1_score(labels, predictions, average='weighted'),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize trainer
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=tokenized_train,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=tokenized_eval,
|
eval_dataset=eval_dataset,
|
||||||
compute_metrics=compute_metrics if tokenized_eval else None,
|
compute_metrics=compute_metrics if eval_dataset else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train
|
|
||||||
print("Starting training...")
|
|
||||||
train_result = trainer.train()
|
train_result = trainer.train()
|
||||||
|
|
||||||
# Get metrics
|
|
||||||
metrics = {
|
metrics = {
|
||||||
'train_loss': float(train_result.training_loss),
|
'train_loss': float(train_result.training_loss),
|
||||||
'epochs': epochs,
|
'epochs': epochs,
|
||||||
'model_name': model_name,
|
'model_name': model_name,
|
||||||
'samples': len(train_data),
|
'samples': len(train_data),
|
||||||
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
||||||
'data_source': 'curated_ddi'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run evaluation if we have eval data
|
if eval_dataset:
|
||||||
if tokenized_eval:
|
|
||||||
print("Running evaluation...")
|
|
||||||
eval_result = trainer.evaluate()
|
eval_result = trainer.evaluate()
|
||||||
metrics.update({
|
metrics.update({
|
||||||
'eval_loss': float(eval_result['eval_loss']),
|
|
||||||
'eval_accuracy': float(eval_result['eval_accuracy']),
|
'eval_accuracy': float(eval_result['eval_accuracy']),
|
||||||
'eval_f1_macro': float(eval_result['eval_f1_macro']),
|
|
||||||
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
|
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
|
||||||
})
|
})
|
||||||
|
|
||||||
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
|
return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained'}
|
||||||
if 'eval_accuracy' in metrics:
|
|
||||||
print(f"Eval accuracy: {metrics['eval_accuracy']:.4f}, F1: {metrics['eval_f1_weighted']:.4f}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
'status': 'success',
|
|
||||||
'metrics': metrics,
|
|
||||||
'message': 'Model trained successfully on curated DDI data'
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
return {
|
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
|
||||||
'status': 'error',
|
|
||||||
'error': str(e),
|
|
||||||
'traceback': traceback.format_exc()
|
|
||||||
}
|
|
||||||
finally:
|
finally:
|
||||||
# Cleanup
|
|
||||||
shutil.rmtree(work_dir, ignore_errors=True)
|
shutil.rmtree(work_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
def handler(job):
|
def handler(job):
|
||||||
"""RunPod serverless handler."""
|
"""RunPod serverless handler."""
|
||||||
job_input = job.get('input', {})
|
job_input = job.get('input', {})
|
||||||
return train_ddi_model(job_input)
|
|
||||||
|
# Choose training mode
|
||||||
|
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
||||||
|
use_lora = job_input.get('use_lora', True)
|
||||||
|
|
||||||
|
# Auto-detect: use LoRA for large models
|
||||||
|
if 'gemma' in model_name.lower() or 'llama' in model_name.lower() or 'mistral' in model_name.lower():
|
||||||
|
use_lora = True
|
||||||
|
elif 'bert' in model_name.lower():
|
||||||
|
use_lora = False
|
||||||
|
|
||||||
|
if use_lora:
|
||||||
|
return train_gemma_lora(job_input)
|
||||||
|
else:
|
||||||
|
return train_bert_classifier(job_input)
|
||||||
|
|
||||||
|
|
||||||
# RunPod serverless entrypoint
|
# RunPod serverless entrypoint
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
runpod>=1.7.0
|
runpod>=1.7.0
|
||||||
transformers==4.44.0
|
transformers>=4.48.0
|
||||||
datasets>=2.16.0
|
datasets>=2.16.0
|
||||||
accelerate>=0.30.0
|
accelerate>=0.30.0
|
||||||
boto3>=1.34.0
|
boto3>=1.34.0
|
||||||
@@ -7,3 +7,6 @@ scikit-learn>=1.3.0
|
|||||||
scipy>=1.11.0
|
scipy>=1.11.0
|
||||||
safetensors>=0.4.0
|
safetensors>=0.4.0
|
||||||
requests>=2.31.0
|
requests>=2.31.0
|
||||||
|
peft>=0.14.0
|
||||||
|
bitsandbytes>=0.45.0
|
||||||
|
trl>=0.14.0
|
||||||
|
|||||||
Reference in New Issue
Block a user