mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
feat: add ADE, Triage, and Symptom-Disease training pipelines
New tasks supported: - task=ade: Adverse Drug Event classification (ADE Corpus V2, 30K samples) - task=triage: Medical Triage classification (urgency levels) - task=symptom_disease: Symptom-to-Disease prediction (40+ diseases) All use HuggingFace datasets, Bio_ClinicalBERT, and S3 model storage.
This commit is contained in:
376
pipelines/healthcare_training.py
Normal file
376
pipelines/healthcare_training.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
Healthcare ML Training Pipelines
|
||||
|
||||
Multi-task training pipelines for:
|
||||
- Adverse Drug Event (ADE) Classification
|
||||
- Medical Triage Classification
|
||||
- Symptom-to-Disease Prediction
|
||||
- Drug-Drug Interaction (DDI) Classification
|
||||
|
||||
All use RunPod serverless GPU infrastructure.
|
||||
"""
|
||||
from kfp import dsl
|
||||
from kfp import compiler
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ADE (Adverse Drug Event) Classification Pipeline
|
||||
# ============================================================================
|
||||
@dsl.component(
|
||||
base_image="python:3.11-slim",
|
||||
packages_to_install=["requests"]
|
||||
)
|
||||
def train_ade_model(
|
||||
runpod_api_key: str,
|
||||
runpod_endpoint: str,
|
||||
model_name: str,
|
||||
max_samples: int,
|
||||
epochs: int,
|
||||
batch_size: int,
|
||||
s3_bucket: str,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
aws_session_token: str,
|
||||
) -> dict:
|
||||
"""Train ADE classifier on RunPod serverless GPU."""
|
||||
import requests
|
||||
import time
|
||||
|
||||
response = requests.post(
|
||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
|
||||
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
||||
json={
|
||||
"input": {
|
||||
"task": "ade",
|
||||
"model_name": model_name,
|
||||
"max_samples": max_samples,
|
||||
"epochs": epochs,
|
||||
"batch_size": batch_size,
|
||||
"eval_split": 0.1,
|
||||
"s3_bucket": s3_bucket,
|
||||
"s3_prefix": "ade-models/bert",
|
||||
"aws_access_key_id": aws_access_key_id,
|
||||
"aws_secret_access_key": aws_secret_access_key,
|
||||
"aws_session_token": aws_session_token,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
job_id = response.json()["id"]
|
||||
print(f"RunPod job submitted: {job_id}")
|
||||
|
||||
# Poll for completion
|
||||
while True:
|
||||
status = requests.get(
|
||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
|
||||
headers={"Authorization": f"Bearer {runpod_api_key}"}
|
||||
).json()
|
||||
|
||||
if status["status"] == "COMPLETED":
|
||||
return status["output"]
|
||||
elif status["status"] == "FAILED":
|
||||
raise Exception(f"Training failed: {status}")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
@dsl.pipeline(name="ade-classification-pipeline")
|
||||
def ade_classification_pipeline(
|
||||
runpod_api_key: str,
|
||||
runpod_endpoint: str = "k57do7afav01es",
|
||||
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||
max_samples: int = 10000,
|
||||
epochs: int = 3,
|
||||
batch_size: int = 16,
|
||||
s3_bucket: str = "",
|
||||
aws_access_key_id: str = "",
|
||||
aws_secret_access_key: str = "",
|
||||
aws_session_token: str = "",
|
||||
):
|
||||
"""
|
||||
Adverse Drug Event Classification Pipeline
|
||||
|
||||
Trains Bio_ClinicalBERT on ADE Corpus V2 (30K samples)
|
||||
Binary classification: ADE present / No ADE
|
||||
"""
|
||||
train_task = train_ade_model(
|
||||
runpod_api_key=runpod_api_key,
|
||||
runpod_endpoint=runpod_endpoint,
|
||||
model_name=model_name,
|
||||
max_samples=max_samples,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
s3_bucket=s3_bucket,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Medical Triage Classification Pipeline
|
||||
# ============================================================================
|
||||
@dsl.component(
|
||||
base_image="python:3.11-slim",
|
||||
packages_to_install=["requests"]
|
||||
)
|
||||
def train_triage_model(
|
||||
runpod_api_key: str,
|
||||
runpod_endpoint: str,
|
||||
model_name: str,
|
||||
max_samples: int,
|
||||
epochs: int,
|
||||
batch_size: int,
|
||||
s3_bucket: str,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
aws_session_token: str,
|
||||
) -> dict:
|
||||
"""Train Medical Triage classifier on RunPod."""
|
||||
import requests
|
||||
import time
|
||||
|
||||
response = requests.post(
|
||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
|
||||
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
||||
json={
|
||||
"input": {
|
||||
"task": "triage",
|
||||
"model_name": model_name,
|
||||
"max_samples": max_samples,
|
||||
"epochs": epochs,
|
||||
"batch_size": batch_size,
|
||||
"eval_split": 0.1,
|
||||
"s3_bucket": s3_bucket,
|
||||
"s3_prefix": "triage-models/bert",
|
||||
"aws_access_key_id": aws_access_key_id,
|
||||
"aws_secret_access_key": aws_secret_access_key,
|
||||
"aws_session_token": aws_session_token,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
job_id = response.json()["id"]
|
||||
print(f"RunPod job submitted: {job_id}")
|
||||
|
||||
while True:
|
||||
status = requests.get(
|
||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
|
||||
headers={"Authorization": f"Bearer {runpod_api_key}"}
|
||||
).json()
|
||||
|
||||
if status["status"] == "COMPLETED":
|
||||
return status["output"]
|
||||
elif status["status"] == "FAILED":
|
||||
raise Exception(f"Training failed: {status}")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
@dsl.pipeline(name="triage-classification-pipeline")
|
||||
def triage_classification_pipeline(
|
||||
runpod_api_key: str,
|
||||
runpod_endpoint: str = "k57do7afav01es",
|
||||
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||
max_samples: int = 5000,
|
||||
epochs: int = 3,
|
||||
batch_size: int = 8,
|
||||
s3_bucket: str = "",
|
||||
aws_access_key_id: str = "",
|
||||
aws_secret_access_key: str = "",
|
||||
aws_session_token: str = "",
|
||||
):
|
||||
"""
|
||||
Medical Triage Classification Pipeline
|
||||
|
||||
Trains classifier for ER triage urgency levels.
|
||||
Multi-class: Emergency, Urgent, Standard, etc.
|
||||
"""
|
||||
train_task = train_triage_model(
|
||||
runpod_api_key=runpod_api_key,
|
||||
runpod_endpoint=runpod_endpoint,
|
||||
model_name=model_name,
|
||||
max_samples=max_samples,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
s3_bucket=s3_bucket,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Symptom-to-Disease Classification Pipeline
|
||||
# ============================================================================
|
||||
@dsl.component(
|
||||
base_image="python:3.11-slim",
|
||||
packages_to_install=["requests"]
|
||||
)
|
||||
def train_symptom_disease_model(
|
||||
runpod_api_key: str,
|
||||
runpod_endpoint: str,
|
||||
model_name: str,
|
||||
max_samples: int,
|
||||
epochs: int,
|
||||
batch_size: int,
|
||||
s3_bucket: str,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
aws_session_token: str,
|
||||
) -> dict:
|
||||
"""Train Symptom-to-Disease classifier on RunPod."""
|
||||
import requests
|
||||
import time
|
||||
|
||||
response = requests.post(
|
||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
|
||||
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
||||
json={
|
||||
"input": {
|
||||
"task": "symptom_disease",
|
||||
"model_name": model_name,
|
||||
"max_samples": max_samples,
|
||||
"epochs": epochs,
|
||||
"batch_size": batch_size,
|
||||
"eval_split": 0.1,
|
||||
"s3_bucket": s3_bucket,
|
||||
"s3_prefix": "symptom-disease-models/bert",
|
||||
"aws_access_key_id": aws_access_key_id,
|
||||
"aws_secret_access_key": aws_secret_access_key,
|
||||
"aws_session_token": aws_session_token,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
job_id = response.json()["id"]
|
||||
print(f"RunPod job submitted: {job_id}")
|
||||
|
||||
while True:
|
||||
status = requests.get(
|
||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
|
||||
headers={"Authorization": f"Bearer {runpod_api_key}"}
|
||||
).json()
|
||||
|
||||
if status["status"] == "COMPLETED":
|
||||
return status["output"]
|
||||
elif status["status"] == "FAILED":
|
||||
raise Exception(f"Training failed: {status}")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
@dsl.pipeline(name="symptom-disease-classification-pipeline")
|
||||
def symptom_disease_pipeline(
|
||||
runpod_api_key: str,
|
||||
runpod_endpoint: str = "k57do7afav01es",
|
||||
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||
max_samples: int = 5000,
|
||||
epochs: int = 3,
|
||||
batch_size: int = 16,
|
||||
s3_bucket: str = "",
|
||||
aws_access_key_id: str = "",
|
||||
aws_secret_access_key: str = "",
|
||||
aws_session_token: str = "",
|
||||
):
|
||||
"""
|
||||
Symptom-to-Disease Classification Pipeline
|
||||
|
||||
Predicts disease from symptom descriptions.
|
||||
Multi-class: 40+ disease categories
|
||||
"""
|
||||
train_task = train_symptom_disease_model(
|
||||
runpod_api_key=runpod_api_key,
|
||||
runpod_endpoint=runpod_endpoint,
|
||||
model_name=model_name,
|
||||
max_samples=max_samples,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
s3_bucket=s3_bucket,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Full Healthcare Training Pipeline (All Tasks)
|
||||
# ============================================================================
|
||||
@dsl.pipeline(name="healthcare-multi-task-pipeline")
|
||||
def healthcare_multi_task_pipeline(
|
||||
runpod_api_key: str,
|
||||
runpod_endpoint: str = "k57do7afav01es",
|
||||
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||
s3_bucket: str = "",
|
||||
aws_access_key_id: str = "",
|
||||
aws_secret_access_key: str = "",
|
||||
aws_session_token: str = "",
|
||||
):
|
||||
"""
|
||||
Train all healthcare models in parallel.
|
||||
|
||||
Outputs:
|
||||
- ADE classifier (s3://bucket/ade-models/...)
|
||||
- Triage classifier (s3://bucket/triage-models/...)
|
||||
- Symptom-Disease classifier (s3://bucket/symptom-disease-models/...)
|
||||
"""
|
||||
# Run all training tasks in parallel
|
||||
ade_task = train_ade_model(
|
||||
runpod_api_key=runpod_api_key,
|
||||
runpod_endpoint=runpod_endpoint,
|
||||
model_name=model_name,
|
||||
max_samples=10000,
|
||||
epochs=3,
|
||||
batch_size=16,
|
||||
s3_bucket=s3_bucket,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
)
|
||||
|
||||
triage_task = train_triage_model(
|
||||
runpod_api_key=runpod_api_key,
|
||||
runpod_endpoint=runpod_endpoint,
|
||||
model_name=model_name,
|
||||
max_samples=5000,
|
||||
epochs=3,
|
||||
batch_size=8,
|
||||
s3_bucket=s3_bucket,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
)
|
||||
|
||||
symptom_task = train_symptom_disease_model(
|
||||
runpod_api_key=runpod_api_key,
|
||||
runpod_endpoint=runpod_endpoint,
|
||||
model_name=model_name,
|
||||
max_samples=5000,
|
||||
epochs=3,
|
||||
batch_size=16,
|
||||
s3_bucket=s3_bucket,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Compile pipelines
|
||||
compiler.Compiler().compile(
|
||||
ade_classification_pipeline,
|
||||
"ade_classification_pipeline.yaml"
|
||||
)
|
||||
compiler.Compiler().compile(
|
||||
triage_classification_pipeline,
|
||||
"triage_classification_pipeline.yaml"
|
||||
)
|
||||
compiler.Compiler().compile(
|
||||
symptom_disease_pipeline,
|
||||
"symptom_disease_pipeline.yaml"
|
||||
)
|
||||
compiler.Compiler().compile(
|
||||
healthcare_multi_task_pipeline,
|
||||
"healthcare_multi_task_pipeline.yaml"
|
||||
)
|
||||
print("All pipelines compiled!")
|
||||
Reference in New Issue
Block a user