refactor: environment variable configuration for all pipeline settings

- Add config.py with dataclass-based configuration from env vars
- Remove hardcoded RunPod endpoint and credentials
- Consolidate duplicate training components into single reusable function
- Add .env.example with all configurable options
- Update README with environment variable documentation
- Add Kubernetes secrets example for production deployments
- Add timeout and error handling improvements

BREAKING: Pipeline parameters now use env vars by default.
Set RUNPOD_API_KEY, RUNPOD_ENDPOINT, S3_BUCKET, and AWS creds.
This commit is contained in:
2026-02-03 20:47:27 +00:00
parent 419918460d
commit 5f554ea769
4 changed files with 490 additions and 226 deletions

93
pipelines/config.py Normal file
View File

@@ -0,0 +1,93 @@
"""
Pipeline Configuration
All configuration loaded from environment variables with sensible defaults.
Secrets should be provided via Kubernetes secrets, not hardcoded.
"""
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class RunPodConfig:
"""RunPod API configuration."""
api_key: str = os.getenv("RUNPOD_API_KEY", "")
endpoint: str = os.getenv("RUNPOD_ENDPOINT", "")
api_base: str = os.getenv("RUNPOD_API_BASE", "https://api.runpod.ai/v2")
@dataclass
class AWSConfig:
"""AWS credentials and settings."""
access_key_id: str = os.getenv("AWS_ACCESS_KEY_ID", "")
secret_access_key: str = os.getenv("AWS_SECRET_ACCESS_KEY", "")
session_token: str = os.getenv("AWS_SESSION_TOKEN", "")
region: str = os.getenv("AWS_REGION", "us-east-1")
s3_bucket: str = os.getenv("S3_BUCKET", "")
s3_prefix: str = os.getenv("S3_PREFIX", "models")
@dataclass
class ModelConfig:
"""Model training defaults."""
base_model: str = os.getenv("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT")
max_samples: int = int(os.getenv("MAX_SAMPLES", "10000"))
epochs: int = int(os.getenv("EPOCHS", "3"))
batch_size: int = int(os.getenv("BATCH_SIZE", "16"))
eval_split: float = float(os.getenv("EVAL_SPLIT", "0.1"))
learning_rate: float = float(os.getenv("LEARNING_RATE", "2e-5"))
@dataclass
class PipelineConfig:
"""Combined pipeline configuration."""
runpod: RunPodConfig
aws: AWSConfig
model: ModelConfig
# Pipeline settings
poll_interval: int = int(os.getenv("POLL_INTERVAL_SECONDS", "10"))
timeout: int = int(os.getenv("TRAINING_TIMEOUT_SECONDS", "3600"))
@classmethod
def from_env(cls) -> "PipelineConfig":
"""Load configuration from environment variables."""
return cls(
runpod=RunPodConfig(),
aws=AWSConfig(),
model=ModelConfig(),
)
# Task-specific defaults (can override base config)
TASK_DEFAULTS = {
"ddi": {
"max_samples": 10000,
"batch_size": 16,
"s3_prefix": "ddi-models",
},
"ade": {
"max_samples": 10000,
"batch_size": 16,
"s3_prefix": "ade-models",
},
"triage": {
"max_samples": 5000,
"batch_size": 8,
"s3_prefix": "triage-models",
},
"symptom_disease": {
"max_samples": 5000,
"batch_size": 16,
"s3_prefix": "symptom-disease-models",
},
}
def get_task_config(task: str, overrides: Optional[dict] = None) -> dict:
"""Get task-specific configuration with optional overrides."""
config = TASK_DEFAULTS.get(task, {}).copy()
if overrides:
config.update(overrides)
return config

View File

