mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
refactor: environment variable configuration for all pipeline settings
- Add config.py with dataclass-based configuration from env vars - Remove hardcoded RunPod endpoint and credentials - Consolidate duplicate training components into single reusable function - Add .env.example with all configurable options - Update README with environment variable documentation - Add Kubernetes secrets example for production deployments - Add timeout and error handling improvements BREAKING: Pipeline parameters now use env vars by default. Set RUNPOD_API_KEY, RUNPOD_ENDPOINT, S3_BUCKET, and AWS creds.
This commit is contained in:
37
.env.example
Normal file
37
.env.example
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# =============================================================================
|
||||||
|
# Healthcare ML Pipeline Configuration
|
||||||
|
# =============================================================================
|
||||||
|
# Copy this file to .env and fill in your values.
|
||||||
|
# DO NOT commit .env to version control!
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# RunPod Configuration (Required)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
RUNPOD_API_KEY=your_runpod_api_key_here
|
||||||
|
RUNPOD_ENDPOINT=your_endpoint_id_here
|
||||||
|
RUNPOD_API_BASE=https://api.runpod.ai/v2
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# AWS Configuration (Required for model storage)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
AWS_ACCESS_KEY_ID=
|
||||||
|
AWS_SECRET_ACCESS_KEY=
|
||||||
|
AWS_SESSION_TOKEN= # Optional - for assumed role sessions
|
||||||
|
AWS_REGION=us-east-1
|
||||||
|
S3_BUCKET=your-model-bucket
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Model Training Defaults (Optional - sensible defaults provided)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
BASE_MODEL=emilyalsentzer/Bio_ClinicalBERT
|
||||||
|
MAX_SAMPLES=10000
|
||||||
|
EPOCHS=3
|
||||||
|
BATCH_SIZE=16
|
||||||
|
EVAL_SPLIT=0.1
|
||||||
|
LEARNING_RATE=2e-5
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Pipeline Runtime Settings (Optional)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
POLL_INTERVAL_SECONDS=10 # How often to check training status
|
||||||
|
TRAINING_TIMEOUT_SECONDS=3600 # Max training time (1 hour default)
|
||||||
39
README.md
39
README.md
@@ -74,6 +74,45 @@ tar -xzf model.tar.gz
|
|||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
|
All configuration is via environment variables. Copy `.env.example` to `.env` and fill in your values:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your credentials
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
| Variable | Required | Default | Description |
|
||||||
|
|----------|----------|---------|-------------|
|
||||||
|
| `RUNPOD_API_KEY` | Yes | - | RunPod API key |
|
||||||
|
| `RUNPOD_ENDPOINT` | Yes | - | RunPod serverless endpoint ID |
|
||||||
|
| `AWS_ACCESS_KEY_ID` | Yes | - | AWS credentials for S3 |
|
||||||
|
| `AWS_SECRET_ACCESS_KEY` | Yes | - | AWS credentials for S3 |
|
||||||
|
| `AWS_SESSION_TOKEN` | No | - | For assumed role sessions |
|
||||||
|
| `AWS_REGION` | No | us-east-1 | AWS region |
|
||||||
|
| `S3_BUCKET` | Yes | - | Bucket for model artifacts |
|
||||||
|
| `BASE_MODEL` | No | Bio_ClinicalBERT | HuggingFace model ID |
|
||||||
|
| `MAX_SAMPLES` | No | 10000 | Training samples |
|
||||||
|
| `EPOCHS` | No | 3 | Training epochs |
|
||||||
|
| `BATCH_SIZE` | No | 16 | Batch size |
|
||||||
|
|
||||||
|
### Kubernetes Secrets (Recommended)
|
||||||
|
|
||||||
|
For production, use Kubernetes secrets instead of environment variables:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Secret
|
||||||
|
metadata:
|
||||||
|
name: ml-pipeline-secrets
|
||||||
|
type: Opaque
|
||||||
|
stringData:
|
||||||
|
RUNPOD_API_KEY: "your-key"
|
||||||
|
AWS_ACCESS_KEY_ID: "your-key"
|
||||||
|
AWS_SECRET_ACCESS_KEY: "your-secret"
|
||||||
|
```
|
||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
|
|
||||||
| Model | Type | Use Case |
|
| Model | Type | Use Case |
|
||||||
|
|||||||
93
pipelines/config.py
Normal file
93
pipelines/config.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""
|
||||||
|
Pipeline Configuration
|
||||||
|
|
||||||
|
All configuration loaded from environment variables with sensible defaults.
|
||||||
|
Secrets should be provided via Kubernetes secrets, not hardcoded.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunPodConfig:
|
||||||
|
"""RunPod API configuration."""
|
||||||
|
api_key: str = os.getenv("RUNPOD_API_KEY", "")
|
||||||
|
endpoint: str = os.getenv("RUNPOD_ENDPOINT", "")
|
||||||
|
api_base: str = os.getenv("RUNPOD_API_BASE", "https://api.runpod.ai/v2")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AWSConfig:
|
||||||
|
"""AWS credentials and settings."""
|
||||||
|
access_key_id: str = os.getenv("AWS_ACCESS_KEY_ID", "")
|
||||||
|
secret_access_key: str = os.getenv("AWS_SECRET_ACCESS_KEY", "")
|
||||||
|
session_token: str = os.getenv("AWS_SESSION_TOKEN", "")
|
||||||
|
region: str = os.getenv("AWS_REGION", "us-east-1")
|
||||||
|
s3_bucket: str = os.getenv("S3_BUCKET", "")
|
||||||
|
s3_prefix: str = os.getenv("S3_PREFIX", "models")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
"""Model training defaults."""
|
||||||
|
base_model: str = os.getenv("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT")
|
||||||
|
max_samples: int = int(os.getenv("MAX_SAMPLES", "10000"))
|
||||||
|
epochs: int = int(os.getenv("EPOCHS", "3"))
|
||||||
|
batch_size: int = int(os.getenv("BATCH_SIZE", "16"))
|
||||||
|
eval_split: float = float(os.getenv("EVAL_SPLIT", "0.1"))
|
||||||
|
learning_rate: float = float(os.getenv("LEARNING_RATE", "2e-5"))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineConfig:
|
||||||
|
"""Combined pipeline configuration."""
|
||||||
|
runpod: RunPodConfig
|
||||||
|
aws: AWSConfig
|
||||||
|
model: ModelConfig
|
||||||
|
|
||||||
|
# Pipeline settings
|
||||||
|
poll_interval: int = int(os.getenv("POLL_INTERVAL_SECONDS", "10"))
|
||||||
|
timeout: int = int(os.getenv("TRAINING_TIMEOUT_SECONDS", "3600"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "PipelineConfig":
|
||||||
|
"""Load configuration from environment variables."""
|
||||||
|
return cls(
|
||||||
|
runpod=RunPodConfig(),
|
||||||
|
aws=AWSConfig(),
|
||||||
|
model=ModelConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Task-specific defaults (can override base config)
|
||||||
|
TASK_DEFAULTS = {
|
||||||
|
"ddi": {
|
||||||
|
"max_samples": 10000,
|
||||||
|
"batch_size": 16,
|
||||||
|
"s3_prefix": "ddi-models",
|
||||||
|
},
|
||||||
|
"ade": {
|
||||||
|
"max_samples": 10000,
|
||||||
|
"batch_size": 16,
|
||||||
|
"s3_prefix": "ade-models",
|
||||||
|
},
|
||||||
|
"triage": {
|
||||||
|
"max_samples": 5000,
|
||||||
|
"batch_size": 8,
|
||||||
|
"s3_prefix": "triage-models",
|
||||||
|
},
|
||||||
|
"symptom_disease": {
|
||||||
|
"max_samples": 5000,
|
||||||
|
"batch_size": 16,
|
||||||
|
"s3_prefix": "symptom-disease-models",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_config(task: str, overrides: Optional[dict] = None) -> dict:
|
||||||
|
"""Get task-specific configuration with optional overrides."""
|
||||||
|
config = TASK_DEFAULTS.get(task, {}).copy()
|
||||||
|
if overrides:
|
||||||
|
config.update(overrides)
|
||||||
|
return config
|
||||||
@@ -8,369 +8,464 @@ Multi-task training pipelines for:
|
|||||||
- Drug-Drug Interaction (DDI) Classification
|
- Drug-Drug Interaction (DDI) Classification
|
||||||
|
|
||||||
All use RunPod serverless GPU infrastructure.
|
All use RunPod serverless GPU infrastructure.
|
||||||
|
Configuration via environment variables - see config.py for details.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
RUNPOD_API_KEY - RunPod API key (required)
|
||||||
|
RUNPOD_ENDPOINT - RunPod serverless endpoint ID (required)
|
||||||
|
AWS_ACCESS_KEY_ID - AWS credentials for S3 upload
|
||||||
|
AWS_SECRET_ACCESS_KEY
|
||||||
|
AWS_SESSION_TOKEN - Optional session token for assumed roles
|
||||||
|
AWS_REGION - Default: us-east-1
|
||||||
|
S3_BUCKET - Bucket for model artifacts (required)
|
||||||
|
BASE_MODEL - HuggingFace model ID (default: Bio_ClinicalBERT)
|
||||||
|
MAX_SAMPLES - Training samples (default: 10000)
|
||||||
|
EPOCHS - Training epochs (default: 3)
|
||||||
|
BATCH_SIZE - Batch size (default: 16)
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
from kfp import dsl
|
from kfp import dsl
|
||||||
from kfp import compiler
|
from kfp import compiler
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# =============================================================================
|
||||||
# ADE (Adverse Drug Event) Classification Pipeline
|
# Reusable Training Component
|
||||||
# ============================================================================
|
# =============================================================================
|
||||||
@dsl.component(
|
@dsl.component(
|
||||||
base_image="python:3.11-slim",
|
base_image="python:3.11-slim",
|
||||||
packages_to_install=["requests"]
|
packages_to_install=["requests"]
|
||||||
)
|
)
|
||||||
def train_ade_model(
|
def train_healthcare_model(
|
||||||
|
task: str,
|
||||||
runpod_api_key: str,
|
runpod_api_key: str,
|
||||||
runpod_endpoint: str,
|
runpod_endpoint: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
max_samples: int,
|
max_samples: int,
|
||||||
epochs: int,
|
epochs: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
eval_split: float,
|
||||||
s3_bucket: str,
|
s3_bucket: str,
|
||||||
|
s3_prefix: str,
|
||||||
aws_access_key_id: str,
|
aws_access_key_id: str,
|
||||||
aws_secret_access_key: str,
|
aws_secret_access_key: str,
|
||||||
aws_session_token: str,
|
aws_session_token: str,
|
||||||
|
aws_region: str,
|
||||||
|
poll_interval: int,
|
||||||
|
timeout: int,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Train ADE classifier on RunPod serverless GPU."""
|
"""
|
||||||
|
Generic healthcare model training component.
|
||||||
|
|
||||||
|
Submits training job to RunPod serverless GPU and polls for completion.
|
||||||
|
Trained model is uploaded to S3 by the RunPod handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Training task (ddi, ade, triage, symptom_disease)
|
||||||
|
runpod_api_key: RunPod API key
|
||||||
|
runpod_endpoint: RunPod serverless endpoint ID
|
||||||
|
model_name: HuggingFace model ID
|
||||||
|
max_samples: Maximum training samples
|
||||||
|
epochs: Training epochs
|
||||||
|
batch_size: Training batch size
|
||||||
|
eval_split: Validation split ratio
|
||||||
|
s3_bucket: S3 bucket for model output
|
||||||
|
s3_prefix: S3 key prefix for this task
|
||||||
|
aws_*: AWS credentials for S3 access
|
||||||
|
poll_interval: Seconds between status checks
|
||||||
|
timeout: Maximum training time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Training output including metrics and S3 URI
|
||||||
|
"""
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
api_base = os.getenv("RUNPOD_API_BASE", "https://api.runpod.ai/v2")
|
||||||
|
|
||||||
|
# Submit training job
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
|
f"{api_base}/{runpod_endpoint}/run",
|
||||||
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
||||||
json={
|
json={
|
||||||
"input": {
|
"input": {
|
||||||
"task": "ade",
|
"task": task,
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"max_samples": max_samples,
|
"max_samples": max_samples,
|
||||||
"epochs": epochs,
|
"epochs": epochs,
|
||||||
"batch_size": batch_size,
|
"batch_size": batch_size,
|
||||||
"eval_split": 0.1,
|
"eval_split": eval_split,
|
||||||
"s3_bucket": s3_bucket,
|
"s3_bucket": s3_bucket,
|
||||||
"s3_prefix": "ade-models/bert",
|
"s3_prefix": s3_prefix,
|
||||||
"aws_access_key_id": aws_access_key_id,
|
"aws_access_key_id": aws_access_key_id,
|
||||||
"aws_secret_access_key": aws_secret_access_key,
|
"aws_secret_access_key": aws_secret_access_key,
|
||||||
"aws_session_token": aws_session_token,
|
"aws_session_token": aws_session_token,
|
||||||
|
"aws_region": aws_region,
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
job_id = response.json()["id"]
|
job_id = response.json()["id"]
|
||||||
print(f"RunPod job submitted: {job_id}")
|
print(f"[{task}] RunPod job submitted: {job_id}")
|
||||||
|
|
||||||
# Poll for completion
|
# Poll for completion
|
||||||
|
start_time = time.time()
|
||||||
while True:
|
while True:
|
||||||
status = requests.get(
|
elapsed = time.time() - start_time
|
||||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
|
if elapsed > timeout:
|
||||||
headers={"Authorization": f"Bearer {runpod_api_key}"}
|
raise TimeoutError(f"Training exceeded timeout of {timeout}s")
|
||||||
).json()
|
|
||||||
|
status_resp = requests.get(
|
||||||
|
f"{api_base}/{runpod_endpoint}/status/{job_id}",
|
||||||
|
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
status_resp.raise_for_status()
|
||||||
|
status = status_resp.json()
|
||||||
|
|
||||||
if status["status"] == "COMPLETED":
|
if status["status"] == "COMPLETED":
|
||||||
return status["output"]
|
print(f"[{task}] Training completed in {elapsed:.0f}s")
|
||||||
|
return status.get("output", {})
|
||||||
elif status["status"] == "FAILED":
|
elif status["status"] == "FAILED":
|
||||||
raise Exception(f"Training failed: {status}")
|
error = status.get("error", "Unknown error")
|
||||||
|
raise RuntimeError(f"[{task}] Training failed: {error}")
|
||||||
|
elif status["status"] in ["IN_QUEUE", "IN_PROGRESS"]:
|
||||||
|
print(f"[{task}] Status: {status['status']} ({elapsed:.0f}s elapsed)")
|
||||||
|
|
||||||
time.sleep(10)
|
time.sleep(poll_interval)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Pipeline Definitions
|
||||||
|
# =============================================================================
|
||||||
|
def _get_env(name: str, default: str = "") -> str:
|
||||||
|
"""Helper to get env var with default."""
|
||||||
|
return os.getenv(name, default)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_env_int(name: str, default: int) -> int:
|
||||||
|
"""Helper to get int env var with default."""
|
||||||
|
return int(os.getenv(name, str(default)))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_env_float(name: str, default: float) -> float:
|
||||||
|
"""Helper to get float env var with default."""
|
||||||
|
return float(os.getenv(name, str(default)))
|
||||||
|
|
||||||
|
|
||||||
@dsl.pipeline(name="ade-classification-pipeline")
|
@dsl.pipeline(name="ade-classification-pipeline")
|
||||||
def ade_classification_pipeline(
|
def ade_classification_pipeline(
|
||||||
runpod_api_key: str,
|
# RunPod config - from env or override
|
||||||
runpod_endpoint: str = "k57do7afav01es",
|
runpod_api_key: str = "",
|
||||||
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
runpod_endpoint: str = "",
|
||||||
|
# Model config
|
||||||
|
model_name: str = "",
|
||||||
max_samples: int = 10000,
|
max_samples: int = 10000,
|
||||||
epochs: int = 3,
|
epochs: int = 3,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
|
eval_split: float = 0.1,
|
||||||
|
# AWS config
|
||||||
s3_bucket: str = "",
|
s3_bucket: str = "",
|
||||||
aws_access_key_id: str = "",
|
aws_access_key_id: str = "",
|
||||||
aws_secret_access_key: str = "",
|
aws_secret_access_key: str = "",
|
||||||
aws_session_token: str = "",
|
aws_session_token: str = "",
|
||||||
|
aws_region: str = "us-east-1",
|
||||||
|
# Runtime config
|
||||||
|
poll_interval: int = 10,
|
||||||
|
timeout: int = 3600,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Adverse Drug Event Classification Pipeline
|
Adverse Drug Event Classification Pipeline
|
||||||
|
|
||||||
Trains Bio_ClinicalBERT on ADE Corpus V2 (30K samples)
|
Trains Bio_ClinicalBERT on ADE Corpus V2 (30K samples).
|
||||||
Binary classification: ADE present / No ADE
|
Binary classification: ADE present / No ADE.
|
||||||
|
|
||||||
|
All parameters can be provided via environment variables:
|
||||||
|
- RUNPOD_API_KEY, RUNPOD_ENDPOINT
|
||||||
|
- BASE_MODEL, MAX_SAMPLES, EPOCHS, BATCH_SIZE
|
||||||
|
- S3_BUCKET, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, etc.
|
||||||
"""
|
"""
|
||||||
train_task = train_ade_model(
|
train_healthcare_model(
|
||||||
runpod_api_key=runpod_api_key,
|
task="ade",
|
||||||
runpod_endpoint=runpod_endpoint,
|
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
|
||||||
model_name=model_name,
|
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
|
||||||
|
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
s3_bucket=s3_bucket,
|
eval_split=eval_split,
|
||||||
aws_access_key_id=aws_access_key_id,
|
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
s3_prefix="ade-models/bert",
|
||||||
aws_session_token=aws_session_token,
|
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
|
||||||
|
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
|
||||||
|
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
|
||||||
|
aws_region=aws_region,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Medical Triage Classification Pipeline
|
|
||||||
# ============================================================================
|
|
||||||
@dsl.component(
|
|
||||||
base_image="python:3.11-slim",
|
|
||||||
packages_to_install=["requests"]
|
|
||||||
)
|
|
||||||
def train_triage_model(
|
|
||||||
runpod_api_key: str,
|
|
||||||
runpod_endpoint: str,
|
|
||||||
model_name: str,
|
|
||||||
max_samples: int,
|
|
||||||
epochs: int,
|
|
||||||
batch_size: int,
|
|
||||||
s3_bucket: str,
|
|
||||||
aws_access_key_id: str,
|
|
||||||
aws_secret_access_key: str,
|
|
||||||
aws_session_token: str,
|
|
||||||
) -> dict:
|
|
||||||
"""Train Medical Triage classifier on RunPod."""
|
|
||||||
import requests
|
|
||||||
import time
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
|
|
||||||
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
|
||||||
json={
|
|
||||||
"input": {
|
|
||||||
"task": "triage",
|
|
||||||
"model_name": model_name,
|
|
||||||
"max_samples": max_samples,
|
|
||||||
"epochs": epochs,
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"eval_split": 0.1,
|
|
||||||
"s3_bucket": s3_bucket,
|
|
||||||
"s3_prefix": "triage-models/bert",
|
|
||||||
"aws_access_key_id": aws_access_key_id,
|
|
||||||
"aws_secret_access_key": aws_secret_access_key,
|
|
||||||
"aws_session_token": aws_session_token,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
job_id = response.json()["id"]
|
|
||||||
print(f"RunPod job submitted: {job_id}")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
status = requests.get(
|
|
||||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
|
|
||||||
headers={"Authorization": f"Bearer {runpod_api_key}"}
|
|
||||||
).json()
|
|
||||||
|
|
||||||
if status["status"] == "COMPLETED":
|
|
||||||
return status["output"]
|
|
||||||
elif status["status"] == "FAILED":
|
|
||||||
raise Exception(f"Training failed: {status}")
|
|
||||||
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
|
|
||||||
@dsl.pipeline(name="triage-classification-pipeline")
|
@dsl.pipeline(name="triage-classification-pipeline")
|
||||||
def triage_classification_pipeline(
|
def triage_classification_pipeline(
|
||||||
runpod_api_key: str,
|
runpod_api_key: str = "",
|
||||||
runpod_endpoint: str = "k57do7afav01es",
|
runpod_endpoint: str = "",
|
||||||
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
model_name: str = "",
|
||||||
max_samples: int = 5000,
|
max_samples: int = 5000,
|
||||||
epochs: int = 3,
|
epochs: int = 3,
|
||||||
batch_size: int = 8,
|
batch_size: int = 8,
|
||||||
|
eval_split: float = 0.1,
|
||||||
s3_bucket: str = "",
|
s3_bucket: str = "",
|
||||||
aws_access_key_id: str = "",
|
aws_access_key_id: str = "",
|
||||||
aws_secret_access_key: str = "",
|
aws_secret_access_key: str = "",
|
||||||
aws_session_token: str = "",
|
aws_session_token: str = "",
|
||||||
|
aws_region: str = "us-east-1",
|
||||||
|
poll_interval: int = 10,
|
||||||
|
timeout: int = 3600,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Medical Triage Classification Pipeline
|
Medical Triage Classification Pipeline
|
||||||
|
|
||||||
Trains classifier for ER triage urgency levels.
|
Trains classifier for ER triage urgency levels.
|
||||||
Multi-class: Emergency, Urgent, Standard, etc.
|
Multi-class: Emergency, Urgent, Standard, Non-urgent.
|
||||||
"""
|
"""
|
||||||
train_task = train_triage_model(
|
train_healthcare_model(
|
||||||
runpod_api_key=runpod_api_key,
|
task="triage",
|
||||||
runpod_endpoint=runpod_endpoint,
|
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
|
||||||
model_name=model_name,
|
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
|
||||||
|
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
s3_bucket=s3_bucket,
|
eval_split=eval_split,
|
||||||
aws_access_key_id=aws_access_key_id,
|
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
s3_prefix="triage-models/bert",
|
||||||
aws_session_token=aws_session_token,
|
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
|
||||||
|
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
|
||||||
|
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
|
||||||
|
aws_region=aws_region,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Symptom-to-Disease Classification Pipeline
|
|
||||||
# ============================================================================
|
|
||||||
@dsl.component(
|
|
||||||
base_image="python:3.11-slim",
|
|
||||||
packages_to_install=["requests"]
|
|
||||||
)
|
|
||||||
def train_symptom_disease_model(
|
|
||||||
runpod_api_key: str,
|
|
||||||
runpod_endpoint: str,
|
|
||||||
model_name: str,
|
|
||||||
max_samples: int,
|
|
||||||
epochs: int,
|
|
||||||
batch_size: int,
|
|
||||||
s3_bucket: str,
|
|
||||||
aws_access_key_id: str,
|
|
||||||
aws_secret_access_key: str,
|
|
||||||
aws_session_token: str,
|
|
||||||
) -> dict:
|
|
||||||
"""Train Symptom-to-Disease classifier on RunPod."""
|
|
||||||
import requests
|
|
||||||
import time
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
|
|
||||||
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
|
||||||
json={
|
|
||||||
"input": {
|
|
||||||
"task": "symptom_disease",
|
|
||||||
"model_name": model_name,
|
|
||||||
"max_samples": max_samples,
|
|
||||||
"epochs": epochs,
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"eval_split": 0.1,
|
|
||||||
"s3_bucket": s3_bucket,
|
|
||||||
"s3_prefix": "symptom-disease-models/bert",
|
|
||||||
"aws_access_key_id": aws_access_key_id,
|
|
||||||
"aws_secret_access_key": aws_secret_access_key,
|
|
||||||
"aws_session_token": aws_session_token,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
job_id = response.json()["id"]
|
|
||||||
print(f"RunPod job submitted: {job_id}")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
status = requests.get(
|
|
||||||
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
|
|
||||||
headers={"Authorization": f"Bearer {runpod_api_key}"}
|
|
||||||
).json()
|
|
||||||
|
|
||||||
if status["status"] == "COMPLETED":
|
|
||||||
return status["output"]
|
|
||||||
elif status["status"] == "FAILED":
|
|
||||||
raise Exception(f"Training failed: {status}")
|
|
||||||
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
|
|
||||||
@dsl.pipeline(name="symptom-disease-classification-pipeline")
|
@dsl.pipeline(name="symptom-disease-classification-pipeline")
|
||||||
def symptom_disease_pipeline(
|
def symptom_disease_pipeline(
|
||||||
runpod_api_key: str,
|
runpod_api_key: str = "",
|
||||||
runpod_endpoint: str = "k57do7afav01es",
|
runpod_endpoint: str = "",
|
||||||
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
model_name: str = "",
|
||||||
max_samples: int = 5000,
|
max_samples: int = 5000,
|
||||||
epochs: int = 3,
|
epochs: int = 3,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
|
eval_split: float = 0.1,
|
||||||
s3_bucket: str = "",
|
s3_bucket: str = "",
|
||||||
aws_access_key_id: str = "",
|
aws_access_key_id: str = "",
|
||||||
aws_secret_access_key: str = "",
|
aws_secret_access_key: str = "",
|
||||||
aws_session_token: str = "",
|
aws_session_token: str = "",
|
||||||
|
aws_region: str = "us-east-1",
|
||||||
|
poll_interval: int = 10,
|
||||||
|
timeout: int = 3600,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Symptom-to-Disease Classification Pipeline
|
Symptom-to-Disease Classification Pipeline
|
||||||
|
|
||||||
Predicts disease from symptom descriptions.
|
Predicts disease from symptom descriptions.
|
||||||
Multi-class: 40+ disease categories
|
Multi-class: 41 disease categories.
|
||||||
"""
|
"""
|
||||||
train_task = train_symptom_disease_model(
|
train_healthcare_model(
|
||||||
runpod_api_key=runpod_api_key,
|
task="symptom_disease",
|
||||||
runpod_endpoint=runpod_endpoint,
|
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
|
||||||
model_name=model_name,
|
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
|
||||||
|
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
s3_bucket=s3_bucket,
|
eval_split=eval_split,
|
||||||
aws_access_key_id=aws_access_key_id,
|
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
s3_prefix="symptom-disease-models/bert",
|
||||||
aws_session_token=aws_session_token,
|
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
|
||||||
|
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
|
||||||
|
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
|
||||||
|
aws_region=aws_region,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
@dsl.pipeline(name="ddi-classification-pipeline")
|
||||||
# Full Healthcare Training Pipeline (All Tasks)
|
def ddi_classification_pipeline(
|
||||||
# ============================================================================
|
runpod_api_key: str = "",
|
||||||
@dsl.pipeline(name="healthcare-multi-task-pipeline")
|
runpod_endpoint: str = "",
|
||||||
def healthcare_multi_task_pipeline(
|
model_name: str = "",
|
||||||
runpod_api_key: str,
|
max_samples: int = 10000,
|
||||||
runpod_endpoint: str = "k57do7afav01es",
|
epochs: int = 3,
|
||||||
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
batch_size: int = 16,
|
||||||
|
eval_split: float = 0.1,
|
||||||
s3_bucket: str = "",
|
s3_bucket: str = "",
|
||||||
aws_access_key_id: str = "",
|
aws_access_key_id: str = "",
|
||||||
aws_secret_access_key: str = "",
|
aws_secret_access_key: str = "",
|
||||||
aws_session_token: str = "",
|
aws_session_token: str = "",
|
||||||
|
aws_region: str = "us-east-1",
|
||||||
|
poll_interval: int = 10,
|
||||||
|
timeout: int = 3600,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Drug-Drug Interaction Classification Pipeline
|
||||||
|
|
||||||
|
Trains on 176K DrugBank DDI samples.
|
||||||
|
Multi-class severity: Minor, Moderate, Major, Contraindicated.
|
||||||
|
"""
|
||||||
|
train_healthcare_model(
|
||||||
|
task="ddi",
|
||||||
|
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
|
||||||
|
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
|
||||||
|
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
|
||||||
|
max_samples=max_samples,
|
||||||
|
epochs=epochs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
eval_split=eval_split,
|
||||||
|
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
|
||||||
|
s3_prefix="ddi-models/bert",
|
||||||
|
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
|
||||||
|
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
|
||||||
|
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
|
||||||
|
aws_region=aws_region,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.pipeline(name="healthcare-multi-task-pipeline")
|
||||||
|
def healthcare_multi_task_pipeline(
|
||||||
|
runpod_api_key: str = "",
|
||||||
|
runpod_endpoint: str = "",
|
||||||
|
model_name: str = "",
|
||||||
|
s3_bucket: str = "",
|
||||||
|
aws_access_key_id: str = "",
|
||||||
|
aws_secret_access_key: str = "",
|
||||||
|
aws_session_token: str = "",
|
||||||
|
aws_region: str = "us-east-1",
|
||||||
|
poll_interval: int = 10,
|
||||||
|
timeout: int = 3600,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Train all healthcare models in parallel.
|
Train all healthcare models in parallel.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
|
- DDI classifier (s3://bucket/ddi-models/...)
|
||||||
- ADE classifier (s3://bucket/ade-models/...)
|
- ADE classifier (s3://bucket/ade-models/...)
|
||||||
- Triage classifier (s3://bucket/triage-models/...)
|
- Triage classifier (s3://bucket/triage-models/...)
|
||||||
- Symptom-Disease classifier (s3://bucket/symptom-disease-models/...)
|
- Symptom-Disease classifier (s3://bucket/symptom-disease-models/...)
|
||||||
"""
|
"""
|
||||||
# Run all training tasks in parallel
|
# Resolve env vars once
|
||||||
ade_task = train_ade_model(
|
_runpod_key = runpod_api_key or _get_env("RUNPOD_API_KEY")
|
||||||
runpod_api_key=runpod_api_key,
|
_runpod_endpoint = runpod_endpoint or _get_env("RUNPOD_ENDPOINT")
|
||||||
runpod_endpoint=runpod_endpoint,
|
_model = model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT")
|
||||||
model_name=model_name,
|
_bucket = s3_bucket or _get_env("S3_BUCKET")
|
||||||
|
_aws_key = aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID")
|
||||||
|
_aws_secret = aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY")
|
||||||
|
_aws_token = aws_session_token or _get_env("AWS_SESSION_TOKEN")
|
||||||
|
|
||||||
|
# Run all training tasks in parallel (no dependencies between them)
|
||||||
|
ddi_task = train_healthcare_model(
|
||||||
|
task="ddi",
|
||||||
|
runpod_api_key=_runpod_key,
|
||||||
|
runpod_endpoint=_runpod_endpoint,
|
||||||
|
model_name=_model,
|
||||||
max_samples=10000,
|
max_samples=10000,
|
||||||
epochs=3,
|
epochs=3,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
s3_bucket=s3_bucket,
|
eval_split=0.1,
|
||||||
aws_access_key_id=aws_access_key_id,
|
s3_bucket=_bucket,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
s3_prefix="ddi-models/bert",
|
||||||
aws_session_token=aws_session_token,
|
aws_access_key_id=_aws_key,
|
||||||
|
aws_secret_access_key=_aws_secret,
|
||||||
|
aws_session_token=_aws_token,
|
||||||
|
aws_region=aws_region,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
triage_task = train_triage_model(
|
ade_task = train_healthcare_model(
|
||||||
runpod_api_key=runpod_api_key,
|
task="ade",
|
||||||
runpod_endpoint=runpod_endpoint,
|
runpod_api_key=_runpod_key,
|
||||||
model_name=model_name,
|
runpod_endpoint=_runpod_endpoint,
|
||||||
|
model_name=_model,
|
||||||
|
max_samples=10000,
|
||||||
|
epochs=3,
|
||||||
|
batch_size=16,
|
||||||
|
eval_split=0.1,
|
||||||
|
s3_bucket=_bucket,
|
||||||
|
s3_prefix="ade-models/bert",
|
||||||
|
aws_access_key_id=_aws_key,
|
||||||
|
aws_secret_access_key=_aws_secret,
|
||||||
|
aws_session_token=_aws_token,
|
||||||
|
aws_region=aws_region,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
triage_task = train_healthcare_model(
|
||||||
|
task="triage",
|
||||||
|
runpod_api_key=_runpod_key,
|
||||||
|
runpod_endpoint=_runpod_endpoint,
|
||||||
|
model_name=_model,
|
||||||
max_samples=5000,
|
max_samples=5000,
|
||||||
epochs=3,
|
epochs=3,
|
||||||
batch_size=8,
|
batch_size=8,
|
||||||
s3_bucket=s3_bucket,
|
eval_split=0.1,
|
||||||
aws_access_key_id=aws_access_key_id,
|
s3_bucket=_bucket,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
s3_prefix="triage-models/bert",
|
||||||
aws_session_token=aws_session_token,
|
aws_access_key_id=_aws_key,
|
||||||
|
aws_secret_access_key=_aws_secret,
|
||||||
|
aws_session_token=_aws_token,
|
||||||
|
aws_region=aws_region,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
symptom_task = train_symptom_disease_model(
|
symptom_task = train_healthcare_model(
|
||||||
runpod_api_key=runpod_api_key,
|
task="symptom_disease",
|
||||||
runpod_endpoint=runpod_endpoint,
|
runpod_api_key=_runpod_key,
|
||||||
model_name=model_name,
|
runpod_endpoint=_runpod_endpoint,
|
||||||
|
model_name=_model,
|
||||||
max_samples=5000,
|
max_samples=5000,
|
||||||
epochs=3,
|
epochs=3,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
s3_bucket=s3_bucket,
|
eval_split=0.1,
|
||||||
aws_access_key_id=aws_access_key_id,
|
s3_bucket=_bucket,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
s3_prefix="symptom-disease-models/bert",
|
||||||
aws_session_token=aws_session_token,
|
aws_access_key_id=_aws_key,
|
||||||
|
aws_secret_access_key=_aws_secret,
|
||||||
|
aws_session_token=_aws_token,
|
||||||
|
aws_region=aws_region,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Compile Pipelines
|
||||||
|
# =============================================================================
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Compile pipelines
|
import argparse
|
||||||
compiler.Compiler().compile(
|
|
||||||
ade_classification_pipeline,
|
parser = argparse.ArgumentParser(description="Compile Kubeflow pipelines")
|
||||||
"ade_classification_pipeline.yaml"
|
parser.add_argument("--output-dir", default=".", help="Output directory for compiled YAML")
|
||||||
)
|
args = parser.parse_args()
|
||||||
compiler.Compiler().compile(
|
|
||||||
triage_classification_pipeline,
|
pipelines = [
|
||||||
"triage_classification_pipeline.yaml"
|
(ade_classification_pipeline, "ade_classification_pipeline.yaml"),
|
||||||
)
|
(triage_classification_pipeline, "triage_classification_pipeline.yaml"),
|
||||||
compiler.Compiler().compile(
|
(symptom_disease_pipeline, "symptom_disease_pipeline.yaml"),
|
||||||
symptom_disease_pipeline,
|
(ddi_classification_pipeline, "ddi_classification_pipeline.yaml"),
|
||||||
"symptom_disease_pipeline.yaml"
|
(healthcare_multi_task_pipeline, "healthcare_multi_task_pipeline.yaml"),
|
||||||
)
|
]
|
||||||
compiler.Compiler().compile(
|
|
||||||
healthcare_multi_task_pipeline,
|
for pipeline_func, filename in pipelines:
|
||||||
"healthcare_multi_task_pipeline.yaml"
|
output_path = os.path.join(args.output_dir, filename)
|
||||||
)
|
compiler.Compiler().compile(pipeline_func, output_path)
|
||||||
print("All pipelines compiled!")
|
print(f"Compiled: {output_path}")
|
||||||
|
|
||||||
|
print(f"\n✓ All {len(pipelines)} pipelines compiled!")
|
||||||
|
|||||||
Reference in New Issue
Block a user