mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 14:55:11 +00:00
Add DDI training pipeline with RunPod serverless GPU support
This commit is contained in:
21
components/runpod_trainer/Dockerfile
Normal file
21
components/runpod_trainer/Dockerfile
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
FROM runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
runpod \
|
||||||
|
transformers \
|
||||||
|
datasets \
|
||||||
|
accelerate \
|
||||||
|
boto3 \
|
||||||
|
scikit-learn \
|
||||||
|
scipy
|
||||||
|
|
||||||
|
# Copy handler
|
||||||
|
COPY handler.py /app/handler.py
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
CMD ["python", "-u", "handler.py"]
|
||||||
183
components/runpod_trainer/handler.py
Normal file
183
components/runpod_trainer/handler.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""
|
||||||
|
RunPod Serverless Handler for DDI Model Training
|
||||||
|
|
||||||
|
This runs on RunPod GPU instances and trains the Bio_ClinicalBERT model
|
||||||
|
for drug-drug interaction detection.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import runpod
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
def download_from_minio(bucket: str, key: str, local_path: str):
|
||||||
|
"""Download file from MinIO."""
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
s3 = boto3.client(
|
||||||
|
's3',
|
||||||
|
endpoint_url=os.environ['MINIO_ENDPOINT'],
|
||||||
|
aws_access_key_id=os.environ['MINIO_ACCESS_KEY'],
|
||||||
|
aws_secret_access_key=os.environ['MINIO_SECRET_KEY']
|
||||||
|
)
|
||||||
|
s3.download_file(bucket, key, local_path)
|
||||||
|
|
||||||
|
|
||||||
|
def upload_to_minio(local_path: str, bucket: str, key: str):
|
||||||
|
"""Upload file to MinIO."""
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
s3 = boto3.client(
|
||||||
|
's3',
|
||||||
|
endpoint_url=os.environ['MINIO_ENDPOINT'],
|
||||||
|
aws_access_key_id=os.environ['MINIO_ACCESS_KEY'],
|
||||||
|
aws_secret_access_key=os.environ['MINIO_SECRET_KEY']
|
||||||
|
)
|
||||||
|
s3.upload_file(local_path, bucket, key)
|
||||||
|
|
||||||
|
|
||||||
|
def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Train DDI detection model.
|
||||||
|
|
||||||
|
Expected input:
|
||||||
|
{
|
||||||
|
"model_name": "emilyalsentzer/Bio_ClinicalBERT",
|
||||||
|
"dataset_path": "datasets/ddi_train.json",
|
||||||
|
"epochs": 3,
|
||||||
|
"learning_rate": 2e-5,
|
||||||
|
"batch_size": 16,
|
||||||
|
"output_path": "models/ddi_model_v1"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
TrainingArguments,
|
||||||
|
Trainer
|
||||||
|
)
|
||||||
|
from datasets import Dataset
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Extract parameters
|
||||||
|
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
||||||
|
dataset_path = job_input.get('dataset_path', 'datasets/ddi_train.json')
|
||||||
|
epochs = job_input.get('epochs', 3)
|
||||||
|
learning_rate = job_input.get('learning_rate', 2e-5)
|
||||||
|
batch_size = job_input.get('batch_size', 16)
|
||||||
|
output_path = job_input.get('output_path', 'models/ddi_model')
|
||||||
|
|
||||||
|
# Create temp directory
|
||||||
|
work_dir = tempfile.mkdtemp()
|
||||||
|
data_file = os.path.join(work_dir, 'train.json')
|
||||||
|
model_dir = os.path.join(work_dir, 'model')
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Download training data from MinIO
|
||||||
|
print(f"Downloading dataset from {dataset_path}...")
|
||||||
|
download_from_minio('datasets', dataset_path, data_file)
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
with open(data_file, 'r') as f:
|
||||||
|
train_data = json.load(f)
|
||||||
|
|
||||||
|
dataset = Dataset.from_list(train_data)
|
||||||
|
|
||||||
|
# Load model and tokenizer
|
||||||
|
print(f"Loading model: {model_name}")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
num_labels=5 # DDI severity levels: none, minor, moderate, major, contraindicated
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize dataset
|
||||||
|
def tokenize_function(examples):
|
||||||
|
return tokenizer(
|
||||||
|
examples['text'],
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=512
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
||||||
|
|
||||||
|
# Training arguments
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=model_dir,
|
||||||
|
num_train_epochs=epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
per_device_eval_batch_size=batch_size,
|
||||||
|
warmup_steps=100,
|
||||||
|
weight_decay=0.01,
|
||||||
|
logging_dir=os.path.join(work_dir, 'logs'),
|
||||||
|
logging_steps=10,
|
||||||
|
save_strategy='epoch',
|
||||||
|
evaluation_strategy='epoch' if 'validation' in train_data else 'no',
|
||||||
|
load_best_model_at_end=True if 'validation' in train_data else False,
|
||||||
|
fp16=torch.cuda.is_available(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize trainer
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=tokenized_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
print("Starting training...")
|
||||||
|
train_result = trainer.train()
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
print("Saving model...")
|
||||||
|
trainer.save_model(model_dir)
|
||||||
|
tokenizer.save_pretrained(model_dir)
|
||||||
|
|
||||||
|
# Save training metrics
|
||||||
|
metrics = {
|
||||||
|
'train_loss': train_result.training_loss,
|
||||||
|
'epochs': epochs,
|
||||||
|
'model_name': model_name,
|
||||||
|
'samples': len(dataset)
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(os.path.join(model_dir, 'metrics.json'), 'w') as f:
|
||||||
|
json.dump(metrics, f)
|
||||||
|
|
||||||
|
# Upload model to MinIO
|
||||||
|
print(f"Uploading model to {output_path}...")
|
||||||
|
for root, dirs, files in os.walk(model_dir):
|
||||||
|
for file in files:
|
||||||
|
local_file = os.path.join(root, file)
|
||||||
|
relative_path = os.path.relpath(local_file, model_dir)
|
||||||
|
minio_key = f"{output_path}/{relative_path}"
|
||||||
|
upload_to_minio(local_file, 'models', minio_key)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'status': 'success',
|
||||||
|
'model_path': f"s3://models/{output_path}",
|
||||||
|
'metrics': metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
'status': 'error',
|
||||||
|
'error': str(e)
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
shutil.rmtree(work_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def handler(job):
|
||||||
|
"""RunPod serverless handler."""
|
||||||
|
job_input = job['input']
|
||||||
|
return train_ddi_model(job_input)
|
||||||
|
|
||||||
|
|
||||||
|
# RunPod serverless entrypoint
|
||||||
|
runpod.serverless.start({'handler': handler})
|
||||||
255
pipelines/ddi_training_runpod.py
Normal file
255
pipelines/ddi_training_runpod.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""
|
||||||
|
DDI Training Pipeline with RunPod GPU
|
||||||
|
|
||||||
|
Fully automated pipeline that:
|
||||||
|
1. Preprocesses CCDA/FHIR clinical data
|
||||||
|
2. Uploads to MinIO
|
||||||
|
3. Triggers RunPod serverless GPU training
|
||||||
|
4. Evaluates and registers the model
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from kfp import dsl
|
||||||
|
from kfp import compiler
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image="python:3.11-slim",
|
||||||
|
packages_to_install=["boto3", "requests"]
|
||||||
|
)
|
||||||
|
def create_sample_dataset(
|
||||||
|
minio_endpoint: str,
|
||||||
|
minio_access_key: str,
|
||||||
|
minio_secret_key: str,
|
||||||
|
output_path: str = "ddi_train.json"
|
||||||
|
) -> str:
|
||||||
|
"""Create a sample DDI training dataset for testing."""
|
||||||
|
import json
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
# Sample DDI training data (drug pairs with interaction labels)
|
||||||
|
# Labels: 0=none, 1=minor, 2=moderate, 3=major, 4=contraindicated
|
||||||
|
sample_data = [
|
||||||
|
{"text": "Patient taking warfarin and aspirin together", "label": 3},
|
||||||
|
{"text": "Metformin administered with lisinopril", "label": 0},
|
||||||
|
{"text": "Concurrent use of simvastatin and amiodarone", "label": 3},
|
||||||
|
{"text": "Patient prescribed omeprazole with clopidogrel", "label": 2},
|
||||||
|
{"text": "Fluoxetine and tramadol co-administration", "label": 4},
|
||||||
|
{"text": "Atorvastatin given with diltiazem", "label": 2},
|
||||||
|
{"text": "Methotrexate and NSAIDs used together", "label": 3},
|
||||||
|
{"text": "Levothyroxine taken with calcium supplements", "label": 1},
|
||||||
|
{"text": "Ciprofloxacin and theophylline interaction", "label": 3},
|
||||||
|
{"text": "ACE inhibitor with potassium supplement", "label": 2},
|
||||||
|
# Add more samples for better training
|
||||||
|
{"text": "Digoxin and amiodarone combination therapy", "label": 3},
|
||||||
|
{"text": "SSRIs with MAO inhibitors", "label": 4},
|
||||||
|
{"text": "Lithium and ACE inhibitors together", "label": 3},
|
||||||
|
{"text": "Benzodiazepines with opioids", "label": 4},
|
||||||
|
{"text": "Metronidazole and alcohol consumption", "label": 4},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Upload to MinIO
|
||||||
|
s3 = boto3.client(
|
||||||
|
's3',
|
||||||
|
endpoint_url=minio_endpoint,
|
||||||
|
aws_access_key_id=minio_access_key,
|
||||||
|
aws_secret_access_key=minio_secret_key,
|
||||||
|
region_name='us-east-1'
|
||||||
|
)
|
||||||
|
|
||||||
|
data_json = json.dumps(sample_data)
|
||||||
|
s3.put_object(
|
||||||
|
Bucket='datasets',
|
||||||
|
Key=output_path,
|
||||||
|
Body=data_json.encode('utf-8'),
|
||||||
|
ContentType='application/json'
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Uploaded sample dataset to datasets/{output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image="python:3.11-slim",
|
||||||
|
packages_to_install=["requests"]
|
||||||
|
)
|
||||||
|
def trigger_runpod_training(
|
||||||
|
runpod_api_key: str,
|
||||||
|
runpod_endpoint_id: str,
|
||||||
|
minio_endpoint: str,
|
||||||
|
minio_access_key: str,
|
||||||
|
minio_secret_key: str,
|
||||||
|
dataset_path: str,
|
||||||
|
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||||
|
epochs: int = 3,
|
||||||
|
learning_rate: float = 2e-5,
|
||||||
|
output_model_path: str = "ddi_model_v1"
|
||||||
|
) -> str:
|
||||||
|
"""Trigger RunPod serverless training job."""
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
# RunPod API endpoint
|
||||||
|
url = f"https://api.runpod.ai/v2/{runpod_endpoint_id}/runsync"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {runpod_api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"input": {
|
||||||
|
"model_name": model_name,
|
||||||
|
"dataset_path": dataset_path,
|
||||||
|
"epochs": epochs,
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"batch_size": 16,
|
||||||
|
"output_path": output_model_path,
|
||||||
|
# MinIO credentials for the worker
|
||||||
|
"minio_endpoint": minio_endpoint,
|
||||||
|
"minio_access_key": minio_access_key,
|
||||||
|
"minio_secret_key": minio_secret_key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Triggering RunPod training job...")
|
||||||
|
print(f"Model: {model_name}")
|
||||||
|
print(f"Dataset: {dataset_path}")
|
||||||
|
print(f"Epochs: {epochs}")
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=payload, timeout=3600)
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"RunPod API error: {result}")
|
||||||
|
|
||||||
|
if result.get('status') == 'FAILED':
|
||||||
|
raise Exception(f"Training failed: {result.get('error')}")
|
||||||
|
|
||||||
|
output = result.get('output', {})
|
||||||
|
print(f"Training complete!")
|
||||||
|
print(f"Model path: {output.get('model_path')}")
|
||||||
|
print(f"Metrics: {output.get('metrics')}")
|
||||||
|
|
||||||
|
return output.get('model_path', f"s3://models/{output_model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image="python:3.11-slim",
|
||||||
|
packages_to_install=["boto3"]
|
||||||
|
)
|
||||||
|
def register_model(
|
||||||
|
model_path: str,
|
||||||
|
minio_endpoint: str,
|
||||||
|
minio_access_key: str,
|
||||||
|
minio_secret_key: str,
|
||||||
|
model_name: str = "ddi-detector",
|
||||||
|
version: str = "v1"
|
||||||
|
) -> str:
|
||||||
|
"""Register the trained model in the model registry."""
|
||||||
|
import boto3
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
s3 = boto3.client(
|
||||||
|
's3',
|
||||||
|
endpoint_url=minio_endpoint,
|
||||||
|
aws_access_key_id=minio_access_key,
|
||||||
|
aws_secret_access_key=minio_secret_key,
|
||||||
|
region_name='us-east-1'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create model registry entry
|
||||||
|
registry_entry = {
|
||||||
|
"name": model_name,
|
||||||
|
"version": version,
|
||||||
|
"path": model_path,
|
||||||
|
"created_at": datetime.utcnow().isoformat(),
|
||||||
|
"framework": "transformers",
|
||||||
|
"task": "sequence-classification",
|
||||||
|
"labels": ["none", "minor", "moderate", "major", "contraindicated"]
|
||||||
|
}
|
||||||
|
|
||||||
|
registry_key = f"registry/{model_name}/{version}/metadata.json"
|
||||||
|
s3.put_object(
|
||||||
|
Bucket='models',
|
||||||
|
Key=registry_key,
|
||||||
|
Body=json.dumps(registry_entry).encode('utf-8'),
|
||||||
|
ContentType='application/json'
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Model registered: {model_name} v{version}")
|
||||||
|
print(f"Registry path: models/{registry_key}")
|
||||||
|
|
||||||
|
return f"models/{registry_key}"
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.pipeline(
|
||||||
|
name="ddi-training-runpod",
|
||||||
|
description="Train DDI detection model using RunPod serverless GPU"
|
||||||
|
)
|
||||||
|
def ddi_training_pipeline(
|
||||||
|
# RunPod settings
|
||||||
|
runpod_endpoint_id: str = "YOUR_ENDPOINT_ID",
|
||||||
|
|
||||||
|
# Model settings
|
||||||
|
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||||
|
epochs: int = 3,
|
||||||
|
learning_rate: float = 2e-5,
|
||||||
|
model_version: str = "v1",
|
||||||
|
|
||||||
|
# MinIO settings (these will be injected from secrets)
|
||||||
|
minio_endpoint: str = "https://minio.walleye-frog.ts.net",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Full DDI training pipeline:
|
||||||
|
1. Create/upload sample dataset
|
||||||
|
2. Trigger RunPod GPU training
|
||||||
|
3. Register trained model
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
# These would come from k8s secrets in production
|
||||||
|
minio_access_key = "minioadmin"
|
||||||
|
minio_secret_key = "minioadmin123!"
|
||||||
|
runpod_api_key = os.environ.get("RUNPOD_API_KEY", "")
|
||||||
|
|
||||||
|
# Step 1: Create sample dataset
|
||||||
|
dataset_task = create_sample_dataset(
|
||||||
|
minio_endpoint=minio_endpoint,
|
||||||
|
minio_access_key=minio_access_key,
|
||||||
|
minio_secret_key=minio_secret_key,
|
||||||
|
output_path=f"ddi_train_{model_version}.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Trigger RunPod training
|
||||||
|
training_task = trigger_runpod_training(
|
||||||
|
runpod_api_key=runpod_api_key,
|
||||||
|
runpod_endpoint_id=runpod_endpoint_id,
|
||||||
|
minio_endpoint=minio_endpoint,
|
||||||
|
minio_access_key=minio_access_key,
|
||||||
|
minio_secret_key=minio_secret_key,
|
||||||
|
dataset_path=dataset_task.output,
|
||||||
|
model_name=model_name,
|
||||||
|
epochs=epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
output_model_path=f"ddi_model_{model_version}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Register model
|
||||||
|
register_task = register_model(
|
||||||
|
model_path=training_task.output,
|
||||||
|
minio_endpoint=minio_endpoint,
|
||||||
|
minio_access_key=minio_access_key,
|
||||||
|
minio_secret_key=minio_secret_key,
|
||||||
|
model_name="ddi-detector",
|
||||||
|
version=model_version
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=ddi_training_pipeline,
|
||||||
|
package_path="ddi_training_runpod.yaml"
|
||||||
|
)
|
||||||
|
print("Pipeline compiled to ddi_training_runpod.yaml")
|
||||||
Reference in New Issue
Block a user