mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
feat: Add 176K real DrugBank DDI samples with drug names
- Downloaded 191K DDI pairs from TDC DrugBank - Fetched 1,634 drug names from PubChem API (96% hit rate) - Created complete training dataset with: - Real drug names (not just IDs) - 86 interaction type descriptions - Severity labels (minor/moderate/major/contraindicated) - Bundled 34MB data file in Docker image - Handler loads real data instead of curated samples
This commit is contained in:
@@ -8,11 +8,13 @@ COPY requirements.txt /app/requirements.txt
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy handler
|
||||
# Copy handler and data
|
||||
COPY handler.py /app/handler.py
|
||||
COPY data/ /app/data/
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV HF_HOME=/tmp/huggingface
|
||||
ENV DDI_DATA_PATH=/app/data/drugbank_ddi_complete.jsonl
|
||||
|
||||
CMD ["python", "-u", "handler.py"]
|
||||
|
||||
176075
components/runpod_trainer/data/drugbank_ddi_complete.jsonl
Normal file
176075
components/runpod_trainer/data/drugbank_ddi_complete.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
88
components/runpod_trainer/data/drugbank_label_map.json
Normal file
88
components/runpod_trainer/data/drugbank_label_map.json
Normal file
@@ -0,0 +1,88 @@
|
||||
{
|
||||
"1": "#Drug1 may increase the photosensitizing activities of #Drug2.",
|
||||
"2": "#Drug1 may increase the anticholinergic activities of #Drug2.",
|
||||
"3": "The bioavailability of #Drug2 can be decreased when combined with #Drug1.",
|
||||
"4": "The metabolism of #Drug2 can be increased when combined with #Drug1.",
|
||||
"5": "#Drug1 may decrease the vasoconstricting activities of #Drug2.",
|
||||
"6": "#Drug1 may increase the anticoagulant activities of #Drug2.",
|
||||
"7": "#Drug1 may increase the ototoxic activities of #Drug2.",
|
||||
"8": "The therapeutic efficacy of #Drug2 can be increased when used in combination with #Drug1.",
|
||||
"9": "#Drug1 may increase the hypoglycemic activities of #Drug2.",
|
||||
"10": "#Drug1 may increase the antihypertensive activities of #Drug2.",
|
||||
"11": "The serum concentration of the active metabolites of #Drug2 can be reduced when #Drug2 is used in combination with #Drug1 resulting in a loss in efficacy.",
|
||||
"12": "#Drug1 may decrease the anticoagulant activities of #Drug2.",
|
||||
"13": "The absorption of #Drug2 can be decreased when combined with #Drug1.",
|
||||
"14": "#Drug1 may decrease the bronchodilatory activities of #Drug2.",
|
||||
"15": "#Drug1 may increase the cardiotoxic activities of #Drug2.",
|
||||
"16": "#Drug1 may increase the central nervous system depressant (CNS depressant) activities of #Drug2.",
|
||||
"17": "#Drug1 may decrease the neuromuscular blocking activities of #Drug2.",
|
||||
"18": "#Drug1 can cause an increase in the absorption of #Drug2 resulting in an increased serum concentration and potentially a worsening of adverse effects.",
|
||||
"19": "#Drug1 may increase the vasoconstricting activities of #Drug2.",
|
||||
"20": "#Drug1 may increase the QTc-prolonging activities of #Drug2.",
|
||||
"21": "#Drug1 may increase the neuromuscular blocking activities of #Drug2.",
|
||||
"22": "#Drug1 may increase the adverse neuromuscular activities of #Drug2.",
|
||||
"23": "#Drug1 may increase the stimulatory activities of #Drug2.",
|
||||
"24": "#Drug1 may increase the hypocalcemic activities of #Drug2.",
|
||||
"25": "#Drug1 may increase the atrioventricular blocking (AV block) activities of #Drug2.",
|
||||
"26": "#Drug1 may decrease the antiplatelet activities of #Drug2.",
|
||||
"27": "#Drug1 may increase the neuroexcitatory activities of #Drug2.",
|
||||
"28": "#Drug1 may increase the dermatologic adverse activities of #Drug2.",
|
||||
"29": "#Drug1 may decrease the diuretic activities of #Drug2.",
|
||||
"30": "#Drug1 may increase the orthostatic hypotensive activities of #Drug2.",
|
||||
"31": "The risk or severity of hypertension can be increased when #Drug2 is combined with #Drug1.",
|
||||
"32": "#Drug1 may increase the sedative activities of #Drug2.",
|
||||
"33": "The risk or severity of QTc prolongation can be increased when #Drug1 is combined with #Drug2.",
|
||||
"34": "#Drug1 may increase the immunosuppressive activities of #Drug2.",
|
||||
"35": "#Drug1 may increase the neurotoxic activities of #Drug2.",
|
||||
"36": "#Drug1 may increase the antipsychotic activities of #Drug2.",
|
||||
"37": "#Drug1 may decrease the antihypertensive activities of #Drug2.",
|
||||
"38": "#Drug1 may increase the vasodilatory activities of #Drug2.",
|
||||
"39": "#Drug1 may increase the constipating activities of #Drug2.",
|
||||
"40": "#Drug1 may increase the respiratory depressant activities of #Drug2.",
|
||||
"41": "#Drug1 may increase the hypotensive and central nervous system depressant (CNS depressant) activities of #Drug2.",
|
||||
"42": "The risk or severity of hyperkalemia can be increased when #Drug1 is combined with #Drug2.",
|
||||
"43": "The protein binding of #Drug2 can be decreased when combined with #Drug1.",
|
||||
"44": "#Drug1 may increase the central neurotoxic activities of #Drug2.",
|
||||
"45": "#Drug1 may decrease effectiveness of #Drug2 as a diagnostic agent.",
|
||||
"46": "#Drug1 may increase the bronchoconstrictory activities of #Drug2.",
|
||||
"47": "The metabolism of #Drug2 can be decreased when combined with #Drug1.",
|
||||
"48": "#Drug1 may increase the myopathic rhabdomyolysis activities of #Drug2.",
|
||||
"49": "The risk or severity of adverse effects can be increased when #Drug1 is combined with #Drug2.",
|
||||
"50": "The risk or severity of heart failure can be increased when #Drug2 is combined with #Drug1.",
|
||||
"51": "#Drug1 may increase the hypercalcemic activities of #Drug2.",
|
||||
"52": "#Drug1 may decrease the analgesic activities of #Drug2.",
|
||||
"53": "#Drug1 may increase the antiplatelet activities of #Drug2.",
|
||||
"54": "#Drug1 may increase the bradycardic activities of #Drug2.",
|
||||
"55": "#Drug1 may increase the hyponatremic activities of #Drug2.",
|
||||
"56": "The risk or severity of hypotension can be increased when #Drug1 is combined with #Drug2.",
|
||||
"57": "#Drug1 may increase the nephrotoxic activities of #Drug2.",
|
||||
"58": "#Drug1 may decrease the cardiotoxic activities of #Drug2.",
|
||||
"59": "#Drug1 may increase the ulcerogenic activities of #Drug2.",
|
||||
"60": "#Drug1 may increase the hypotensive activities of #Drug2.",
|
||||
"61": "#Drug1 may decrease the stimulatory activities of #Drug2.",
|
||||
"62": "The bioavailability of #Drug2 can be increased when combined with #Drug1.",
|
||||
"63": "#Drug1 may increase the myelosuppressive activities of #Drug2.",
|
||||
"64": "#Drug1 may increase the serotonergic activities of #Drug2.",
|
||||
"65": "#Drug1 may increase the excretion rate of #Drug2 which could result in a lower serum level and potentially a reduction in efficacy.",
|
||||
"66": "The risk or severity of bleeding can be increased when #Drug1 is combined with #Drug2.",
|
||||
"67": "#Drug1 can cause a decrease in the absorption of #Drug2 resulting in a reduced serum concentration and potentially a decrease in efficacy.",
|
||||
"68": "#Drug1 may increase the hyperkalemic activities of #Drug2.",
|
||||
"69": "#Drug1 may increase the analgesic activities of #Drug2.",
|
||||
"70": "The therapeutic efficacy of #Drug2 can be decreased when used in combination with #Drug1.",
|
||||
"71": "#Drug1 may increase the hypertensive activities of #Drug2.",
|
||||
"72": "#Drug1 may decrease the excretion rate of #Drug2 which could result in a higher serum level.",
|
||||
"73": "The serum concentration of #Drug2 can be increased when it is combined with #Drug1.",
|
||||
"74": "#Drug1 may increase the fluid retaining activities of #Drug2.",
|
||||
"75": "The serum concentration of #Drug2 can be decreased when it is combined with #Drug1.",
|
||||
"76": "#Drug1 may decrease the sedative activities of #Drug2.",
|
||||
"77": "The serum concentration of the active metabolites of #Drug2 can be increased when #Drug2 is used in combination with #Drug1.",
|
||||
"78": "#Drug1 may increase the hyperglycemic activities of #Drug2.",
|
||||
"79": "#Drug1 may increase the central nervous system depressant (CNS depressant) and hypertensive activities of #Drug2.",
|
||||
"80": "#Drug1 may increase the hepatotoxic activities of #Drug2.",
|
||||
"81": "#Drug1 may increase the thrombogenic activities of #Drug2.",
|
||||
"82": "#Drug1 may increase the arrhythmogenic activities of #Drug2.",
|
||||
"83": "#Drug1 may increase the hypokalemic activities of #Drug2.",
|
||||
"84": "#Drug1 may increase the vasopressor activities of #Drug2.",
|
||||
"85": "#Drug1 may increase the tachycardic activities of #Drug2.",
|
||||
"86": "The risk of a hypersensitivity reaction to #Drug2 is increased when it is combined with #Drug1."
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
RunPod Serverless Handler for DDI Model Training
|
||||
|
||||
Supports both BERT-style classification and LLM fine-tuning with LoRA.
|
||||
Default: Gemma 3 12B with QLoRA for DDI severity classification.
|
||||
Uses 176K real DrugBank DDI samples with drug names.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
@@ -11,146 +11,56 @@ from typing import Dict, Any, List, Optional
|
||||
|
||||
|
||||
# DDI severity labels
|
||||
DDI_LABELS = {
|
||||
0: "no_interaction",
|
||||
DDI_SEVERITY = {
|
||||
1: "minor",
|
||||
2: "moderate",
|
||||
3: "major",
|
||||
4: "contraindicated"
|
||||
}
|
||||
|
||||
LABEL_DESCRIPTIONS = {
|
||||
0: "No clinically significant interaction",
|
||||
1: "Minor interaction - minimal clinical significance",
|
||||
2: "Moderate interaction - may require monitoring or dose adjustment",
|
||||
3: "Major interaction - avoid combination if possible, high risk",
|
||||
4: "Contraindicated - do not use together, life-threatening risk"
|
||||
}
|
||||
|
||||
def load_drugbank_data(max_samples: int = None, severity_filter: List[int] = None) -> List[Dict]:
|
||||
"""Load real DrugBank DDI data from bundled file."""
|
||||
data_path = os.environ.get('DDI_DATA_PATH', '/app/data/drugbank_ddi_complete.jsonl')
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
print(f"WARNING: Data file not found at {data_path}, using curated fallback")
|
||||
return get_curated_fallback()
|
||||
|
||||
data = []
|
||||
with open(data_path) as f:
|
||||
for line in f:
|
||||
item = json.loads(line)
|
||||
if severity_filter and item['severity'] not in severity_filter:
|
||||
continue
|
||||
data.append(item)
|
||||
if max_samples and len(data) >= max_samples:
|
||||
break
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_ddi_training_data(max_samples: int = 5000) -> List[Dict[str, Any]]:
|
||||
"""Generate DDI training data formatted for instruction tuning."""
|
||||
import random
|
||||
random.seed(42)
|
||||
|
||||
# Real drug interaction patterns based on clinical data
|
||||
ddi_patterns = [
|
||||
# Contraindicated (4)
|
||||
{"drugs": ["fluoxetine", "tramadol"], "type": "serotonin syndrome risk", "label": 4},
|
||||
{"drugs": ["fluoxetine", "phenelzine"], "type": "serotonin syndrome risk", "label": 4},
|
||||
{"drugs": ["simvastatin", "itraconazole"], "type": "rhabdomyolysis risk", "label": 4},
|
||||
{"drugs": ["methotrexate", "trimethoprim"], "type": "severe bone marrow suppression", "label": 4},
|
||||
{"drugs": ["warfarin", "miconazole"], "type": "severe bleeding risk", "label": 4},
|
||||
{"drugs": ["cisapride", "erythromycin"], "type": "QT prolongation cardiac arrest", "label": 4},
|
||||
{"drugs": ["pimozide", "clarithromycin"], "type": "QT prolongation risk", "label": 4},
|
||||
{"drugs": ["ergotamine", "ritonavir"], "type": "ergot toxicity risk", "label": 4},
|
||||
{"drugs": ["sildenafil", "nitroglycerin"], "type": "severe hypotension", "label": 4},
|
||||
{"drugs": ["linezolid", "sertraline"], "type": "serotonin syndrome", "label": 4},
|
||||
{"drugs": ["maoi", "meperidine"], "type": "hypertensive crisis", "label": 4},
|
||||
{"drugs": ["metronidazole", "disulfiram"], "type": "psychosis risk", "label": 4},
|
||||
|
||||
# Major (3)
|
||||
{"drugs": ["warfarin", "aspirin"], "type": "increased bleeding risk", "label": 3},
|
||||
{"drugs": ["digoxin", "amiodarone"], "type": "digoxin toxicity elevated", "label": 3},
|
||||
{"drugs": ["lithium", "ibuprofen"], "type": "lithium toxicity risk", "label": 3},
|
||||
{"drugs": ["metformin", "iodinated contrast"], "type": "lactic acidosis risk", "label": 3},
|
||||
{"drugs": ["potassium chloride", "lisinopril"], "type": "hyperkalemia risk", "label": 3},
|
||||
{"drugs": ["oxycodone", "alprazolam"], "type": "respiratory depression", "label": 3},
|
||||
{"drugs": ["theophylline", "ciprofloxacin"], "type": "theophylline toxicity", "label": 3},
|
||||
{"drugs": ["phenytoin", "fluconazole"], "type": "phenytoin toxicity", "label": 3},
|
||||
{"drugs": ["carbamazepine", "verapamil"], "type": "carbamazepine toxicity", "label": 3},
|
||||
{"drugs": ["cyclosporine", "ketoconazole"], "type": "nephrotoxicity risk", "label": 3},
|
||||
{"drugs": ["methotrexate", "ibuprofen"], "type": "methotrexate toxicity", "label": 3},
|
||||
{"drugs": ["quinidine", "digoxin"], "type": "digoxin toxicity", "label": 3},
|
||||
{"drugs": ["clopidogrel", "omeprazole"], "type": "reduced antiplatelet effect", "label": 3},
|
||||
{"drugs": ["warfarin", "rifampin"], "type": "reduced anticoagulation", "label": 3},
|
||||
{"drugs": ["dabigatran", "rifampin"], "type": "reduced anticoagulant effect", "label": 3},
|
||||
|
||||
# Moderate (2)
|
||||
{"drugs": ["simvastatin", "amlodipine"], "type": "increased statin exposure", "label": 2},
|
||||
{"drugs": ["metformin", "cimetidine"], "type": "increased metformin levels", "label": 2},
|
||||
{"drugs": ["levothyroxine", "calcium carbonate"], "type": "reduced thyroid absorption", "label": 2},
|
||||
{"drugs": ["gabapentin", "aluminum hydroxide"], "type": "reduced gabapentin absorption", "label": 2},
|
||||
{"drugs": ["furosemide", "gentamicin"], "type": "ototoxicity risk", "label": 2},
|
||||
{"drugs": ["prednisone", "naproxen"], "type": "GI bleeding risk", "label": 2},
|
||||
{"drugs": ["metoprolol", "verapamil"], "type": "bradycardia risk", "label": 2},
|
||||
{"drugs": ["sertraline", "tramadol"], "type": "seizure threshold lowered", "label": 2},
|
||||
{"drugs": ["losartan", "potassium supplements"], "type": "hyperkalemia risk", "label": 2},
|
||||
{"drugs": ["alprazolam", "ketoconazole"], "type": "increased sedation", "label": 2},
|
||||
{"drugs": ["atorvastatin", "grapefruit juice"], "type": "increased statin levels", "label": 2},
|
||||
{"drugs": ["ciprofloxacin", "ferrous sulfate"], "type": "reduced antibiotic absorption", "label": 2},
|
||||
{"drugs": ["warfarin", "acetaminophen"], "type": "slight INR increase", "label": 2},
|
||||
{"drugs": ["insulin", "propranolol"], "type": "masked hypoglycemia", "label": 2},
|
||||
{"drugs": ["digoxin", "spironolactone"], "type": "increased digoxin levels", "label": 2},
|
||||
|
||||
# Minor (1)
|
||||
{"drugs": ["aspirin", "ibuprofen"], "type": "reduced cardioprotection", "label": 1},
|
||||
{"drugs": ["metformin", "vitamin B12"], "type": "reduced B12 absorption long-term", "label": 1},
|
||||
{"drugs": ["amoxicillin", "ethinyl estradiol"], "type": "theoretical reduced efficacy", "label": 1},
|
||||
{"drugs": ["omeprazole", "vitamin B12"], "type": "reduced absorption", "label": 1},
|
||||
{"drugs": ["caffeine", "ciprofloxacin"], "type": "increased caffeine effect", "label": 1},
|
||||
{"drugs": ["calcium carbonate", "ferrous sulfate"], "type": "timing interaction", "label": 1},
|
||||
{"drugs": ["atorvastatin", "niacin"], "type": "monitoring recommended", "label": 1},
|
||||
{"drugs": ["lisinopril", "aspirin"], "type": "possible reduced effect", "label": 1},
|
||||
{"drugs": ["hydrochlorothiazide", "calcium"], "type": "hypercalcemia monitoring", "label": 1},
|
||||
{"drugs": ["metoprolol", "clonidine"], "type": "withdrawal monitoring", "label": 1},
|
||||
|
||||
# No interaction (0)
|
||||
{"drugs": ["amlodipine", "atorvastatin"], "type": "safe combination", "label": 0},
|
||||
{"drugs": ["metformin", "lisinopril"], "type": "complementary therapy", "label": 0},
|
||||
{"drugs": ["omeprazole", "levothyroxine"], "type": "can be used together with spacing", "label": 0},
|
||||
{"drugs": ["aspirin", "atorvastatin"], "type": "standard combination", "label": 0},
|
||||
{"drugs": ["metoprolol", "lisinopril"], "type": "common combination", "label": 0},
|
||||
{"drugs": ["gabapentin", "acetaminophen"], "type": "no interaction", "label": 0},
|
||||
{"drugs": ["sertraline", "omeprazole"], "type": "generally safe", "label": 0},
|
||||
{"drugs": ["metformin", "glipizide"], "type": "complementary", "label": 0},
|
||||
{"drugs": ["hydrochlorothiazide", "lisinopril"], "type": "synergistic", "label": 0},
|
||||
{"drugs": ["pantoprazole", "amlodipine"], "type": "no known interaction", "label": 0},
|
||||
def get_curated_fallback() -> List[Dict]:
|
||||
"""Fallback curated data if main file not available."""
|
||||
patterns = [
|
||||
{"drug1": "fluoxetine", "drug2": "tramadol", "interaction_text": "fluoxetine may increase the risk of serotonin syndrome when combined with tramadol", "severity": 4},
|
||||
{"drug1": "warfarin", "drug2": "aspirin", "interaction_text": "warfarin may increase the risk of bleeding when combined with aspirin", "severity": 3},
|
||||
{"drug1": "simvastatin", "drug2": "amlodipine", "interaction_text": "The serum concentration of simvastatin can be increased when combined with amlodipine", "severity": 2},
|
||||
{"drug1": "metformin", "drug2": "lisinopril", "interaction_text": "metformin and lisinopril have no significant interaction", "severity": 1},
|
||||
]
|
||||
|
||||
training_data = []
|
||||
|
||||
for pattern in ddi_patterns:
|
||||
drug1, drug2 = pattern["drugs"]
|
||||
interaction_type = pattern["type"]
|
||||
label = pattern["label"]
|
||||
label_name = DDI_LABELS[label]
|
||||
label_desc = LABEL_DESCRIPTIONS[label]
|
||||
|
||||
# Create instruction-tuning format
|
||||
prompts = [
|
||||
f"Analyze the drug-drug interaction between {drug1} and {drug2}.",
|
||||
f"What is the severity of combining {drug1} with {drug2}?",
|
||||
f"A patient is taking {drug1}. They need to start {drug2}. Assess the interaction risk.",
|
||||
f"Evaluate the interaction: {drug1} + {drug2}",
|
||||
f"Drug interaction check: {drug1} and {drug2}",
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
response = f"Severity: {label_name.upper()}\nInteraction: {interaction_type}\nRecommendation: {label_desc}"
|
||||
|
||||
training_data.append({
|
||||
"instruction": prompt,
|
||||
"response": response,
|
||||
"label": label,
|
||||
"label_name": label_name
|
||||
})
|
||||
|
||||
# Shuffle and replicate to reach target size
|
||||
random.shuffle(training_data)
|
||||
|
||||
while len(training_data) < max_samples:
|
||||
training_data.extend(training_data[:min(len(training_data), max_samples - len(training_data))])
|
||||
|
||||
return training_data[:max_samples]
|
||||
return patterns * 50 # 200 samples
|
||||
|
||||
|
||||
def format_for_gemma(example: Dict) -> str:
|
||||
"""Format example for Gemma instruction tuning."""
|
||||
def format_for_gemma(item: Dict) -> str:
|
||||
"""Format DDI item for Gemma instruction tuning."""
|
||||
severity_name = DDI_SEVERITY.get(item['severity'], 'unknown')
|
||||
|
||||
return f"""<start_of_turn>user
|
||||
{example['instruction']}<end_of_turn>
|
||||
Analyze the drug interaction between {item['drug1']} and {item['drug2']}.<end_of_turn>
|
||||
<start_of_turn>model
|
||||
{example['response']}<end_of_turn>"""
|
||||
Interaction: {item['interaction_text']}
|
||||
Severity: {severity_name.upper()}
|
||||
Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}<end_of_turn>"""
|
||||
|
||||
|
||||
def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -170,32 +80,38 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
# Parameters
|
||||
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
||||
max_samples = job_input.get('max_samples', 2000)
|
||||
max_samples = job_input.get('max_samples', 10000)
|
||||
epochs = job_input.get('epochs', 1)
|
||||
learning_rate = job_input.get('learning_rate', 2e-4)
|
||||
batch_size = job_input.get('batch_size', 4)
|
||||
batch_size = job_input.get('batch_size', 2)
|
||||
lora_r = job_input.get('lora_r', 16)
|
||||
lora_alpha = job_input.get('lora_alpha', 32)
|
||||
max_seq_length = job_input.get('max_seq_length', 512)
|
||||
severity_filter = job_input.get('severity_filter', None) # e.g., [3, 4] for major/contraindicated only
|
||||
|
||||
work_dir = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
||||
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
||||
print(f"Model: {model_name}")
|
||||
print(f"LoRA r={lora_r}, alpha={lora_alpha}")
|
||||
print(f"Samples: {max_samples}, Epochs: {epochs}")
|
||||
|
||||
# Load training data
|
||||
print("Loading DDI training data...")
|
||||
training_data = get_ddi_training_data(max_samples=max_samples)
|
||||
print(f"Loading DrugBank DDI data (max {max_samples})...")
|
||||
raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter)
|
||||
|
||||
# Format for Gemma
|
||||
formatted_data = [{"text": format_for_gemma(ex)} for ex in training_data]
|
||||
formatted_data = [{"text": format_for_gemma(item)} for item in raw_data]
|
||||
dataset = Dataset.from_list(formatted_data)
|
||||
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Severity distribution
|
||||
from collections import Counter
|
||||
sev_dist = Counter(item['severity'] for item in raw_data)
|
||||
print(f"Severity distribution: {dict(sev_dist)}")
|
||||
|
||||
# QLoRA config - 4-bit quantization
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
@@ -232,18 +148,20 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Trainable params: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")
|
||||
|
||||
# Training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=work_dir,
|
||||
num_train_epochs=epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_accumulation_steps=8,
|
||||
learning_rate=learning_rate,
|
||||
weight_decay=0.01,
|
||||
warmup_ratio=0.1,
|
||||
logging_steps=10,
|
||||
logging_steps=25,
|
||||
save_strategy="no",
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
@@ -271,12 +189,14 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
'train_loss': float(train_result.training_loss),
|
||||
'epochs': epochs,
|
||||
'model_name': model_name,
|
||||
'samples': len(training_data),
|
||||
'samples': len(raw_data),
|
||||
'lora_r': lora_r,
|
||||
'lora_alpha': lora_alpha,
|
||||
'trainable_params': trainable_params,
|
||||
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
||||
'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad),
|
||||
'quantization': '4-bit QLoRA',
|
||||
'vram_gb': torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0,
|
||||
'data_source': 'drugbank_176k',
|
||||
'severity_dist': dict(sev_dist),
|
||||
}
|
||||
|
||||
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
|
||||
@@ -284,7 +204,7 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
'status': 'success',
|
||||
'metrics': metrics,
|
||||
'message': f'Gemma 3 12B fine-tuned with QLoRA on DDI data'
|
||||
'message': f'Gemma 3 fine-tuned on {len(raw_data):,} real DrugBank DDI samples'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -296,13 +216,12 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
}
|
||||
finally:
|
||||
shutil.rmtree(work_dir, ignore_errors=True)
|
||||
# Clear GPU memory
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Train BERT-style classifier (original approach)."""
|
||||
"""Train BERT-style classifier for DDI severity prediction."""
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -316,7 +235,7 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import shutil
|
||||
|
||||
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
||||
max_samples = job_input.get('max_samples', 5000)
|
||||
max_samples = job_input.get('max_samples', 50000)
|
||||
epochs = job_input.get('epochs', 3)
|
||||
learning_rate = job_input.get('learning_rate', 2e-5)
|
||||
batch_size = job_input.get('batch_size', 16)
|
||||
@@ -328,9 +247,17 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
||||
print(f"Model: {model_name}")
|
||||
|
||||
# Get data in BERT format
|
||||
raw_data = get_ddi_training_data(max_samples=max_samples)
|
||||
training_data = [{"text": d["instruction"], "label": d["label"]} for d in raw_data]
|
||||
# Load data
|
||||
raw_data = load_drugbank_data(max_samples=max_samples)
|
||||
|
||||
# Create text + label format
|
||||
# Shift severity to 0-indexed (1-4 -> 0-3)
|
||||
training_data = [{
|
||||
"text": f"{d['drug1']} and {d['drug2']}: {d['interaction_text']}",
|
||||
"label": d['severity'] - 1 # 0-indexed
|
||||
} for d in raw_data]
|
||||
|
||||
print(f"Loaded {len(training_data)} samples")
|
||||
|
||||
# Split
|
||||
if eval_split > 0:
|
||||
@@ -344,10 +271,10 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
train_dataset = Dataset.from_list(train_data)
|
||||
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
|
||||
|
||||
# Load model
|
||||
# Load model (4 classes: minor, moderate, major, contraindicated)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, num_labels=5
|
||||
model_name, num_labels=4
|
||||
)
|
||||
|
||||
def tokenize(examples):
|
||||
@@ -392,6 +319,7 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
'model_name': model_name,
|
||||
'samples': len(train_data),
|
||||
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
||||
'data_source': 'drugbank_176k',
|
||||
}
|
||||
|
||||
if eval_dataset:
|
||||
@@ -401,7 +329,7 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
|
||||
})
|
||||
|
||||
return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained'}
|
||||
return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained on DrugBank data'}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
@@ -414,12 +342,11 @@ def handler(job):
|
||||
"""RunPod serverless handler."""
|
||||
job_input = job.get('input', {})
|
||||
|
||||
# Choose training mode
|
||||
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
||||
use_lora = job_input.get('use_lora', True)
|
||||
|
||||
# Auto-detect: use LoRA for large models
|
||||
if 'gemma' in model_name.lower() or 'llama' in model_name.lower() or 'mistral' in model_name.lower():
|
||||
if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']):
|
||||
use_lora = True
|
||||
elif 'bert' in model_name.lower():
|
||||
use_lora = False
|
||||
|
||||
Reference in New Issue
Block a user