@@ -8,369 +8,464 @@ Multi-task training pipelines for:
- Drug-Drug Interaction (DDI) Classification
All use RunPod serverless GPU infrastructure.
Configuration via environment variables - see config.py for details.
Environment Variables:
RUNPOD_API_KEY - RunPod API key (required)
RUNPOD_ENDPOINT - RunPod serverless endpoint ID (required)
AWS_ACCESS_KEY_ID - AWS credentials for S3 upload
AWS_SECRET_ACCESS_KEY
AWS_SESSION_TOKEN - Optional session token for assumed roles
AWS_REGION - Default: us-east-1
S3_BUCKET - Bucket for model artifacts (required)
BASE_MODEL - HuggingFace model ID (default: Bio_ClinicalBERT)
MAX_SAMPLES - Training samples (default: 10000)
EPOCHS - Training epochs (default: 3)
BATCH_SIZE - Batch size (default: 16)
"""
import os
from kfp import dsl
from kfp import compiler
from typing import Optional
# ============================================================================
# ADE (Adverse Drug Event) Classification Pipeline
# ============================================================================
# =============================================================================
# Reusable Training Component
# =============================================================================
@dsl.component(
base_image="python:3.11-slim",
packages_to_install=["requests"]
)
def train_ade_model(
def train_healthcare_model(
task: str,
runpod_api_key: str,
runpod_endpoint: str,
model_name: str,
max_samples: int,
epochs: int,
batch_size: int,
eval_split: float,
s3_bucket: str,
s3_prefix: str,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_session_token: str,
aws_region: str,
poll_interval: int,
timeout: int,
) -> dict:
"""Train ADE classifier on RunPod serverless GPU."""
"""
Generic healthcare model training component.
Submits training job to RunPod serverless GPU and polls for completion.
Trained model is uploaded to S3 by the RunPod handler.
Args:
task: Training task (ddi, ade, triage, symptom_disease)
runpod_api_key: RunPod API key
runpod_endpoint: RunPod serverless endpoint ID
model_name: HuggingFace model ID
max_samples: Maximum training samples
epochs: Training epochs
batch_size: Training batch size
eval_split: Validation split ratio
s3_bucket: S3 bucket for model output
s3_prefix: S3 key prefix for this task
aws_*: AWS credentials for S3 access
poll_interval: Seconds between status checks
timeout: Maximum training time in seconds
Returns:
Training output including metrics and S3 URI
"""
import requests
import time
api_base = os.getenv("RUNPOD_API_BASE", "https://api.runpod.ai/v2")
# Submit training job
response = requests.post(
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
f"{api_base}/{runpod_endpoint}/run",
headers={"Authorization": f"Bearer {runpod_api_key}"},
json={
"input": {
"task": "ade",
"task": task,
"model_name": model_name,
"max_samples": max_samples,
"epochs": epochs,
"batch_size": batch_size,
"eval_split": 0.1,
"eval_split": eval_split,
"s3_bucket": s3_bucket,
"s3_prefix": "ade-models/bert",
"s3_prefix": s3_prefix,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"aws_session_token": aws_session_token,
"aws_region": aws_region,
}
}
},
timeout=30,
)
response.raise_for_status()
job_id = response.json()["id"]
print(f"RunPod job submitted: {job_id}")
print(f"[{task}] RunPod job submitted: {job_id}")
# Poll for completion
start_time = time.time()
while True:
status = requests.get(
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
headers={"Authorization": f"Bearer {runpod_api_key}"}
).json()
elapsed = time.time() - start_time
if elapsed > timeout:
raise TimeoutError(f"Training exceeded timeout of {timeout}s")
status_resp = requests.get(
f"{api_base}/{runpod_endpoint}/status/{job_id}",
headers={"Authorization": f"Bearer {runpod_api_key}"},
timeout=30,
)
status_resp.raise_for_status()
status = status_resp.json()
if status["status"] == "COMPLETED":
return status["output"]
print(f"[{task}] Training completed in {elapsed:.0f}s")
return status.get("output", {})
elif status["status"] == "FAILED":
raise Exception(f"Training failed: {status}")
error = status.get("error", "Unknown error")
raise RuntimeError(f"[{task}] Training failed: {error}")
elif status["status"] in ["IN_QUEUE", "IN_PROGRESS"]:
print(f"[{task}] Status: {status['status']} ({elapsed:.0f}s elapsed)")
time.sleep(10)
time.sleep(poll_interval)
# =============================================================================
# Pipeline Definitions
# =============================================================================
def _get_env(name: str, default: str = "") -> str:
"""Helper to get env var with default."""
return os.getenv(name, default)
def _get_env_int(name: str, default: int) -> int:
"""Helper to get int env var with default."""
return int(os.getenv(name, str(default)))
def _get_env_float(name: str, default: float) -> float:
"""Helper to get float env var with default."""
return float(os.getenv(name, str(default)))
@dsl.pipeline(name="ade-classification-pipeline")
def ade_classification_pipeline(
runpod_api_key: str,
runpod_endpoint: str = "k57do7afav01es",
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
# RunPod config - from env or override
runpod_api_key: str = "",
runpod_endpoint: str = "",
# Model config
model_name: str = "",
max_samples: int = 10000,
epochs: int = 3,
batch_size: int = 16,
eval_split: float = 0.1,
# AWS config
s3_bucket: str = "",
aws_access_key_id: str = "",
aws_secret_access_key: str = "",
aws_session_token: str = "",
aws_region: str = "us-east-1",
# Runtime config
poll_interval: int = 10,
timeout: int = 3600,
):
"""
Adverse Drug Event Classification Pipeline
Trains Bio_ClinicalBERT on ADE Corpus V2 (30K samples)
Binary classification: ADE present / No ADE
Trains Bio_ClinicalBERT on ADE Corpus V2 (30K samples).
Binary classification: ADE present / No ADE.
All parameters can be provided via environment variables:
- RUNPOD_API_KEY, RUNPOD_ENDPOINT
- BASE_MODEL, MAX_SAMPLES, EPOCHS, BATCH_SIZE
- S3_BUCKET, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, etc.
"""
train_task = train_ade_model(
runpod_api_key=runpod_api_key,
runpod_endpoint=runpod_endpoint,
model_name=model_name,
train_healthcare_model(
task="ade",
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
max_samples=max_samples,
epochs=epochs,
batch_size=batch_size,
s3_bucket=s3_bucket,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
eval_split=eval_split,
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
s3_prefix="ade-models/bert",
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
# ============================================================================
# Medical Triage Classification Pipeline
# ============================================================================
@dsl.component(
base_image="python:3.11-slim",
packages_to_install=["requests"]
)
def train_triage_model(
runpod_api_key: str,
runpod_endpoint: str,
model_name: str,
max_samples: int,
epochs: int,
batch_size: int,
s3_bucket: str,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_session_token: str,
) -> dict:
"""Train Medical Triage classifier on RunPod."""
import requests
import time
response = requests.post(
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
headers={"Authorization": f"Bearer {runpod_api_key}"},
json={
"input": {
"task": "triage",
"model_name": model_name,
"max_samples": max_samples,
"epochs": epochs,
"batch_size": batch_size,
"eval_split": 0.1,
"s3_bucket": s3_bucket,
"s3_prefix": "triage-models/bert",
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"aws_session_token": aws_session_token,
}
}
)
job_id = response.json()["id"]
print(f"RunPod job submitted: {job_id}")
while True:
status = requests.get(
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
headers={"Authorization": f"Bearer {runpod_api_key}"}
).json()
if status["status"] == "COMPLETED":
return status["output"]
elif status["status"] == "FAILED":
raise Exception(f"Training failed: {status}")
time.sleep(10)
@dsl.pipeline(name="triage-classification-pipeline")
def triage_classification_pipeline(
runpod_api_key: str,
runpod_endpoint: str = "k57do7afav01es",
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
runpod_api_key: str = "",
runpod_endpoint: str = "",
model_name: str = "",
max_samples: int = 5000,
epochs: int = 3,
batch_size: int = 8,
eval_split: float = 0.1,
s3_bucket: str = "",
aws_access_key_id: str = "",
aws_secret_access_key: str = "",
aws_session_token: str = "",
aws_region: str = "us-east-1",
poll_interval: int = 10,
timeout: int = 3600,
):
"""
Medical Triage Classification Pipeline
Trains classifier for ER triage urgency levels.
Multi-class: Emergency, Urgent, Standard, etc.
Multi-class: Emergency, Urgent, Standard, Non-urgent.
"""
train_task = train_triage_model(
runpod_api_key=runpod_api_key,
runpod_endpoint=runpod_endpoint,
model_name=model_name,
train_healthcare_model(
task="triage",
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
max_samples=max_samples,
epochs=epochs,
batch_size=batch_size,
s3_bucket=s3_bucket,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
eval_split=eval_split,
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
s3_prefix="triage-models/bert",
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
# ============================================================================
# Symptom-to-Disease Classification Pipeline
# ============================================================================
@dsl.component(
base_image="python:3.11-slim",
packages_to_install=["requests"]
)
def train_symptom_disease_model(
runpod_api_key: str,
runpod_endpoint: str,
model_name: str,
max_samples: int,
epochs: int,
batch_size: int,
s3_bucket: str,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_session_token: str,
) -> dict:
"""Train Symptom-to-Disease classifier on RunPod."""
import requests
import time
response = requests.post(
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
headers={"Authorization": f"Bearer {runpod_api_key}"},
json={
"input": {
"task": "symptom_disease",
"model_name": model_name,
"max_samples": max_samples,
"epochs": epochs,
"batch_size": batch_size,
"eval_split": 0.1,
"s3_bucket": s3_bucket,
"s3_prefix": "symptom-disease-models/bert",
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"aws_session_token": aws_session_token,
}
}
)
job_id = response.json()["id"]
print(f"RunPod job submitted: {job_id}")
while True:
status = requests.get(
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
headers={"Authorization": f"Bearer {runpod_api_key}"}
).json()
if status["status"] == "COMPLETED":
return status["output"]
elif status["status"] == "FAILED":
raise Exception(f"Training failed: {status}")
time.sleep(10)
@dsl.pipeline(name="symptom-disease-classification-pipeline")
def symptom_disease_pipeline(
runpod_api_key: str,
runpod_endpoint: str = "k57do7afav01es",
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
runpod_api_key: str = "",
runpod_endpoint: str = "",
model_name: str = "",
max_samples: int = 5000,
epochs: int = 3,
batch_size: int = 16,
eval_split: float = 0.1,
s3_bucket: str = "",
aws_access_key_id: str = "",
aws_secret_access_key: str = "",
aws_session_token: str = "",
aws_region: str = "us-east-1",
poll_interval: int = 10,
timeout: int = 3600,
):
"""
Symptom-to-Disease Classification Pipeline
Predicts disease from symptom descriptions.
Multi-class: 40+ disease categories
Multi-class: 41 disease categories.
"""
train_task = train_symptom_disease_model(
runpod_api_key=runpod_api_key,
runpod_endpoint=runpod_endpoint,
model_name=model_name,
train_healthcare_model(
task="symptom_disease",
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
max_samples=max_samples,
epochs=epochs,
batch_size=batch_size,
s3_bucket=s3_bucket,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
eval_split=eval_split,
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
s3_prefix="symptom-disease-models/bert",
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
# ============================================================================
# Full Healthcare Training Pipeline (All Tasks)
# ============================================================================
@dsl.pipeline(name="healthcare-multi-task-pipeline")
def healthcare_multi_task_pipeline(
runpod_api_key: str,
runpod_endpoint: str = "k57do7afav01es",
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
@dsl.pipeline(name="ddi-classification-pipeline")
def ddi_classification_pipeline(
runpod_api_key: str = "",
runpod_endpoint: str = "",
model_name: str = "",
max_samples: int = 10000,
epochs: int = 3,
batch_size: int = 16,
eval_split: float = 0.1,
s3_bucket: str = "",
aws_access_key_id: str = "",
aws_secret_access_key: str = "",
aws_session_token: str = "",
aws_region: str = "us-east-1",
poll_interval: int = 10,
timeout: int = 3600,
):
"""
Drug-Drug Interaction Classification Pipeline
Trains on 176K DrugBank DDI samples.
Multi-class severity: Minor, Moderate, Major, Contraindicated.
"""
train_healthcare_model(
task="ddi",
runpod_api_key=runpod_api_key or _get_env("RUNPOD_API_KEY"),
runpod_endpoint=runpod_endpoint or _get_env("RUNPOD_ENDPOINT"),
model_name=model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT"),
max_samples=max_samples,
epochs=epochs,
batch_size=batch_size,
eval_split=eval_split,
s3_bucket=s3_bucket or _get_env("S3_BUCKET"),
s3_prefix="ddi-models/bert",
aws_access_key_id=aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or _get_env("AWS_SESSION_TOKEN"),
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
@dsl.pipeline(name="healthcare-multi-task-pipeline")
def healthcare_multi_task_pipeline(
runpod_api_key: str = "",
runpod_endpoint: str = "",
model_name: str = "",
s3_bucket: str = "",
aws_access_key_id: str = "",
aws_secret_access_key: str = "",
aws_session_token: str = "",
aws_region: str = "us-east-1",
poll_interval: int = 10,
timeout: int = 3600,
):
"""
Train all healthcare models in parallel.
Outputs:
- DDI classifier (s3://bucket/ddi-models/...)
- ADE classifier (s3://bucket/ade-models/...)
- Triage classifier (s3://bucket/triage-models/...)
- Symptom-Disease classifier (s3://bucket/symptom-disease-models/...)
"""
# Run all training tasks in parallel
ade_task = train_ade_model(
runpod_api_key=runpod_api_key,
runpod_endpoint=runpod_endpoint,
model_name=model_name,
# Resolve env vars once
_runpod_key = runpod_api_key or _get_env("RUNPOD_API_KEY")
_runpod_endpoint = runpod_endpoint or _get_env("RUNPOD_ENDPOINT")
_model = model_name or _get_env("BASE_MODEL", "emilyalsentzer/Bio_ClinicalBERT")
_bucket = s3_bucket or _get_env("S3_BUCKET")
_aws_key = aws_access_key_id or _get_env("AWS_ACCESS_KEY_ID")
_aws_secret = aws_secret_access_key or _get_env("AWS_SECRET_ACCESS_KEY")
_aws_token = aws_session_token or _get_env("AWS_SESSION_TOKEN")
# Run all training tasks in parallel (no dependencies between them)
ddi_task = train_healthcare_model(
task="ddi",
runpod_api_key=_runpod_key,
runpod_endpoint=_runpod_endpoint,
model_name=_model,
max_samples=10000,
epochs=3,
batch_size=16,
s3_bucket=s3_bucket,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
eval_split=0.1,
s3_bucket=_bucket,
s3_prefix="ddi-models/bert",
aws_access_key_id=_aws_key,
aws_secret_access_key=_aws_secret,
aws_session_token=_aws_token,
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
triage_task = train_triage_model(
runpod_api_key=runpod_api_key,
runpod_endpoint=runpod_endpoint,
model_name=model_name,
ade_task = train_healthcare_model(
task="ade",
runpod_api_key=_runpod_key,
runpod_endpoint=_runpod_endpoint,
model_name=_model,
max_samples=10000,
epochs=3,
batch_size=16,
eval_split=0.1,
s3_bucket=_bucket,
s3_prefix="ade-models/bert",
aws_access_key_id=_aws_key,
aws_secret_access_key=_aws_secret,
aws_session_token=_aws_token,
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
triage_task = train_healthcare_model(
task="triage",
runpod_api_key=_runpod_key,
runpod_endpoint=_runpod_endpoint,
model_name=_model,
max_samples=5000,
epochs=3,
batch_size=8,
s3_bucket=s3_bucket,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
eval_split=0.1,
s3_bucket=_bucket,
s3_prefix="triage-models/bert",
aws_access_key_id=_aws_key,
aws_secret_access_key=_aws_secret,
aws_session_token=_aws_token,
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
symptom_task = train_symptom_disease_model(
runpod_api_key=runpod_api_key,
runpod_endpoint=runpod_endpoint,
model_name=model_name,
symptom_task = train_healthcare_model(
task="symptom_disease",
runpod_api_key=_runpod_key,
runpod_endpoint=_runpod_endpoint,
model_name=_model,
max_samples=5000,
epochs=3,
batch_size=16,
s3_bucket=s3_bucket,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
eval_split=0.1,
s3_bucket=_bucket,
s3_prefix="symptom-disease-models/bert",
aws_access_key_id=_aws_key,
aws_secret_access_key=_aws_secret,
aws_session_token=_aws_token,
aws_region=aws_region,
poll_interval=poll_interval,
timeout=timeout,
)
# =============================================================================
# Compile Pipelines
# =============================================================================
if __name__ == "__main__":
# Compile pipelines
compiler.Compiler().compile(
ade_classification_pipeline,
"ade_classification_pipeline.yaml"
)
compiler.Compiler().compile(
triage_classification_pipeline,
"triage_classification_pipeline.yaml"
)
compiler.Compiler().compile(
symptom_disease_pipeline,
"symptom_disease_pipeline.yaml"
)
compiler.Compiler().compile(
healthcare_multi_task_pipeline,
"healthcare_multi_task_pipeline.yaml"
)
print("All pipelines compiled!")
import argparse
parser = argparse.ArgumentParser(description="Compile Kubeflow pipelines")
parser.add_argument("--output-dir", default=".", help="Output directory for compiled YAML")
args = parser.parse_args()
pipelines = [
(ade_classification_pipeline, "ade_classification_pipeline.yaml"),
(triage_classification_pipeline, "triage_classification_pipeline.yaml"),
(symptom_disease_pipeline, "symptom_disease_pipeline.yaml"),
(ddi_classification_pipeline, "ddi_classification_pipeline.yaml"),
(healthcare_multi_task_pipeline, "healthcare_multi_task_pipeline.yaml"),
]
for pipeline_func, filename in pipelines:
output_path = os.path.join(args.output_dir, filename)
compiler.Compiler().compile(pipeline_func, output_path)
print(f"Compiled: {output_path}")
print(f"\n✓ All {len(pipelines)} pipelines compiled!")