mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
feat: Add 176K real DrugBank DDI samples with drug names
- Downloaded 191K DDI pairs from TDC DrugBank - Fetched 1,634 drug names from PubChem API (96% hit rate) - Created complete training dataset with: - Real drug names (not just IDs) - 86 interaction type descriptions - Severity labels (minor/moderate/major/contraindicated) - Bundled 34MB data file in Docker image - Handler loads real data instead of curated samples
This commit is contained in:
@@ -8,11 +8,13 @@ COPY requirements.txt /app/requirements.txt
|
|||||||
# Install dependencies
|
# Install dependencies
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
# Copy handler
|
# Copy handler and data
|
||||||
COPY handler.py /app/handler.py
|
COPY handler.py /app/handler.py
|
||||||
|
COPY data/ /app/data/
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
ENV HF_HOME=/tmp/huggingface
|
ENV HF_HOME=/tmp/huggingface
|
||||||
|
ENV DDI_DATA_PATH=/app/data/drugbank_ddi_complete.jsonl
|
||||||
|
|
||||||
CMD ["python", "-u", "handler.py"]
|
CMD ["python", "-u", "handler.py"]
|
||||||
|
|||||||
176075
components/runpod_trainer/data/drugbank_ddi_complete.jsonl
Normal file
176075
components/runpod_trainer/data/drugbank_ddi_complete.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
88
components/runpod_trainer/data/drugbank_label_map.json
Normal file
88
components/runpod_trainer/data/drugbank_label_map.json
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
{
|
||||||
|
"1": "#Drug1 may increase the photosensitizing activities of #Drug2.",
|
||||||
|
"2": "#Drug1 may increase the anticholinergic activities of #Drug2.",
|
||||||
|
"3": "The bioavailability of #Drug2 can be decreased when combined with #Drug1.",
|
||||||
|
"4": "The metabolism of #Drug2 can be increased when combined with #Drug1.",
|
||||||
|
"5": "#Drug1 may decrease the vasoconstricting activities of #Drug2.",
|
||||||
|
"6": "#Drug1 may increase the anticoagulant activities of #Drug2.",
|
||||||
|
"7": "#Drug1 may increase the ototoxic activities of #Drug2.",
|
||||||
|
"8": "The therapeutic efficacy of #Drug2 can be increased when used in combination with #Drug1.",
|
||||||
|
"9": "#Drug1 may increase the hypoglycemic activities of #Drug2.",
|
||||||
|
"10": "#Drug1 may increase the antihypertensive activities of #Drug2.",
|
||||||
|
"11": "The serum concentration of the active metabolites of #Drug2 can be reduced when #Drug2 is used in combination with #Drug1 resulting in a loss in efficacy.",
|
||||||
|
"12": "#Drug1 may decrease the anticoagulant activities of #Drug2.",
|
||||||
|
"13": "The absorption of #Drug2 can be decreased when combined with #Drug1.",
|
||||||
|
"14": "#Drug1 may decrease the bronchodilatory activities of #Drug2.",
|
||||||
|
"15": "#Drug1 may increase the cardiotoxic activities of #Drug2.",
|
||||||
|
"16": "#Drug1 may increase the central nervous system depressant (CNS depressant) activities of #Drug2.",
|
||||||
|
"17": "#Drug1 may decrease the neuromuscular blocking activities of #Drug2.",
|
||||||
|
"18": "#Drug1 can cause an increase in the absorption of #Drug2 resulting in an increased serum concentration and potentially a worsening of adverse effects.",
|
||||||
|
"19": "#Drug1 may increase the vasoconstricting activities of #Drug2.",
|
||||||
|
"20": "#Drug1 may increase the QTc-prolonging activities of #Drug2.",
|
||||||
|
"21": "#Drug1 may increase the neuromuscular blocking activities of #Drug2.",
|
||||||
|
"22": "#Drug1 may increase the adverse neuromuscular activities of #Drug2.",
|
||||||
|
"23": "#Drug1 may increase the stimulatory activities of #Drug2.",
|
||||||
|
"24": "#Drug1 may increase the hypocalcemic activities of #Drug2.",
|
||||||
|
"25": "#Drug1 may increase the atrioventricular blocking (AV block) activities of #Drug2.",
|
||||||
|
"26": "#Drug1 may decrease the antiplatelet activities of #Drug2.",
|
||||||
|
"27": "#Drug1 may increase the neuroexcitatory activities of #Drug2.",
|
||||||
|
"28": "#Drug1 may increase the dermatologic adverse activities of #Drug2.",
|
||||||
|
"29": "#Drug1 may decrease the diuretic activities of #Drug2.",
|
||||||
|
"30": "#Drug1 may increase the orthostatic hypotensive activities of #Drug2.",
|
||||||
|
"31": "The risk or severity of hypertension can be increased when #Drug2 is combined with #Drug1.",
|
||||||
|
"32": "#Drug1 may increase the sedative activities of #Drug2.",
|
||||||
|
"33": "The risk or severity of QTc prolongation can be increased when #Drug1 is combined with #Drug2.",
|
||||||
|
"34": "#Drug1 may increase the immunosuppressive activities of #Drug2.",
|
||||||
|
"35": "#Drug1 may increase the neurotoxic activities of #Drug2.",
|
||||||
|
"36": "#Drug1 may increase the antipsychotic activities of #Drug2.",
|
||||||
|
"37": "#Drug1 may decrease the antihypertensive activities of #Drug2.",
|
||||||
|
"38": "#Drug1 may increase the vasodilatory activities of #Drug2.",
|
||||||
|
"39": "#Drug1 may increase the constipating activities of #Drug2.",
|
||||||
|
"40": "#Drug1 may increase the respiratory depressant activities of #Drug2.",
|
||||||
|
"41": "#Drug1 may increase the hypotensive and central nervous system depressant (CNS depressant) activities of #Drug2.",
|
||||||
|
"42": "The risk or severity of hyperkalemia can be increased when #Drug1 is combined with #Drug2.",
|
||||||
|
"43": "The protein binding of #Drug2 can be decreased when combined with #Drug1.",
|
||||||
|
"44": "#Drug1 may increase the central neurotoxic activities of #Drug2.",
|
||||||
|
"45": "#Drug1 may decrease effectiveness of #Drug2 as a diagnostic agent.",
|
||||||
|
"46": "#Drug1 may increase the bronchoconstrictory activities of #Drug2.",
|
||||||
|
"47": "The metabolism of #Drug2 can be decreased when combined with #Drug1.",
|
||||||
|
"48": "#Drug1 may increase the myopathic rhabdomyolysis activities of #Drug2.",
|
||||||
|
"49": "The risk or severity of adverse effects can be increased when #Drug1 is combined with #Drug2.",
|
||||||
|
"50": "The risk or severity of heart failure can be increased when #Drug2 is combined with #Drug1.",
|
||||||
|
"51": "#Drug1 may increase the hypercalcemic activities of #Drug2.",
|
||||||
|
"52": "#Drug1 may decrease the analgesic activities of #Drug2.",
|
||||||
|
"53": "#Drug1 may increase the antiplatelet activities of #Drug2.",
|
||||||
|
"54": "#Drug1 may increase the bradycardic activities of #Drug2.",
|
||||||
|
"55": "#Drug1 may increase the hyponatremic activities of #Drug2.",
|
||||||
|
"56": "The risk or severity of hypotension can be increased when #Drug1 is combined with #Drug2.",
|
||||||
|
"57": "#Drug1 may increase the nephrotoxic activities of #Drug2.",
|
||||||
|
"58": "#Drug1 may decrease the cardiotoxic activities of #Drug2.",
|
||||||
|
"59": "#Drug1 may increase the ulcerogenic activities of #Drug2.",
|
||||||
|
"60": "#Drug1 may increase the hypotensive activities of #Drug2.",
|
||||||
|
"61": "#Drug1 may decrease the stimulatory activities of #Drug2.",
|
||||||
|
"62": "The bioavailability of #Drug2 can be increased when combined with #Drug1.",
|
||||||
|
"63": "#Drug1 may increase the myelosuppressive activities of #Drug2.",
|
||||||
|
"64": "#Drug1 may increase the serotonergic activities of #Drug2.",
|
||||||
|
"65": "#Drug1 may increase the excretion rate of #Drug2 which could result in a lower serum level and potentially a reduction in efficacy.",
|
||||||
|
"66": "The risk or severity of bleeding can be increased when #Drug1 is combined with #Drug2.",
|
||||||
|
"67": "#Drug1 can cause a decrease in the absorption of #Drug2 resulting in a reduced serum concentration and potentially a decrease in efficacy.",
|
||||||
|
"68": "#Drug1 may increase the hyperkalemic activities of #Drug2.",
|
||||||
|
"69": "#Drug1 may increase the analgesic activities of #Drug2.",
|
||||||
|
"70": "The therapeutic efficacy of #Drug2 can be decreased when used in combination with #Drug1.",
|
||||||
|
"71": "#Drug1 may increase the hypertensive activities of #Drug2.",
|
||||||
|
"72": "#Drug1 may decrease the excretion rate of #Drug2 which could result in a higher serum level.",
|
||||||
|
"73": "The serum concentration of #Drug2 can be increased when it is combined with #Drug1.",
|
||||||
|
"74": "#Drug1 may increase the fluid retaining activities of #Drug2.",
|
||||||
|
"75": "The serum concentration of #Drug2 can be decreased when it is combined with #Drug1.",
|
||||||
|
"76": "#Drug1 may decrease the sedative activities of #Drug2.",
|
||||||
|
"77": "The serum concentration of the active metabolites of #Drug2 can be increased when #Drug2 is used in combination with #Drug1.",
|
||||||
|
"78": "#Drug1 may increase the hyperglycemic activities of #Drug2.",
|
||||||
|
"79": "#Drug1 may increase the central nervous system depressant (CNS depressant) and hypertensive activities of #Drug2.",
|
||||||
|
"80": "#Drug1 may increase the hepatotoxic activities of #Drug2.",
|
||||||
|
"81": "#Drug1 may increase the thrombogenic activities of #Drug2.",
|
||||||
|
"82": "#Drug1 may increase the arrhythmogenic activities of #Drug2.",
|
||||||
|
"83": "#Drug1 may increase the hypokalemic activities of #Drug2.",
|
||||||
|
"84": "#Drug1 may increase the vasopressor activities of #Drug2.",
|
||||||
|
"85": "#Drug1 may increase the tachycardic activities of #Drug2.",
|
||||||
|
"86": "The risk of a hypersensitivity reaction to #Drug2 is increased when it is combined with #Drug1."
|
||||||
|
}
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
RunPod Serverless Handler for DDI Model Training
|
RunPod Serverless Handler for DDI Model Training
|
||||||
|
|
||||||
Supports both BERT-style classification and LLM fine-tuning with LoRA.
|
Supports both BERT-style classification and LLM fine-tuning with LoRA.
|
||||||
Default: Gemma 3 12B with QLoRA for DDI severity classification.
|
Uses 176K real DrugBank DDI samples with drug names.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
@@ -11,146 +11,56 @@ from typing import Dict, Any, List, Optional
|
|||||||
|
|
||||||
|
|
||||||
# DDI severity labels
|
# DDI severity labels
|
||||||
DDI_LABELS = {
|
DDI_SEVERITY = {
|
||||||
0: "no_interaction",
|
|
||||||
1: "minor",
|
1: "minor",
|
||||||
2: "moderate",
|
2: "moderate",
|
||||||
3: "major",
|
3: "major",
|
||||||
4: "contraindicated"
|
4: "contraindicated"
|
||||||
}
|
}
|
||||||
|
|
||||||
LABEL_DESCRIPTIONS = {
|
|
||||||
0: "No clinically significant interaction",
|
def load_drugbank_data(max_samples: int = None, severity_filter: List[int] = None) -> List[Dict]:
|
||||||
1: "Minor interaction - minimal clinical significance",
|
"""Load real DrugBank DDI data from bundled file."""
|
||||||
2: "Moderate interaction - may require monitoring or dose adjustment",
|
data_path = os.environ.get('DDI_DATA_PATH', '/app/data/drugbank_ddi_complete.jsonl')
|
||||||
3: "Major interaction - avoid combination if possible, high risk",
|
|
||||||
4: "Contraindicated - do not use together, life-threatening risk"
|
if not os.path.exists(data_path):
|
||||||
}
|
print(f"WARNING: Data file not found at {data_path}, using curated fallback")
|
||||||
|
return get_curated_fallback()
|
||||||
|
|
||||||
|
data = []
|
||||||
|
with open(data_path) as f:
|
||||||
|
for line in f:
|
||||||
|
item = json.loads(line)
|
||||||
|
if severity_filter and item['severity'] not in severity_filter:
|
||||||
|
continue
|
||||||
|
data.append(item)
|
||||||
|
if max_samples and len(data) >= max_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def get_ddi_training_data(max_samples: int = 5000) -> List[Dict[str, Any]]:
|
def get_curated_fallback() -> List[Dict]:
|
||||||
"""Generate DDI training data formatted for instruction tuning."""
|
"""Fallback curated data if main file not available."""
|
||||||
import random
|
patterns = [
|
||||||
random.seed(42)
|
{"drug1": "fluoxetine", "drug2": "tramadol", "interaction_text": "fluoxetine may increase the risk of serotonin syndrome when combined with tramadol", "severity": 4},
|
||||||
|
{"drug1": "warfarin", "drug2": "aspirin", "interaction_text": "warfarin may increase the risk of bleeding when combined with aspirin", "severity": 3},
|
||||||
# Real drug interaction patterns based on clinical data
|
{"drug1": "simvastatin", "drug2": "amlodipine", "interaction_text": "The serum concentration of simvastatin can be increased when combined with amlodipine", "severity": 2},
|
||||||
ddi_patterns = [
|
{"drug1": "metformin", "drug2": "lisinopril", "interaction_text": "metformin and lisinopril have no significant interaction", "severity": 1},
|
||||||
# Contraindicated (4)
|
|
||||||
{"drugs": ["fluoxetine", "tramadol"], "type": "serotonin syndrome risk", "label": 4},
|
|
||||||
{"drugs": ["fluoxetine", "phenelzine"], "type": "serotonin syndrome risk", "label": 4},
|
|
||||||
{"drugs": ["simvastatin", "itraconazole"], "type": "rhabdomyolysis risk", "label": 4},
|
|
||||||
{"drugs": ["methotrexate", "trimethoprim"], "type": "severe bone marrow suppression", "label": 4},
|
|
||||||
{"drugs": ["warfarin", "miconazole"], "type": "severe bleeding risk", "label": 4},
|
|
||||||
{"drugs": ["cisapride", "erythromycin"], "type": "QT prolongation cardiac arrest", "label": 4},
|
|
||||||
{"drugs": ["pimozide", "clarithromycin"], "type": "QT prolongation risk", "label": 4},
|
|
||||||
{"drugs": ["ergotamine", "ritonavir"], "type": "ergot toxicity risk", "label": 4},
|
|
||||||
{"drugs": ["sildenafil", "nitroglycerin"], "type": "severe hypotension", "label": 4},
|
|
||||||
{"drugs": ["linezolid", "sertraline"], "type": "serotonin syndrome", "label": 4},
|
|
||||||
{"drugs": ["maoi", "meperidine"], "type": "hypertensive crisis", "label": 4},
|
|
||||||
{"drugs": ["metronidazole", "disulfiram"], "type": "psychosis risk", "label": 4},
|
|
||||||
|
|
||||||
# Major (3)
|
|
||||||
{"drugs": ["warfarin", "aspirin"], "type": "increased bleeding risk", "label": 3},
|
|
||||||
{"drugs": ["digoxin", "amiodarone"], "type": "digoxin toxicity elevated", "label": 3},
|
|
||||||
{"drugs": ["lithium", "ibuprofen"], "type": "lithium toxicity risk", "label": 3},
|
|
||||||
{"drugs": ["metformin", "iodinated contrast"], "type": "lactic acidosis risk", "label": 3},
|
|
||||||
{"drugs": ["potassium chloride", "lisinopril"], "type": "hyperkalemia risk", "label": 3},
|
|
||||||
{"drugs": ["oxycodone", "alprazolam"], "type": "respiratory depression", "label": 3},
|
|
||||||
{"drugs": ["theophylline", "ciprofloxacin"], "type": "theophylline toxicity", "label": 3},
|
|
||||||
{"drugs": ["phenytoin", "fluconazole"], "type": "phenytoin toxicity", "label": 3},
|
|
||||||
{"drugs": ["carbamazepine", "verapamil"], "type": "carbamazepine toxicity", "label": 3},
|
|
||||||
{"drugs": ["cyclosporine", "ketoconazole"], "type": "nephrotoxicity risk", "label": 3},
|
|
||||||
{"drugs": ["methotrexate", "ibuprofen"], "type": "methotrexate toxicity", "label": 3},
|
|
||||||
{"drugs": ["quinidine", "digoxin"], "type": "digoxin toxicity", "label": 3},
|
|
||||||
{"drugs": ["clopidogrel", "omeprazole"], "type": "reduced antiplatelet effect", "label": 3},
|
|
||||||
{"drugs": ["warfarin", "rifampin"], "type": "reduced anticoagulation", "label": 3},
|
|
||||||
{"drugs": ["dabigatran", "rifampin"], "type": "reduced anticoagulant effect", "label": 3},
|
|
||||||
|
|
||||||
# Moderate (2)
|
|
||||||
{"drugs": ["simvastatin", "amlodipine"], "type": "increased statin exposure", "label": 2},
|
|
||||||
{"drugs": ["metformin", "cimetidine"], "type": "increased metformin levels", "label": 2},
|
|
||||||
{"drugs": ["levothyroxine", "calcium carbonate"], "type": "reduced thyroid absorption", "label": 2},
|
|
||||||
{"drugs": ["gabapentin", "aluminum hydroxide"], "type": "reduced gabapentin absorption", "label": 2},
|
|
||||||
{"drugs": ["furosemide", "gentamicin"], "type": "ototoxicity risk", "label": 2},
|
|
||||||
{"drugs": ["prednisone", "naproxen"], "type": "GI bleeding risk", "label": 2},
|
|
||||||
{"drugs": ["metoprolol", "verapamil"], "type": "bradycardia risk", "label": 2},
|
|
||||||
{"drugs": ["sertraline", "tramadol"], "type": "seizure threshold lowered", "label": 2},
|
|
||||||
{"drugs": ["losartan", "potassium supplements"], "type": "hyperkalemia risk", "label": 2},
|
|
||||||
{"drugs": ["alprazolam", "ketoconazole"], "type": "increased sedation", "label": 2},
|
|
||||||
{"drugs": ["atorvastatin", "grapefruit juice"], "type": "increased statin levels", "label": 2},
|
|
||||||
{"drugs": ["ciprofloxacin", "ferrous sulfate"], "type": "reduced antibiotic absorption", "label": 2},
|
|
||||||
{"drugs": ["warfarin", "acetaminophen"], "type": "slight INR increase", "label": 2},
|
|
||||||
{"drugs": ["insulin", "propranolol"], "type": "masked hypoglycemia", "label": 2},
|
|
||||||
{"drugs": ["digoxin", "spironolactone"], "type": "increased digoxin levels", "label": 2},
|
|
||||||
|
|
||||||
# Minor (1)
|
|
||||||
{"drugs": ["aspirin", "ibuprofen"], "type": "reduced cardioprotection", "label": 1},
|
|
||||||
{"drugs": ["metformin", "vitamin B12"], "type": "reduced B12 absorption long-term", "label": 1},
|
|
||||||
{"drugs": ["amoxicillin", "ethinyl estradiol"], "type": "theoretical reduced efficacy", "label": 1},
|
|
||||||
{"drugs": ["omeprazole", "vitamin B12"], "type": "reduced absorption", "label": 1},
|
|
||||||
{"drugs": ["caffeine", "ciprofloxacin"], "type": "increased caffeine effect", "label": 1},
|
|
||||||
{"drugs": ["calcium carbonate", "ferrous sulfate"], "type": "timing interaction", "label": 1},
|
|
||||||
{"drugs": ["atorvastatin", "niacin"], "type": "monitoring recommended", "label": 1},
|
|
||||||
{"drugs": ["lisinopril", "aspirin"], "type": "possible reduced effect", "label": 1},
|
|
||||||
{"drugs": ["hydrochlorothiazide", "calcium"], "type": "hypercalcemia monitoring", "label": 1},
|
|
||||||
{"drugs": ["metoprolol", "clonidine"], "type": "withdrawal monitoring", "label": 1},
|
|
||||||
|
|
||||||
# No interaction (0)
|
|
||||||
{"drugs": ["amlodipine", "atorvastatin"], "type": "safe combination", "label": 0},
|
|
||||||
{"drugs": ["metformin", "lisinopril"], "type": "complementary therapy", "label": 0},
|
|
||||||
{"drugs": ["omeprazole", "levothyroxine"], "type": "can be used together with spacing", "label": 0},
|
|
||||||
{"drugs": ["aspirin", "atorvastatin"], "type": "standard combination", "label": 0},
|
|
||||||
{"drugs": ["metoprolol", "lisinopril"], "type": "common combination", "label": 0},
|
|
||||||
{"drugs": ["gabapentin", "acetaminophen"], "type": "no interaction", "label": 0},
|
|
||||||
{"drugs": ["sertraline", "omeprazole"], "type": "generally safe", "label": 0},
|
|
||||||
{"drugs": ["metformin", "glipizide"], "type": "complementary", "label": 0},
|
|
||||||
{"drugs": ["hydrochlorothiazide", "lisinopril"], "type": "synergistic", "label": 0},
|
|
||||||
{"drugs": ["pantoprazole", "amlodipine"], "type": "no known interaction", "label": 0},
|
|
||||||
]
|
]
|
||||||
|
return patterns * 50 # 200 samples
|
||||||
training_data = []
|
|
||||||
|
|
||||||
for pattern in ddi_patterns:
|
|
||||||
drug1, drug2 = pattern["drugs"]
|
|
||||||
interaction_type = pattern["type"]
|
|
||||||
label = pattern["label"]
|
|
||||||
label_name = DDI_LABELS[label]
|
|
||||||
label_desc = LABEL_DESCRIPTIONS[label]
|
|
||||||
|
|
||||||
# Create instruction-tuning format
|
|
||||||
prompts = [
|
|
||||||
f"Analyze the drug-drug interaction between {drug1} and {drug2}.",
|
|
||||||
f"What is the severity of combining {drug1} with {drug2}?",
|
|
||||||
f"A patient is taking {drug1}. They need to start {drug2}. Assess the interaction risk.",
|
|
||||||
f"Evaluate the interaction: {drug1} + {drug2}",
|
|
||||||
f"Drug interaction check: {drug1} and {drug2}",
|
|
||||||
]
|
|
||||||
|
|
||||||
for prompt in prompts:
|
|
||||||
response = f"Severity: {label_name.upper()}\nInteraction: {interaction_type}\nRecommendation: {label_desc}"
|
|
||||||
|
|
||||||
training_data.append({
|
|
||||||
"instruction": prompt,
|
|
||||||
"response": response,
|
|
||||||
"label": label,
|
|
||||||
"label_name": label_name
|
|
||||||
})
|
|
||||||
|
|
||||||
# Shuffle and replicate to reach target size
|
|
||||||
random.shuffle(training_data)
|
|
||||||
|
|
||||||
while len(training_data) < max_samples:
|
|
||||||
training_data.extend(training_data[:min(len(training_data), max_samples - len(training_data))])
|
|
||||||
|
|
||||||
return training_data[:max_samples]
|
|
||||||
|
|
||||||
|
|
||||||
def format_for_gemma(example: Dict) -> str:
|
def format_for_gemma(item: Dict) -> str:
|
||||||
"""Format example for Gemma instruction tuning."""
|
"""Format DDI item for Gemma instruction tuning."""
|
||||||
|
severity_name = DDI_SEVERITY.get(item['severity'], 'unknown')
|
||||||
|
|
||||||
return f"""<start_of_turn>user
|
return f"""<start_of_turn>user
|
||||||
{example['instruction']}<end_of_turn>
|
Analyze the drug interaction between {item['drug1']} and {item['drug2']}.<end_of_turn>
|
||||||
<start_of_turn>model
|
<start_of_turn>model
|
||||||
{example['response']}<end_of_turn>"""
|
Interaction: {item['interaction_text']}
|
||||||
|
Severity: {severity_name.upper()}
|
||||||
|
Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}<end_of_turn>"""
|
||||||
|
|
||||||
|
|
||||||
def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
@@ -170,32 +80,38 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
||||||
max_samples = job_input.get('max_samples', 2000)
|
max_samples = job_input.get('max_samples', 10000)
|
||||||
epochs = job_input.get('epochs', 1)
|
epochs = job_input.get('epochs', 1)
|
||||||
learning_rate = job_input.get('learning_rate', 2e-4)
|
learning_rate = job_input.get('learning_rate', 2e-4)
|
||||||
batch_size = job_input.get('batch_size', 4)
|
batch_size = job_input.get('batch_size', 2)
|
||||||
lora_r = job_input.get('lora_r', 16)
|
lora_r = job_input.get('lora_r', 16)
|
||||||
lora_alpha = job_input.get('lora_alpha', 32)
|
lora_alpha = job_input.get('lora_alpha', 32)
|
||||||
max_seq_length = job_input.get('max_seq_length', 512)
|
max_seq_length = job_input.get('max_seq_length', 512)
|
||||||
|
severity_filter = job_input.get('severity_filter', None) # e.g., [3, 4] for major/contraindicated only
|
||||||
|
|
||||||
work_dir = tempfile.mkdtemp()
|
work_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
||||||
|
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
||||||
print(f"Model: {model_name}")
|
print(f"Model: {model_name}")
|
||||||
print(f"LoRA r={lora_r}, alpha={lora_alpha}")
|
print(f"LoRA r={lora_r}, alpha={lora_alpha}")
|
||||||
print(f"Samples: {max_samples}, Epochs: {epochs}")
|
|
||||||
|
|
||||||
# Load training data
|
# Load training data
|
||||||
print("Loading DDI training data...")
|
print(f"Loading DrugBank DDI data (max {max_samples})...")
|
||||||
training_data = get_ddi_training_data(max_samples=max_samples)
|
raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter)
|
||||||
|
|
||||||
# Format for Gemma
|
# Format for Gemma
|
||||||
formatted_data = [{"text": format_for_gemma(ex)} for ex in training_data]
|
formatted_data = [{"text": format_for_gemma(item)} for item in raw_data]
|
||||||
dataset = Dataset.from_list(formatted_data)
|
dataset = Dataset.from_list(formatted_data)
|
||||||
|
|
||||||
print(f"Dataset size: {len(dataset)}")
|
print(f"Dataset size: {len(dataset)}")
|
||||||
|
|
||||||
|
# Severity distribution
|
||||||
|
from collections import Counter
|
||||||
|
sev_dist = Counter(item['severity'] for item in raw_data)
|
||||||
|
print(f"Severity distribution: {dict(sev_dist)}")
|
||||||
|
|
||||||
# QLoRA config - 4-bit quantization
|
# QLoRA config - 4-bit quantization
|
||||||
bnb_config = BitsAndBytesConfig(
|
bnb_config = BitsAndBytesConfig(
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
@@ -232,18 +148,20 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
model.print_trainable_parameters()
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"Trainable params: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")
|
||||||
|
|
||||||
# Training arguments
|
# Training arguments
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=work_dir,
|
output_dir=work_dir,
|
||||||
num_train_epochs=epochs,
|
num_train_epochs=epochs,
|
||||||
per_device_train_batch_size=batch_size,
|
per_device_train_batch_size=batch_size,
|
||||||
gradient_accumulation_steps=4,
|
gradient_accumulation_steps=8,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
warmup_ratio=0.1,
|
warmup_ratio=0.1,
|
||||||
logging_steps=10,
|
logging_steps=25,
|
||||||
save_strategy="no",
|
save_strategy="no",
|
||||||
bf16=True,
|
bf16=True,
|
||||||
gradient_checkpointing=True,
|
gradient_checkpointing=True,
|
||||||
@@ -271,12 +189,14 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
'train_loss': float(train_result.training_loss),
|
'train_loss': float(train_result.training_loss),
|
||||||
'epochs': epochs,
|
'epochs': epochs,
|
||||||
'model_name': model_name,
|
'model_name': model_name,
|
||||||
'samples': len(training_data),
|
'samples': len(raw_data),
|
||||||
'lora_r': lora_r,
|
'lora_r': lora_r,
|
||||||
'lora_alpha': lora_alpha,
|
'lora_alpha': lora_alpha,
|
||||||
|
'trainable_params': trainable_params,
|
||||||
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
||||||
'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad),
|
'vram_gb': torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0,
|
||||||
'quantization': '4-bit QLoRA',
|
'data_source': 'drugbank_176k',
|
||||||
|
'severity_dist': dict(sev_dist),
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
|
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
|
||||||
@@ -284,7 +204,7 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
return {
|
return {
|
||||||
'status': 'success',
|
'status': 'success',
|
||||||
'metrics': metrics,
|
'metrics': metrics,
|
||||||
'message': f'Gemma 3 12B fine-tuned with QLoRA on DDI data'
|
'message': f'Gemma 3 fine-tuned on {len(raw_data):,} real DrugBank DDI samples'
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -296,13 +216,12 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
shutil.rmtree(work_dir, ignore_errors=True)
|
shutil.rmtree(work_dir, ignore_errors=True)
|
||||||
# Clear GPU memory
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Train BERT-style classifier (original approach)."""
|
"""Train BERT-style classifier for DDI severity prediction."""
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -316,7 +235,7 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
||||||
max_samples = job_input.get('max_samples', 5000)
|
max_samples = job_input.get('max_samples', 50000)
|
||||||
epochs = job_input.get('epochs', 3)
|
epochs = job_input.get('epochs', 3)
|
||||||
learning_rate = job_input.get('learning_rate', 2e-5)
|
learning_rate = job_input.get('learning_rate', 2e-5)
|
||||||
batch_size = job_input.get('batch_size', 16)
|
batch_size = job_input.get('batch_size', 16)
|
||||||
@@ -328,9 +247,17 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
||||||
print(f"Model: {model_name}")
|
print(f"Model: {model_name}")
|
||||||
|
|
||||||
# Get data in BERT format
|
# Load data
|
||||||
raw_data = get_ddi_training_data(max_samples=max_samples)
|
raw_data = load_drugbank_data(max_samples=max_samples)
|
||||||
training_data = [{"text": d["instruction"], "label": d["label"]} for d in raw_data]
|
|
||||||
|
# Create text + label format
|
||||||
|
# Shift severity to 0-indexed (1-4 -> 0-3)
|
||||||
|
training_data = [{
|
||||||
|
"text": f"{d['drug1']} and {d['drug2']}: {d['interaction_text']}",
|
||||||
|
"label": d['severity'] - 1 # 0-indexed
|
||||||
|
} for d in raw_data]
|
||||||
|
|
||||||
|
print(f"Loaded {len(training_data)} samples")
|
||||||
|
|
||||||
# Split
|
# Split
|
||||||
if eval_split > 0:
|
if eval_split > 0:
|
||||||
@@ -344,10 +271,10 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
train_dataset = Dataset.from_list(train_data)
|
train_dataset = Dataset.from_list(train_data)
|
||||||
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
|
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
|
||||||
|
|
||||||
# Load model
|
# Load model (4 classes: minor, moderate, major, contraindicated)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
model_name, num_labels=5
|
model_name, num_labels=4
|
||||||
)
|
)
|
||||||
|
|
||||||
def tokenize(examples):
|
def tokenize(examples):
|
||||||
@@ -392,6 +319,7 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
'model_name': model_name,
|
'model_name': model_name,
|
||||||
'samples': len(train_data),
|
'samples': len(train_data),
|
||||||
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
||||||
|
'data_source': 'drugbank_176k',
|
||||||
}
|
}
|
||||||
|
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
@@ -401,7 +329,7 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
|
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
|
||||||
})
|
})
|
||||||
|
|
||||||
return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained'}
|
return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained on DrugBank data'}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
@@ -414,12 +342,11 @@ def handler(job):
|
|||||||
"""RunPod serverless handler."""
|
"""RunPod serverless handler."""
|
||||||
job_input = job.get('input', {})
|
job_input = job.get('input', {})
|
||||||
|
|
||||||
# Choose training mode
|
|
||||||
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
||||||
use_lora = job_input.get('use_lora', True)
|
use_lora = job_input.get('use_lora', True)
|
||||||
|
|
||||||
# Auto-detect: use LoRA for large models
|
# Auto-detect: use LoRA for large models
|
||||||
if 'gemma' in model_name.lower() or 'llama' in model_name.lower() or 'mistral' in model_name.lower():
|
if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']):
|
||||||
use_lora = True
|
use_lora = True
|
||||||
elif 'bert' in model_name.lower():
|
elif 'bert' in model_name.lower():
|
||||||
use_lora = False
|
use_lora = False
|
||||||
|
|||||||
Reference in New Issue
Block a user