Files
kubeflow-pipelines/pipelines/healthcare_training.py
Greg Hendrickson 5f554ea769 refactor: environment variable configuration for all pipeline settings
- Add config.py with dataclass-based configuration from env vars
- Remove hardcoded RunPod endpoint and credentials
- Consolidate duplicate training components into single reusable function
- Add .env.example with all configurable options
- Update README with environment variable documentation
- Add Kubernetes secrets example for production deployments
- Add timeout and error handling improvements

BREAKING: Pipeline parameters now use env vars by default.
Set RUNPOD_API_KEY, RUNPOD_ENDPOINT, S3_BUCKET, and AWS creds.
2026-02-03 20:47:27 +00:00

472 lines
16 KiB
Python

"""
Healthcare ML Training Pipelines
Multi-task training pipelines for:
- Adverse Drug Event (ADE) Classification
- Medical Triage Classification
- Symptom-to-Disease Prediction
- Drug-Drug Interaction (DDI) Classification
All use RunPod serverless GPU infrastructure.
Configuration via environment variables - see config.py for details.
Environment Variables:
RUNPOD_API_KEY - RunPod API key (required)
RUNPOD_ENDPOINT - RunPod serverless endpoint ID (required)
AWS_ACCESS_KEY_ID - AWS credentials for S3 upload
AWS_SECRET_ACCESS_KEY
AWS_SESSION_TOKEN - Optional session token for assumed roles
AWS_REGION - Default: us-east-1
S3_BUCKET - Bucket for model artifacts (required)
BASE_MODEL - HuggingFace model ID (default: Bio_ClinicalBERT)
MAX_SAMPLES - Training samples (default: 10000)
EPOCHS - Training epochs (default: 3)
BATCH_SIZE - Batch size (default: 16)
"""
import os
from kfp import dsl
from kfp import compiler
from typing import Optional
# =============================================================================
# Reusable Training Component
# =============================================================================
@dsl.component(
base_image="python:3.11-slim",
packages_to_install=["requests"]
)
def train_healthcare_model(
task: str,
runpod_api_key: str,
runpod_endpoint: str,
model_name: str,
max_samples: int,
epochs: int,
batch_size: int,
eval_split: float,
s3_bucket: str,
s3_prefix: str,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_session_token: str,
aws_region: str,
poll_interval: int,
timeout: int,
) -> dict:
"""
Generic healthcare model training component.
Submits training job to RunPod serverless GPU and polls for completion.
Trained model is uploaded to S3 by the RunPod handler.
Args:
task: Training task (ddi, ade, triage, symptom_disease)
runpod_api_key: RunPod API key
runpod_endpoint: RunPod serverless endpoint ID
model_name: HuggingFace model ID
max_samples: Maximum training samples
epochs: Training epochs
batch_size: Training batch size
eval_split: Validation split ratio
s3_bucket: S3 bucket for model output
s3_prefix: S3 key prefix for this task
aws_*: AWS credentials for S3 access
poll_interval: Seconds between status checks
timeout: Maximum training time in seconds
Returns:
Training output including metrics and S3 URI
"""
import requests
import time
api_base = os.getenv("RUNPOD_API_BASE", "https://api.runpod.ai/v2")
# Submit training job
response = requests.post(
f"{api_base}/{runpod_endpoint}/run",
headers={"Authorization": f"Bearer {runpod_api_key}"},
json={
"input": {
"task": task,
"model_name": model_name,
"max_samples": max_samples,
"epochs": epochs,
"batch_size": batch_size,
"eval_split": eval_split,
"s3_bucket": s3_bucket,
"s3_prefix": s3_prefix,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"aws_session_token": aws_session_token,
"aws_region": aws_region,
}
},
timeout=30,
)
response.raise_for_status()
job_id = response.json()["id"]
print(f"[{task}] RunPod job submitted: {job_id}")
# Poll for completion
start_time = time.time()
while True:
elapsed = time.time() - start_time
if elapsed > timeout:
raise TimeoutError(f"Training exceeded timeout of {timeout}s")
status_resp = requests.get(
f"{api_base}/{runpod_endpoint}/status/{job_id}",
headers={"Authorization": f"Bearer {runpod_api_key}"},
timeout=30,
)
status_resp.raise_for_status()
status = status_resp.json()
if status["status"] == "COMPLETED":
print(f"[{task}] Training completed in {elapsed:.0f}s")
return status.get("output", {})
elif status["status"] == "FAILED":
error = status.get("error", "Unknown error")
raise RuntimeError(f"[{task}] Training failed: {error}")
elif status["status"] in ["IN_QUEUE", "IN_PROGRESS"]:
print(f"[{task}] Status: {status['status']} ({elapsed:.0f}s elapsed)")
time.sleep(poll_interval)
# =============================================================================
# Pipeline Definitions
# =============================================================================
def _get_env(name: str, default: str = "") -> str:
"""Helper to get env var with default."""
return os.getenv(name, default)
def _get_env_int(name: str, default: int) -> int:
"""Helper to get int env var with default."""
return int(os.getenv(name, str(default)))
def _get_env_float(name: str, default: float) -> float:
"""Helper to get float env var with default."""
return float(os.getenv(name, str(default)))
@dsl.pipeline(name="ade-classification-pipeline")
def ade_classification_pipeline(
# RunPod config - from env or override
runpod_api_key: str = "",
runpod_endpoint: str = "",
# Model config
model_name: str = "",
max_samples: int = 10000,
epochs: int = 3,
batch_size: int = 16,
eval_split: float = 0.1,
# AWS config
s3_bucket: str = "",
aws_access_key_id: str = "",
aws_secret_access_key: str = "",
aws_session_token: str = "",
aws_region: str = "us-east-1",
# Runtime config
poll_interval: int = 10,
timeout: int = 3600,
):
"""
Adverse Drug Event Classification Pipeline
Trains Bio_ClinicalBERT on ADE Corpus V2 (30K samples).
Binary classification: ADE present / No ADE.
All parameters can be provided via environment variables:
- RUNPOD_API_KEY, RUNPOD_ENDPOINT
- BASE_MODEL, MAX_SAMPLES, EPOCHS, BATCH_SIZE
- S3_BUCKET, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, etc.
"""
train_healthcare_model(
task="ade",
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
max_samples=max_samples,
epochs=epochs,
batch_size=batch_size,
eval_split=eval_split,
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
s3_prefix="ade-models/bert",
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
@dsl.pipeline(name="triage-classification-pipeline")
def triage_classification_pipeline(
runpod_api_key: str = "",
runpod_endpoint: str = "",
model_name: str = "",
max_samples: int = 5000,
epochs: int = 3,
batch_size: int = 8,
eval_split: float = 0.1,
s3_bucket: str = "",
aws_access_key_id: str = "",
aws_secret_access_key: str = "",
aws_session_token: str = "",
aws_region: str = "us-east-1",
poll_interval: int = 10,
timeout: int = 3600,
):
"""
Medical Triage Classification Pipeline
Trains classifier for ER triage urgency levels.
Multi-class: Emergency, Urgent, Standard, Non-urgent.
"""
train_healthcare_model(
task="triage",
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
max_samples=max_samples,
epochs=epochs,
batch_size=batch_size,
eval_split=eval_split,
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
s3_prefix="triage-models/bert",
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
@dsl.pipeline(name="symptom-disease-classification-pipeline")
def symptom_disease_pipeline(
runpod_api_key: str = "",
runpod_endpoint: str = "",
model_name: str = "",
max_samples: int = 5000,
epochs: int = 3,
batch_size: int = 16,
eval_split: float = 0.1,
s3_bucket: str = "",
aws_access_key_id: str = "",
aws_secret_access_key: str = "",
aws_session_token: str = "",
aws_region: str = "us-east-1",
poll_interval: int = 10,
timeout: int = 3600,
):
"""
Symptom-to-Disease Classification Pipeline
Predicts disease from symptom descriptions.
Multi-class: 41 disease categories.
"""
train_healthcare_model(
task="symptom_disease",
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
max_samples=max_samples,
epochs=epochs,
batch_size=batch_size,
eval_split=eval_split,
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
s3_prefix="symptom-disease-models/bert",
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
@dsl.pipeline(name="ddi-classification-pipeline")
def ddi_classification_pipeline(
runpod_api_key: str = "",
runpod_endpoint: str = "",
model_name: str = "",
max_samples: int = 10000,
epochs: int = 3,
batch_size: int = 16,
eval_split: float = 0.1,
s3_bucket: str = "",
aws_access_key_id: str = "",
aws_secret_access_key: str = "",
aws_session_token: str = "",
aws_region: str = "us-east-1",
poll_interval: int = 10,
timeout: int = 3600,
):
"""
Drug-Drug Interaction Classification Pipeline
Trains on 176K DrugBank DDI samples.
Multi-class severity: Minor, Moderate, Major, Contraindicated.
"""
train_healthcare_model(
task="ddi",
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
max_samples=max_samples,
epochs=epochs,
batch_size=batch_size,
eval_split=eval_split,
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
s3_prefix="ddi-models/bert",
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
@dsl.pipeline(name="healthcare-multi-task-pipeline")
def healthcare_multi_task_pipeline(
runpod_api_key: str = "",
runpod_endpoint: str = "",
model_name: str = "",
s3_bucket: str = "",
aws_access_key_id: str = "",
aws_secret_access_key: str = "",
aws_session_token: str = "",
aws_region: str = "us-east-1",
poll_interval: int = 10,
timeout: int = 3600,
):
"""
Train all healthcare models in parallel.
Outputs:
- DDI classifier (s3://bucket/ddi-models/...)
- ADE classifier (s3://bucket/ade-models/...)
- Triage classifier (s3://bucket/triage-models/...)
- Symptom-Disease classifier (s3://bucket/symptom-disease-models/...)
"""
# Resolve env vars once
_runpod_key = runpod_api_key or _get_env("RUNPOD_API_KEY")
_runpod_endpoint = runpod_endpoint or _get_env("RUNPOD_ENDPOINT")
_model = model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT")
_bucket = s3_bucket or _get_env("S3_BUCKET")
_aws_key = aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID")
_aws_secret = aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY")
_aws_token = aws_session_token or _get_env("AWS_SESSION_TOKEN")
# Run all training tasks in parallel (no dependencies between them)
ddi_task = train_healthcare_model(
task="ddi",
runpod_api_key=_runpod_key,
runpod_endpoint=_runpod_endpoint,
model_name=_model,
max_samples=10000,
epochs=3,
batch_size=16,
eval_split=0.1,
s3_bucket=_bucket,
s3_prefix="ddi-models/bert",
aws_access_key_id=_aws_key,
aws_secret_access_key=_aws_secret,
aws_session_token=_aws_token,
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
ade_task = train_healthcare_model(
task="ade",
runpod_api_key=_runpod_key,
runpod_endpoint=_runpod_endpoint,
model_name=_model,
max_samples=10000,
epochs=3,
batch_size=16,
eval_split=0.1,
s3_bucket=_bucket,
s3_prefix="ade-models/bert",
aws_access_key_id=_aws_key,
aws_secret_access_key=_aws_secret,
aws_session_token=_aws_token,
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
triage_task = train_healthcare_model(
task="triage",
runpod_api_key=_runpod_key,
runpod_endpoint=_runpod_endpoint,
model_name=_model,
max_samples=5000,
epochs=3,
batch_size=8,
eval_split=0.1,
s3_bucket=_bucket,
s3_prefix="triage-models/bert",
aws_access_key_id=_aws_key,
aws_secret_access_key=_aws_secret,
aws_session_token=_aws_token,
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
symptom_task = train_healthcare_model(
task="symptom_disease",
runpod_api_key=_runpod_key,
runpod_endpoint=_runpod_endpoint,
model_name=_model,
max_samples=5000,
epochs=3,
batch_size=16,
eval_split=0.1,
s3_bucket=_bucket,
s3_prefix="symptom-disease-models/bert",
aws_access_key_id=_aws_key,
aws_secret_access_key=_aws_secret,
aws_session_token=_aws_token,
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
# =============================================================================
# Compile Pipelines
# =============================================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Compile Kubeflow pipelines")
parser.add_argument("--output-dir", default=".", help="Output directory for compiled YAML")
args = parser.parse_args()
pipelines = [
(ade_classification_pipeline, "ade_classification_pipeline.yaml"),
(triage_classification_pipeline, "triage_classification_pipeline.yaml"),
(symptom_disease_pipeline, "symptom_disease_pipeline.yaml"),
(ddi_classification_pipeline, "ddi_classification_pipeline.yaml"),
(healthcare_multi_task_pipeline, "healthcare_multi_task_pipeline.yaml"),
]
for pipeline_func, filename in pipelines:
output_path = os.path.join(args.output_dir, filename)
compiler.Compiler().compile(pipeline_func, output_path)
print(f"Compiled: {output_path}")
print(f"\n✓ All {len(pipelines)} pipelines compiled!")