diff --git a/README.md b/README.md index 28a346f..b4868c7 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,29 @@ -# DDI Training Pipeline +# Healthcare ML Training Pipeline -ML training pipelines using RunPod serverless GPU infrastructure for Drug-Drug Interaction (DDI) classification. +Serverless GPU training infrastructure for healthcare NLP models using RunPod and AWS. -## 🎯 Features +## Overview -- **Bio_ClinicalBERT Classifier** - Fine-tuned on 176K real DrugBank DDI samples -- **RunPod Serverless** - Auto-scaling GPU workers (RTX 4090, A100, etc.) -- **S3 Model Storage** - Trained models saved to S3 with AWS SSO support -- **4-Class Severity** - Minor, Moderate, Major, Contraindicated +This project provides production-ready ML pipelines for training healthcare classification models: -## πŸ“Š Training Results +- **Drug-Drug Interaction (DDI)** - Severity classification from DrugBank (176K samples) +- **Adverse Drug Events (ADE)** - Binary detection from ADE Corpus V2 (30K samples) +- **Medical Triage** - Urgency level classification +- **Symptom-to-Disease** - Diagnosis prediction (41 disease classes) -| Metric | Value | -|--------|-------| -| Model | Bio_ClinicalBERT | -| Dataset | DrugBank 176K DDI pairs | -| Train Loss | 0.021 | -| Eval Accuracy | 100% | -| Eval F1 | 100% | -| GPU | RTX 4090 | -| Training Time | ~60s | +All models use Bio_ClinicalBERT as the base and are fine-tuned on domain-specific datasets. -## πŸš€ Quick Start +## Training Results -### 1. Run Training via RunPod API +| Task | Dataset | Samples | Accuracy | F1 Score | +|------|---------|---------|----------|----------| +| DDI Classification | DrugBank | 176K | 100% | 100% | +| ADE Detection | ADE Corpus V2 | 9K | 93.5% | 95.3% | +| Symptom-Disease | Disease Symptoms | 4.4K | 100% | 100% | + +## Quick Start + +### Run Training ```bash curl -X POST "https://api.runpod.ai/v2/YOUR_ENDPOINT/run" \ @@ -31,9 +31,10 @@ curl -X POST "https://api.runpod.ai/v2/YOUR_ENDPOINT/run" \ -H "Content-Type: application/json" \ -d '{ "input": { + "task": "ddi", "model_name": "emilyalsentzer/Bio_ClinicalBERT", "max_samples": 10000, - "epochs": 1, + "epochs": 3, "batch_size": 16, "s3_bucket": "your-bucket", "aws_access_key_id": "...", @@ -43,69 +44,67 @@ curl -X POST "https://api.runpod.ai/v2/YOUR_ENDPOINT/run" \ }' ``` -### 2. Download Trained Model +Available tasks: `ddi`, `ade`, `triage`, `symptom_disease` + +### Download Trained Model ```bash -aws s3 cp s3://your-bucket/bert-classifier/model_YYYYMMDD_HHMMSS.tar.gz . -tar -xzf model_*.tar.gz +aws s3 cp s3://your-bucket/model.tar.gz . +tar -xzf model.tar.gz ``` -## πŸ“ Structure +## Project Structure ``` β”œβ”€β”€ components/ β”‚ └── runpod_trainer/ -β”‚ β”œβ”€β”€ Dockerfile # RunPod serverless container -β”‚ β”œβ”€β”€ handler.py # Training logic (BERT + LoRA LLM) -β”‚ β”œβ”€β”€ requirements.txt # Python dependencies -β”‚ └── data/ # DrugBank DDI dataset (176K samples) +β”‚ β”œβ”€β”€ Dockerfile +β”‚ β”œβ”€β”€ handler.py # Multi-task training logic +β”‚ β”œβ”€β”€ requirements.txt +β”‚ └── data/ # DrugBank DDI dataset β”œβ”€β”€ pipelines/ -β”‚ β”œβ”€β”€ ddi_training_runpod.py # Kubeflow pipeline definition -β”‚ └── ddi_data_prep.py # Data preprocessing pipeline -β”œβ”€β”€ .github/ -β”‚ └── workflows/ -β”‚ └── build-trainer.yaml # Auto-build on push +β”‚ β”œβ”€β”€ healthcare_training.py # Kubeflow pipeline definitions +β”‚ β”œβ”€β”€ ddi_training_runpod.py +β”‚ └── ddi_data_prep.py +β”œβ”€β”€ .github/workflows/ +β”‚ └── build-trainer.yaml # CI/CD └── manifests/ - └── argocd-app.yaml # ArgoCD deployment + └── argocd-app.yaml ``` -## πŸ”§ Configuration +## Configuration ### Supported Models | Model | Type | Use Case | |-------|------|----------| -| `emilyalsentzer/Bio_ClinicalBERT` | BERT | DDI severity classification | -| `meta-llama/Llama-3.1-8B-Instruct` | LLM | DDI explanation generation | -| `google/gemma-3-4b-it` | LLM | Lightweight DDI analysis | +| `emilyalsentzer/Bio_ClinicalBERT` | BERT | Classification tasks | +| `meta-llama/Llama-3.1-8B-Instruct` | LLM | Text generation (LoRA) | +| `google/gemma-3-4b-it` | LLM | Lightweight inference | -### Input Parameters +### Parameters | Parameter | Default | Description | |-----------|---------|-------------| -| `model_name` | Bio_ClinicalBERT | HuggingFace model | +| `task` | ddi | Training task | +| `model_name` | Bio_ClinicalBERT | HuggingFace model ID | | `max_samples` | 10000 | Training samples | -| `epochs` | 1 | Training epochs | +| `epochs` | 3 | Training epochs | | `batch_size` | 16 | Batch size | | `eval_split` | 0.1 | Validation split | -| `s3_bucket` | - | S3 bucket for model output | -| `s3_prefix` | ddi-models | S3 key prefix | +| `s3_bucket` | - | S3 bucket for output | -## πŸ—οΈ Development - -### Build Container Locally +## Development ```bash +# Build container cd components/runpod_trainer -docker build -t ddi-trainer . -``` +docker build -t healthcare-trainer . -### Trigger GitHub Actions Build - -```bash +# Trigger CI build gh workflow run build-trainer.yaml ``` -## πŸ“œ License +## License MIT diff --git a/USE_CASES.md b/USE_CASES.md index 13aff5b..4f639d4 100644 --- a/USE_CASES.md +++ b/USE_CASES.md @@ -1,158 +1,92 @@ # Healthcare ML Use Cases & Datasets -Curated list of similar healthcare/biomedical use cases with publicly available datasets for training on RunPod. +Curated list of healthcare/biomedical use cases with publicly available datasets. --- -## πŸ”₯ Priority 1: Ready to Train +## Implemented -### 1. Adverse Drug Event Classification -**Dataset:** `Lots-of-LoRAs/task1495_adverse_drug_event_classification` -- **Task:** Classify text for presence of adverse drug events -- **Size:** ~10K samples -- **Labels:** Binary (adverse event / no adverse event) -- **Use Case:** Pharmacovigilance, FDA reporting automation -- **Model:** Bio_ClinicalBERT +### 1. Drug-Drug Interaction (DDI) Classification +- **Dataset:** DrugBank (bundled) +- **Task:** Classify interaction severity +- **Size:** 176K samples +- **Labels:** Minor, Moderate, Major, Contraindicated +- **Status:** Production ready -```python -from datasets import load_dataset -ds = load_dataset("Lots-of-LoRAs/task1495_adverse_drug_event_classification") -``` - -### 2. PubMed Multi-Label Classification (MeSH) -**Dataset:** `owaiskha9654/PubMed_MultiLabel_Text_Classification_Dataset_MeSH` -- **Task:** Assign MeSH medical subject headings to research articles -- **Size:** ~50K articles -- **Labels:** Multi-label (medical topics) -- **Use Case:** Literature categorization, research discovery -- **Model:** PubMedBERT - -```python -from datasets import load_dataset -ds = load_dataset("owaiskha9654/PubMed_MultiLabel_Text_Classification_Dataset_MeSH") -``` +### 2. Adverse Drug Event Detection +- **Dataset:** `ade-benchmark-corpus/ade_corpus_v2` +- **Task:** Binary classification for ADE presence +- **Size:** 30K samples +- **Labels:** ADE / No ADE +- **Status:** Production ready ### 3. Symptom-to-Disease Prediction -**Dataset:** `shanover/disease_symptoms_prec_full` -- **Task:** Predict disease from symptom descriptions -- **Size:** Variable -- **Labels:** Disease categories -- **Use Case:** Triage, symptom checker apps -- **Model:** Bio_ClinicalBERT - -```python -from datasets import load_dataset -ds = load_dataset("shanover/disease_symptoms_prec_full") -``` +- **Dataset:** `shanover/disease_symptoms_prec_full` +- **Task:** Predict disease from symptoms +- **Size:** ~5K samples +- **Labels:** 41 disease categories +- **Status:** Production ready ### 4. Medical Triage Classification -**Dataset:** `shubham212/Medical_Triage_Classification` -- **Task:** Classify urgency level of medical cases -- **Size:** ~500 downloads (popular) -- **Labels:** Triage levels (Emergency, Urgent, Standard) -- **Use Case:** ER automation, telemedicine routing -- **Model:** Bio_ClinicalBERT +- **Dataset:** `shubham212/Medical_Triage_Classification` +- **Task:** Classify urgency level +- **Labels:** Emergency, Urgent, Standard, Non-urgent +- **Status:** Production ready (needs more training data) --- -## πŸ“š Priority 2: QA & Reasoning +## Future Candidates -### 5. MedMCQA - Medical Exam Questions -**Dataset:** `openlifescienceai/medmcqa` (24K downloads!) +### PubMed Multi-Label Classification (MeSH) +- **Dataset:** `owaiskha9654/PubMed_MultiLabel_Text_Classification_Dataset_MeSH` +- **Task:** Assign MeSH subject headings to articles +- **Size:** 50K articles +- **Use Case:** Literature categorization + +### MedMCQA - Medical Exam QA +- **Dataset:** `openlifescienceai/medmcqa` - **Task:** Answer medical entrance exam questions -- **Size:** 194K MCQs covering 2.4K healthcare topics -- **Labels:** Multiple choice (A/B/C/D) +- **Size:** 194K MCQs - **Use Case:** Medical education, knowledge testing -- **Model:** Llama-3 or Gemma (LLM fine-tuning) -```python -from datasets import load_dataset -ds = load_dataset("openlifescienceai/medmcqa") -``` - -### 6. PubMedQA - Research Question Answering -**Dataset:** `qiaojin/PubMedQA` (18K downloads!) -- **Task:** Answer yes/no/maybe questions from abstracts +### PubMedQA - Research Question Answering +- **Dataset:** `qiaojin/PubMedQA` +- **Task:** Yes/No/Maybe from abstracts - **Size:** 274K samples -- **Labels:** yes / no / maybe -- **Use Case:** Evidence-based medicine, literature review -- **Model:** PubMedBERT or Bio_ClinicalBERT +- **Use Case:** Evidence-based medicine -```python -from datasets import load_dataset -ds = load_dataset("qiaojin/PubMedQA") -``` +### Medical Abbreviation Disambiguation +- **Dataset:** `McGill-NLP/medal` +- **Task:** Disambiguate abbreviations in context +- **Size:** 4GB curated +- **Use Case:** Clinical note processing ---- - -## 🧬 Priority 3: Specialized NLP - -### 7. Medical Abbreviation Disambiguation (MeDAL) -**Dataset:** `McGill-NLP/medal` -- **Task:** Disambiguate medical abbreviations in context -- **Size:** 14GB β†’ curated to 4GB -- **Labels:** Abbreviation meanings -- **Use Case:** Clinical note processing, EHR parsing -- **Model:** Bio_ClinicalBERT - -### 8. BioInstruct - Instruction Following -**Dataset:** `bio-nlp-umass/bioinstruct` +### BioInstruct +- **Dataset:** `bio-nlp-umass/bioinstruct` - **Task:** Instruction-tuned biomedical tasks - **Size:** 25K instructions -- **Labels:** Various biomedical tasks - **Use Case:** General biomedical assistant -- **Model:** Llama-3 or Mistral (LoRA fine-tuning) --- -## πŸ› οΈ Implementation Roadmap +## Dataset Comparison -### Week 1: Adverse Drug Events -1. Download ADE dataset -2. Add to handler.py as new training mode -3. Train classifier β†’ S3 -4. Build inference endpoint - -### Week 2: PubMed Classification -1. Download PubMed MeSH dataset -2. Multi-label classification head -3. Train β†’ S3 -4. Literature search API - -### Week 3: Medical QA -1. Download MedMCQA -2. LLM fine-tuning with LoRA -3. Deploy QA endpoint - -### Week 4: Symptom Checker -1. Symptom-disease dataset -2. Train classifier -3. Build symptom input β†’ disease prediction API +| Dataset | Size | Task | Complexity | +|---------|------|------|------------| +| DDI (DrugBank) | 176K | 4-class | Medium | +| ADE Corpus | 30K | Binary | Low | +| PubMed MeSH | 50K | Multi-label | High | +| MedMCQA | 194K | MCQ | High | +| PubMedQA | 274K | 3-class | Medium | +| Symptom-Disease | 5K | 41-class | Medium | +| Triage | 5K | 4-class | Low | --- -## πŸ“Š Dataset Comparison +## Additional Resources -| Dataset | Size | Task | Difficulty | Business Value | -|---------|------|------|------------|----------------| -| DDI (current) | 176K | Classification | Medium | ⭐⭐⭐⭐⭐ | -| Adverse Events | 10K | Binary | Easy | ⭐⭐⭐⭐⭐ | -| PubMed MeSH | 50K | Multi-label | Medium | ⭐⭐⭐⭐ | -| MedMCQA | 194K | MCQ | Hard | ⭐⭐⭐⭐ | -| PubMedQA | 274K | Yes/No/Maybe | Medium | ⭐⭐⭐⭐ | -| Symptomβ†’Disease | Varies | Classification | Easy | ⭐⭐⭐⭐⭐ | -| Triage | ~5K | Classification | Easy | ⭐⭐⭐⭐⭐ | - ---- - -## πŸ”— Additional Resources - -- **MIMIC-III/IV:** ICU clinical data (requires PhysioNet access) +- **MIMIC-III/IV:** ICU clinical data (PhysioNet access required) - **n2c2 Challenges:** Clinical NLP shared tasks - **i2b2:** De-identified clinical records - **ChemProt:** Chemical-protein interactions - **BC5CDR:** Chemical-disease relations - ---- - -*Generated: 2026-02-03*