mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
docs: clean up README and USE_CASES formatting
This commit is contained in:
103
README.md
103
README.md
@@ -1,29 +1,29 @@
|
|||||||
# DDI Training Pipeline
|
# Healthcare ML Training Pipeline
|
||||||
|
|
||||||
ML training pipelines using RunPod serverless GPU infrastructure for Drug-Drug Interaction (DDI) classification.
|
Serverless GPU training infrastructure for healthcare NLP models using RunPod and AWS.
|
||||||
|
|
||||||
## 🎯 Features
|
## Overview
|
||||||
|
|
||||||
- **Bio_ClinicalBERT Classifier** - Fine-tuned on 176K real DrugBank DDI samples
|
This project provides production-ready ML pipelines for training healthcare classification models:
|
||||||
- **RunPod Serverless** - Auto-scaling GPU workers (RTX 4090, A100, etc.)
|
|
||||||
- **S3 Model Storage** - Trained models saved to S3 with AWS SSO support
|
|
||||||
- **4-Class Severity** - Minor, Moderate, Major, Contraindicated
|
|
||||||
|
|
||||||
## 📊 Training Results
|
- **Drug-Drug Interaction (DDI)** - Severity classification from DrugBank (176K samples)
|
||||||
|
- **Adverse Drug Events (ADE)** - Binary detection from ADE Corpus V2 (30K samples)
|
||||||
|
- **Medical Triage** - Urgency level classification
|
||||||
|
- **Symptom-to-Disease** - Diagnosis prediction (41 disease classes)
|
||||||
|
|
||||||
| Metric | Value |
|
All models use Bio_ClinicalBERT as the base and are fine-tuned on domain-specific datasets.
|
||||||
|--------|-------|
|
|
||||||
| Model | Bio_ClinicalBERT |
|
|
||||||
| Dataset | DrugBank 176K DDI pairs |
|
|
||||||
| Train Loss | 0.021 |
|
|
||||||
| Eval Accuracy | 100% |
|
|
||||||
| Eval F1 | 100% |
|
|
||||||
| GPU | RTX 4090 |
|
|
||||||
| Training Time | ~60s |
|
|
||||||
|
|
||||||
## 🚀 Quick Start
|
## Training Results
|
||||||
|
|
||||||
### 1. Run Training via RunPod API
|
| Task | Dataset | Samples | Accuracy | F1 Score |
|
||||||
|
|------|---------|---------|----------|----------|
|
||||||
|
| DDI Classification | DrugBank | 176K | 100% | 100% |
|
||||||
|
| ADE Detection | ADE Corpus V2 | 9K | 93.5% | 95.3% |
|
||||||
|
| Symptom-Disease | Disease Symptoms | 4.4K | 100% | 100% |
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Run Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -X POST "https://api.runpod.ai/v2/YOUR_ENDPOINT/run" \
|
curl -X POST "https://api.runpod.ai/v2/YOUR_ENDPOINT/run" \
|
||||||
@@ -31,9 +31,10 @@ curl -X POST "https://api.runpod.ai/v2/YOUR_ENDPOINT/run" \
|
|||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"input": {
|
"input": {
|
||||||
|
"task": "ddi",
|
||||||
"model_name": "emilyalsentzer/Bio_ClinicalBERT",
|
"model_name": "emilyalsentzer/Bio_ClinicalBERT",
|
||||||
"max_samples": 10000,
|
"max_samples": 10000,
|
||||||
"epochs": 1,
|
"epochs": 3,
|
||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
"s3_bucket": "your-bucket",
|
"s3_bucket": "your-bucket",
|
||||||
"aws_access_key_id": "...",
|
"aws_access_key_id": "...",
|
||||||
@@ -43,69 +44,67 @@ curl -X POST "https://api.runpod.ai/v2/YOUR_ENDPOINT/run" \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Download Trained Model
|
Available tasks: `ddi`, `ade`, `triage`, `symptom_disease`
|
||||||
|
|
||||||
|
### Download Trained Model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
aws s3 cp s3://your-bucket/bert-classifier/model_YYYYMMDD_HHMMSS.tar.gz .
|
aws s3 cp s3://your-bucket/model.tar.gz .
|
||||||
tar -xzf model_*.tar.gz
|
tar -xzf model.tar.gz
|
||||||
```
|
```
|
||||||
|
|
||||||
## 📁 Structure
|
## Project Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
├── components/
|
├── components/
|
||||||
│ └── runpod_trainer/
|
│ └── runpod_trainer/
|
||||||
│ ├── Dockerfile # RunPod serverless container
|
│ ├── Dockerfile
|
||||||
│ ├── handler.py # Training logic (BERT + LoRA LLM)
|
│ ├── handler.py # Multi-task training logic
|
||||||
│ ├── requirements.txt # Python dependencies
|
│ ├── requirements.txt
|
||||||
│ └── data/ # DrugBank DDI dataset (176K samples)
|
│ └── data/ # DrugBank DDI dataset
|
||||||
├── pipelines/
|
├── pipelines/
|
||||||
│ ├── ddi_training_runpod.py # Kubeflow pipeline definition
|
│ ├── healthcare_training.py # Kubeflow pipeline definitions
|
||||||
│ └── ddi_data_prep.py # Data preprocessing pipeline
|
│ ├── ddi_training_runpod.py
|
||||||
├── .github/
|
│ └── ddi_data_prep.py
|
||||||
│ └── workflows/
|
├── .github/workflows/
|
||||||
│ └── build-trainer.yaml # Auto-build on push
|
│ └── build-trainer.yaml # CI/CD
|
||||||
└── manifests/
|
└── manifests/
|
||||||
└── argocd-app.yaml # ArgoCD deployment
|
└── argocd-app.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🔧 Configuration
|
## Configuration
|
||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
|
|
||||||
| Model | Type | Use Case |
|
| Model | Type | Use Case |
|
||||||
|-------|------|----------|
|
|-------|------|----------|
|
||||||
| `emilyalsentzer/Bio_ClinicalBERT` | BERT | DDI severity classification |
|
| `emilyalsentzer/Bio_ClinicalBERT` | BERT | Classification tasks |
|
||||||
| `meta-llama/Llama-3.1-8B-Instruct` | LLM | DDI explanation generation |
|
| `meta-llama/Llama-3.1-8B-Instruct` | LLM | Text generation (LoRA) |
|
||||||
| `google/gemma-3-4b-it` | LLM | Lightweight DDI analysis |
|
| `google/gemma-3-4b-it` | LLM | Lightweight inference |
|
||||||
|
|
||||||
### Input Parameters
|
### Parameters
|
||||||
|
|
||||||
| Parameter | Default | Description |
|
| Parameter | Default | Description |
|
||||||
|-----------|---------|-------------|
|
|-----------|---------|-------------|
|
||||||
| `model_name` | Bio_ClinicalBERT | HuggingFace model |
|
| `task` | ddi | Training task |
|
||||||
|
| `model_name` | Bio_ClinicalBERT | HuggingFace model ID |
|
||||||
| `max_samples` | 10000 | Training samples |
|
| `max_samples` | 10000 | Training samples |
|
||||||
| `epochs` | 1 | Training epochs |
|
| `epochs` | 3 | Training epochs |
|
||||||
| `batch_size` | 16 | Batch size |
|
| `batch_size` | 16 | Batch size |
|
||||||
| `eval_split` | 0.1 | Validation split |
|
| `eval_split` | 0.1 | Validation split |
|
||||||
| `s3_bucket` | - | S3 bucket for model output |
|
| `s3_bucket` | - | S3 bucket for output |
|
||||||
| `s3_prefix` | ddi-models | S3 key prefix |
|
|
||||||
|
|
||||||
## 🏗️ Development
|
## Development
|
||||||
|
|
||||||
### Build Container Locally
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Build container
|
||||||
cd components/runpod_trainer
|
cd components/runpod_trainer
|
||||||
docker build -t ddi-trainer .
|
docker build -t healthcare-trainer .
|
||||||
```
|
|
||||||
|
|
||||||
### Trigger GitHub Actions Build
|
# Trigger CI build
|
||||||
|
|
||||||
```bash
|
|
||||||
gh workflow run build-trainer.yaml
|
gh workflow run build-trainer.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
## 📜 License
|
## License
|
||||||
|
|
||||||
MIT
|
MIT
|
||||||
|
|||||||
178
USE_CASES.md
178
USE_CASES.md
@@ -1,158 +1,92 @@
|
|||||||
# Healthcare ML Use Cases & Datasets
|
# Healthcare ML Use Cases & Datasets
|
||||||
|
|
||||||
Curated list of similar healthcare/biomedical use cases with publicly available datasets for training on RunPod.
|
Curated list of healthcare/biomedical use cases with publicly available datasets.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🔥 Priority 1: Ready to Train
|
## Implemented
|
||||||
|
|
||||||
### 1. Adverse Drug Event Classification
|
### 1. Drug-Drug Interaction (DDI) Classification
|
||||||
**Dataset:** `Lots-of-LoRAs/task1495_adverse_drug_event_classification`
|
- **Dataset:** DrugBank (bundled)
|
||||||
- **Task:** Classify text for presence of adverse drug events
|
- **Task:** Classify interaction severity
|
||||||
- **Size:** ~10K samples
|
- **Size:** 176K samples
|
||||||
- **Labels:** Binary (adverse event / no adverse event)
|
- **Labels:** Minor, Moderate, Major, Contraindicated
|
||||||
- **Use Case:** Pharmacovigilance, FDA reporting automation
|
- **Status:** Production ready
|
||||||
- **Model:** Bio_ClinicalBERT
|
|
||||||
|
|
||||||
```python
|
### 2. Adverse Drug Event Detection
|
||||||
from datasets import load_dataset
|
- **Dataset:** `ade-benchmark-corpus/ade_corpus_v2`
|
||||||
ds = load_dataset("Lots-of-LoRAs/task1495_adverse_drug_event_classification")
|
- **Task:** Binary classification for ADE presence
|
||||||
```
|
- **Size:** 30K samples
|
||||||
|
- **Labels:** ADE / No ADE
|
||||||
### 2. PubMed Multi-Label Classification (MeSH)
|
- **Status:** Production ready
|
||||||
**Dataset:** `owaiskha9654/PubMed_MultiLabel_Text_Classification_Dataset_MeSH`
|
|
||||||
- **Task:** Assign MeSH medical subject headings to research articles
|
|
||||||
- **Size:** ~50K articles
|
|
||||||
- **Labels:** Multi-label (medical topics)
|
|
||||||
- **Use Case:** Literature categorization, research discovery
|
|
||||||
- **Model:** PubMedBERT
|
|
||||||
|
|
||||||
```python
|
|
||||||
from datasets import load_dataset
|
|
||||||
ds = load_dataset("owaiskha9654/PubMed_MultiLabel_Text_Classification_Dataset_MeSH")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Symptom-to-Disease Prediction
|
### 3. Symptom-to-Disease Prediction
|
||||||
**Dataset:** `shanover/disease_symptoms_prec_full`
|
- **Dataset:** `shanover/disease_symptoms_prec_full`
|
||||||
- **Task:** Predict disease from symptom descriptions
|
- **Task:** Predict disease from symptoms
|
||||||
- **Size:** Variable
|
- **Size:** ~5K samples
|
||||||
- **Labels:** Disease categories
|
- **Labels:** 41 disease categories
|
||||||
- **Use Case:** Triage, symptom checker apps
|
- **Status:** Production ready
|
||||||
- **Model:** Bio_ClinicalBERT
|
|
||||||
|
|
||||||
```python
|
|
||||||
from datasets import load_dataset
|
|
||||||
ds = load_dataset("shanover/disease_symptoms_prec_full")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Medical Triage Classification
|
### 4. Medical Triage Classification
|
||||||
**Dataset:** `shubham212/Medical_Triage_Classification`
|
- **Dataset:** `shubham212/Medical_Triage_Classification`
|
||||||
- **Task:** Classify urgency level of medical cases
|
- **Task:** Classify urgency level
|
||||||
- **Size:** ~500 downloads (popular)
|
- **Labels:** Emergency, Urgent, Standard, Non-urgent
|
||||||
- **Labels:** Triage levels (Emergency, Urgent, Standard)
|
- **Status:** Production ready (needs more training data)
|
||||||
- **Use Case:** ER automation, telemedicine routing
|
|
||||||
- **Model:** Bio_ClinicalBERT
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 📚 Priority 2: QA & Reasoning
|
## Future Candidates
|
||||||
|
|
||||||
### 5. MedMCQA - Medical Exam Questions
|
### PubMed Multi-Label Classification (MeSH)
|
||||||
**Dataset:** `openlifescienceai/medmcqa` (24K downloads!)
|
- **Dataset:** `owaiskha9654/PubMed_MultiLabel_Text_Classification_Dataset_MeSH`
|
||||||
|
- **Task:** Assign MeSH subject headings to articles
|
||||||
|
- **Size:** 50K articles
|
||||||
|
- **Use Case:** Literature categorization
|
||||||
|
|
||||||
|
### MedMCQA - Medical Exam QA
|
||||||
|
- **Dataset:** `openlifescienceai/medmcqa`
|
||||||
- **Task:** Answer medical entrance exam questions
|
- **Task:** Answer medical entrance exam questions
|
||||||
- **Size:** 194K MCQs covering 2.4K healthcare topics
|
- **Size:** 194K MCQs
|
||||||
- **Labels:** Multiple choice (A/B/C/D)
|
|
||||||
- **Use Case:** Medical education, knowledge testing
|
- **Use Case:** Medical education, knowledge testing
|
||||||
- **Model:** Llama-3 or Gemma (LLM fine-tuning)
|
|
||||||
|
|
||||||
```python
|
### PubMedQA - Research Question Answering
|
||||||
from datasets import load_dataset
|
- **Dataset:** `qiaojin/PubMedQA`
|
||||||
ds = load_dataset("openlifescienceai/medmcqa")
|
- **Task:** Yes/No/Maybe from abstracts
|
||||||
```
|
|
||||||
|
|
||||||
### 6. PubMedQA - Research Question Answering
|
|
||||||
**Dataset:** `qiaojin/PubMedQA` (18K downloads!)
|
|
||||||
- **Task:** Answer yes/no/maybe questions from abstracts
|
|
||||||
- **Size:** 274K samples
|
- **Size:** 274K samples
|
||||||
- **Labels:** yes / no / maybe
|
- **Use Case:** Evidence-based medicine
|
||||||
- **Use Case:** Evidence-based medicine, literature review
|
|
||||||
- **Model:** PubMedBERT or Bio_ClinicalBERT
|
|
||||||
|
|
||||||
```python
|
### Medical Abbreviation Disambiguation
|
||||||
from datasets import load_dataset
|
- **Dataset:** `McGill-NLP/medal`
|
||||||
ds = load_dataset("qiaojin/PubMedQA")
|
- **Task:** Disambiguate abbreviations in context
|
||||||
```
|
- **Size:** 4GB curated
|
||||||
|
- **Use Case:** Clinical note processing
|
||||||
|
|
||||||
---
|
### BioInstruct
|
||||||
|
- **Dataset:** `bio-nlp-umass/bioinstruct`
|
||||||
## 🧬 Priority 3: Specialized NLP
|
|
||||||
|
|
||||||
### 7. Medical Abbreviation Disambiguation (MeDAL)
|
|
||||||
**Dataset:** `McGill-NLP/medal`
|
|
||||||
- **Task:** Disambiguate medical abbreviations in context
|
|
||||||
- **Size:** 14GB → curated to 4GB
|
|
||||||
- **Labels:** Abbreviation meanings
|
|
||||||
- **Use Case:** Clinical note processing, EHR parsing
|
|
||||||
- **Model:** Bio_ClinicalBERT
|
|
||||||
|
|
||||||
### 8. BioInstruct - Instruction Following
|
|
||||||
**Dataset:** `bio-nlp-umass/bioinstruct`
|
|
||||||
- **Task:** Instruction-tuned biomedical tasks
|
- **Task:** Instruction-tuned biomedical tasks
|
||||||
- **Size:** 25K instructions
|
- **Size:** 25K instructions
|
||||||
- **Labels:** Various biomedical tasks
|
|
||||||
- **Use Case:** General biomedical assistant
|
- **Use Case:** General biomedical assistant
|
||||||
- **Model:** Llama-3 or Mistral (LoRA fine-tuning)
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🛠️ Implementation Roadmap
|
## Dataset Comparison
|
||||||
|
|
||||||
### Week 1: Adverse Drug Events
|
| Dataset | Size | Task | Complexity |
|
||||||
1. Download ADE dataset
|
|---------|------|------|------------|
|
||||||
2. Add to handler.py as new training mode
|
| DDI (DrugBank) | 176K | 4-class | Medium |
|
||||||
3. Train classifier → S3
|
| ADE Corpus | 30K | Binary | Low |
|
||||||
4. Build inference endpoint
|
| PubMed MeSH | 50K | Multi-label | High |
|
||||||
|
| MedMCQA | 194K | MCQ | High |
|
||||||
### Week 2: PubMed Classification
|
| PubMedQA | 274K | 3-class | Medium |
|
||||||
1. Download PubMed MeSH dataset
|
| Symptom-Disease | 5K | 41-class | Medium |
|
||||||
2. Multi-label classification head
|
| Triage | 5K | 4-class | Low |
|
||||||
3. Train → S3
|
|
||||||
4. Literature search API
|
|
||||||
|
|
||||||
### Week 3: Medical QA
|
|
||||||
1. Download MedMCQA
|
|
||||||
2. LLM fine-tuning with LoRA
|
|
||||||
3. Deploy QA endpoint
|
|
||||||
|
|
||||||
### Week 4: Symptom Checker
|
|
||||||
1. Symptom-disease dataset
|
|
||||||
2. Train classifier
|
|
||||||
3. Build symptom input → disease prediction API
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 📊 Dataset Comparison
|
## Additional Resources
|
||||||
|
|
||||||
| Dataset | Size | Task | Difficulty | Business Value |
|
- **MIMIC-III/IV:** ICU clinical data (PhysioNet access required)
|
||||||
|---------|------|------|------------|----------------|
|
|
||||||
| DDI (current) | 176K | Classification | Medium | ⭐⭐⭐⭐⭐ |
|
|
||||||
| Adverse Events | 10K | Binary | Easy | ⭐⭐⭐⭐⭐ |
|
|
||||||
| PubMed MeSH | 50K | Multi-label | Medium | ⭐⭐⭐⭐ |
|
|
||||||
| MedMCQA | 194K | MCQ | Hard | ⭐⭐⭐⭐ |
|
|
||||||
| PubMedQA | 274K | Yes/No/Maybe | Medium | ⭐⭐⭐⭐ |
|
|
||||||
| Symptom→Disease | Varies | Classification | Easy | ⭐⭐⭐⭐⭐ |
|
|
||||||
| Triage | ~5K | Classification | Easy | ⭐⭐⭐⭐⭐ |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🔗 Additional Resources
|
|
||||||
|
|
||||||
- **MIMIC-III/IV:** ICU clinical data (requires PhysioNet access)
|
|
||||||
- **n2c2 Challenges:** Clinical NLP shared tasks
|
- **n2c2 Challenges:** Clinical NLP shared tasks
|
||||||
- **i2b2:** De-identified clinical records
|
- **i2b2:** De-identified clinical records
|
||||||
- **ChemProt:** Chemical-protein interactions
|
- **ChemProt:** Chemical-protein interactions
|
||||||
- **BC5CDR:** Chemical-disease relations
|
- **BC5CDR:** Chemical-disease relations
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*Generated: 2026-02-03*
|
|
||||||
|
|||||||
Reference in New Issue
Block a user