mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
feat: add ADE, Triage, and Symptom-Disease training pipelines
New tasks supported: - task=ade: Adverse Drug Event classification (ADE Corpus V2, 30K samples) - task=triage: Medical Triage classification (urgency levels) - task=symptom_disease: Symptom-to-Disease prediction (40+ diseases) All use HuggingFace datasets, Bio_ClinicalBERT, and S3 model storage.
This commit is contained in:
@@ -445,23 +445,483 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
shutil.rmtree(work_dir, ignore_errors=True)
|
shutil.rmtree(work_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def train_ade_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Train Adverse Drug Event (ADE) binary classifier.
|
||||||
|
Dataset: ade-benchmark-corpus/ade_corpus_v2 (30K samples)
|
||||||
|
Labels: 0=No ADE, 1=ADE Present
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer, AutoModelForSequenceClassification,
|
||||||
|
TrainingArguments, Trainer
|
||||||
|
)
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
work_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
||||||
|
max_samples = job_input.get('max_samples', 10000)
|
||||||
|
epochs = job_input.get('epochs', 3)
|
||||||
|
batch_size = job_input.get('batch_size', 16)
|
||||||
|
learning_rate = job_input.get('learning_rate', 2e-5)
|
||||||
|
eval_split = job_input.get('eval_split', 0.1)
|
||||||
|
|
||||||
|
print(f"Loading ADE Corpus V2 dataset...")
|
||||||
|
dataset = load_dataset("ade-benchmark-corpus/ade_corpus_v2", "Ade_corpus_v2_classification")
|
||||||
|
|
||||||
|
# Prepare data
|
||||||
|
training_data = []
|
||||||
|
for item in dataset['train']:
|
||||||
|
if max_samples and len(training_data) >= max_samples:
|
||||||
|
break
|
||||||
|
training_data.append({
|
||||||
|
'text': item['text'],
|
||||||
|
'label': item['label'] # 0 or 1
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"Loaded {len(training_data)} ADE samples")
|
||||||
|
|
||||||
|
# Split
|
||||||
|
if eval_split > 0:
|
||||||
|
train_data, eval_data = train_test_split(
|
||||||
|
training_data, test_size=eval_split, random_state=42,
|
||||||
|
stratify=[d['label'] for d in training_data]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_data, eval_data = training_data, None
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
train_dataset = Dataset.from_list(train_data)
|
||||||
|
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
|
||||||
|
|
||||||
|
# Load model (binary: ADE / No ADE)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_name, num_labels=2
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenize(examples):
|
||||||
|
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256)
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(tokenize, batched=True)
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.map(tokenize, batched=True)
|
||||||
|
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=work_dir,
|
||||||
|
num_train_epochs=epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
eval_strategy='epoch' if eval_dataset else 'no',
|
||||||
|
save_strategy='no',
|
||||||
|
fp16=torch.cuda.is_available(),
|
||||||
|
report_to='none',
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_metrics(eval_pred):
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
|
preds = eval_pred.predictions.argmax(-1)
|
||||||
|
return {
|
||||||
|
'accuracy': accuracy_score(eval_pred.label_ids, preds),
|
||||||
|
'f1': f1_score(eval_pred.label_ids, preds),
|
||||||
|
'precision': precision_score(eval_pred.label_ids, preds),
|
||||||
|
'recall': recall_score(eval_pred.label_ids, preds),
|
||||||
|
}
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
compute_metrics=compute_metrics if eval_dataset else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_result = trainer.train()
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
'task': 'ade_classification',
|
||||||
|
'train_loss': float(train_result.training_loss),
|
||||||
|
'epochs': epochs,
|
||||||
|
'model_name': model_name,
|
||||||
|
'samples': len(train_data),
|
||||||
|
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
||||||
|
'data_source': 'ade_corpus_v2',
|
||||||
|
}
|
||||||
|
|
||||||
|
if eval_dataset:
|
||||||
|
eval_result = trainer.evaluate()
|
||||||
|
metrics.update({
|
||||||
|
'eval_accuracy': float(eval_result['eval_accuracy']),
|
||||||
|
'eval_f1': float(eval_result['eval_f1']),
|
||||||
|
'eval_precision': float(eval_result['eval_precision']),
|
||||||
|
'eval_recall': float(eval_result['eval_recall']),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save to S3
|
||||||
|
s3_uri = None
|
||||||
|
s3_bucket = job_input.get('s3_bucket')
|
||||||
|
if s3_bucket:
|
||||||
|
save_dir = os.path.join(work_dir, 'saved_model')
|
||||||
|
trainer.save_model(save_dir)
|
||||||
|
tokenizer.save_pretrained(save_dir)
|
||||||
|
|
||||||
|
aws_creds = {
|
||||||
|
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
||||||
|
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
||||||
|
'aws_session_token': job_input.get('aws_session_token'),
|
||||||
|
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
||||||
|
}
|
||||||
|
s3_prefix = job_input.get('s3_prefix', 'ade-models/bert')
|
||||||
|
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
|
||||||
|
metrics['s3_uri'] = s3_uri
|
||||||
|
|
||||||
|
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'ADE classifier trained on ADE Corpus V2'}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(work_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def train_triage_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Train Medical Triage classifier.
|
||||||
|
Dataset: shubham212/Medical_Triage_Classification
|
||||||
|
Labels: Triage urgency levels
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
from datasets import load_dataset, Dataset
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer, AutoModelForSequenceClassification,
|
||||||
|
TrainingArguments, Trainer
|
||||||
|
)
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
work_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
||||||
|
max_samples = job_input.get('max_samples', 5000)
|
||||||
|
epochs = job_input.get('epochs', 3)
|
||||||
|
batch_size = job_input.get('batch_size', 8)
|
||||||
|
learning_rate = job_input.get('learning_rate', 2e-5)
|
||||||
|
eval_split = job_input.get('eval_split', 0.1)
|
||||||
|
|
||||||
|
print(f"Loading Medical Triage dataset...")
|
||||||
|
dataset = load_dataset("shubham212/Medical_Triage_Classification")
|
||||||
|
|
||||||
|
# Get unique labels
|
||||||
|
labels = sorted(set(item['label'] for item in dataset['train']))
|
||||||
|
label2id = {l: i for i, l in enumerate(labels)}
|
||||||
|
id2label = {i: l for l, i in label2id.items()}
|
||||||
|
num_labels = len(labels)
|
||||||
|
|
||||||
|
print(f"Found {num_labels} triage levels: {labels}")
|
||||||
|
|
||||||
|
training_data = []
|
||||||
|
for item in dataset['train']:
|
||||||
|
if max_samples and len(training_data) >= max_samples:
|
||||||
|
break
|
||||||
|
training_data.append({
|
||||||
|
'text': item['text'],
|
||||||
|
'label': label2id[item['label']]
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"Loaded {len(training_data)} triage samples")
|
||||||
|
|
||||||
|
if eval_split > 0:
|
||||||
|
train_data, eval_data = train_test_split(
|
||||||
|
training_data, test_size=eval_split, random_state=42,
|
||||||
|
stratify=[d['label'] for d in training_data]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_data, eval_data = training_data, None
|
||||||
|
|
||||||
|
train_dataset = Dataset.from_list(train_data)
|
||||||
|
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_name, num_labels=num_labels, id2label=id2label, label2id=label2id
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenize(examples):
|
||||||
|
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(tokenize, batched=True)
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.map(tokenize, batched=True)
|
||||||
|
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=work_dir,
|
||||||
|
num_train_epochs=epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
eval_strategy='epoch' if eval_dataset else 'no',
|
||||||
|
save_strategy='no',
|
||||||
|
fp16=torch.cuda.is_available(),
|
||||||
|
report_to='none',
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_metrics(eval_pred):
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score
|
||||||
|
preds = eval_pred.predictions.argmax(-1)
|
||||||
|
return {
|
||||||
|
'accuracy': accuracy_score(eval_pred.label_ids, preds),
|
||||||
|
'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'),
|
||||||
|
}
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
compute_metrics=compute_metrics if eval_dataset else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_result = trainer.train()
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
'task': 'triage_classification',
|
||||||
|
'train_loss': float(train_result.training_loss),
|
||||||
|
'epochs': epochs,
|
||||||
|
'model_name': model_name,
|
||||||
|
'samples': len(train_data),
|
||||||
|
'num_labels': num_labels,
|
||||||
|
'labels': id2label,
|
||||||
|
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
||||||
|
'data_source': 'medical_triage_classification',
|
||||||
|
}
|
||||||
|
|
||||||
|
if eval_dataset:
|
||||||
|
eval_result = trainer.evaluate()
|
||||||
|
metrics.update({
|
||||||
|
'eval_accuracy': float(eval_result['eval_accuracy']),
|
||||||
|
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
|
||||||
|
})
|
||||||
|
|
||||||
|
s3_uri = None
|
||||||
|
s3_bucket = job_input.get('s3_bucket')
|
||||||
|
if s3_bucket:
|
||||||
|
save_dir = os.path.join(work_dir, 'saved_model')
|
||||||
|
trainer.save_model(save_dir)
|
||||||
|
tokenizer.save_pretrained(save_dir)
|
||||||
|
|
||||||
|
aws_creds = {
|
||||||
|
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
||||||
|
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
||||||
|
'aws_session_token': job_input.get('aws_session_token'),
|
||||||
|
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
||||||
|
}
|
||||||
|
s3_prefix = job_input.get('s3_prefix', 'triage-models/bert')
|
||||||
|
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
|
||||||
|
metrics['s3_uri'] = s3_uri
|
||||||
|
|
||||||
|
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'Triage classifier trained'}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(work_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def train_symptom_disease_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Train Symptom-to-Disease classifier.
|
||||||
|
Dataset: shanover/disease_symptoms_prec_full
|
||||||
|
Task: Predict disease from symptoms
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
from datasets import load_dataset, Dataset
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer, AutoModelForSequenceClassification,
|
||||||
|
TrainingArguments, Trainer
|
||||||
|
)
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
work_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
||||||
|
max_samples = job_input.get('max_samples', 5000)
|
||||||
|
epochs = job_input.get('epochs', 3)
|
||||||
|
batch_size = job_input.get('batch_size', 16)
|
||||||
|
learning_rate = job_input.get('learning_rate', 2e-5)
|
||||||
|
eval_split = job_input.get('eval_split', 0.1)
|
||||||
|
|
||||||
|
print(f"Loading Symptom-Disease dataset...")
|
||||||
|
dataset = load_dataset("shanover/disease_symptoms_prec_full")
|
||||||
|
|
||||||
|
# Build label mapping from diseases
|
||||||
|
diseases = sorted(set(item['disease'] for item in dataset['train']))
|
||||||
|
label2id = {d: i for i, d in enumerate(diseases)}
|
||||||
|
id2label = {i: d for d, i in label2id.items()}
|
||||||
|
num_labels = len(diseases)
|
||||||
|
|
||||||
|
print(f"Found {num_labels} diseases")
|
||||||
|
|
||||||
|
training_data = []
|
||||||
|
for item in dataset['train']:
|
||||||
|
if max_samples and len(training_data) >= max_samples:
|
||||||
|
break
|
||||||
|
# Format symptoms as natural text
|
||||||
|
symptoms = item['symptoms'].replace('_', ' ').replace(',', ', ')
|
||||||
|
training_data.append({
|
||||||
|
'text': f"Patient presents with: {symptoms}",
|
||||||
|
'label': label2id[item['disease']]
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"Loaded {len(training_data)} symptom-disease samples")
|
||||||
|
|
||||||
|
if eval_split > 0:
|
||||||
|
train_data, eval_data = train_test_split(
|
||||||
|
training_data, test_size=eval_split, random_state=42
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_data, eval_data = training_data, None
|
||||||
|
|
||||||
|
train_dataset = Dataset.from_list(train_data)
|
||||||
|
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_name, num_labels=num_labels, id2label=id2label, label2id=label2id
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenize(examples):
|
||||||
|
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256)
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(tokenize, batched=True)
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.map(tokenize, batched=True)
|
||||||
|
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=work_dir,
|
||||||
|
num_train_epochs=epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
eval_strategy='epoch' if eval_dataset else 'no',
|
||||||
|
save_strategy='no',
|
||||||
|
fp16=torch.cuda.is_available(),
|
||||||
|
report_to='none',
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_metrics(eval_pred):
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score, top_k_accuracy_score
|
||||||
|
preds = eval_pred.predictions.argmax(-1)
|
||||||
|
metrics = {
|
||||||
|
'accuracy': accuracy_score(eval_pred.label_ids, preds),
|
||||||
|
'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'),
|
||||||
|
}
|
||||||
|
# Top-5 accuracy (important for diagnosis)
|
||||||
|
try:
|
||||||
|
metrics['top5_accuracy'] = top_k_accuracy_score(
|
||||||
|
eval_pred.label_ids, eval_pred.predictions, k=5
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
compute_metrics=compute_metrics if eval_dataset else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_result = trainer.train()
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
'task': 'symptom_disease_classification',
|
||||||
|
'train_loss': float(train_result.training_loss),
|
||||||
|
'epochs': epochs,
|
||||||
|
'model_name': model_name,
|
||||||
|
'samples': len(train_data),
|
||||||
|
'num_diseases': num_labels,
|
||||||
|
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
||||||
|
'data_source': 'disease_symptoms_prec_full',
|
||||||
|
}
|
||||||
|
|
||||||
|
if eval_dataset:
|
||||||
|
eval_result = trainer.evaluate()
|
||||||
|
metrics.update({
|
||||||
|
'eval_accuracy': float(eval_result['eval_accuracy']),
|
||||||
|
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
|
||||||
|
})
|
||||||
|
if 'eval_top5_accuracy' in eval_result:
|
||||||
|
metrics['eval_top5_accuracy'] = float(eval_result['eval_top5_accuracy'])
|
||||||
|
|
||||||
|
s3_uri = None
|
||||||
|
s3_bucket = job_input.get('s3_bucket')
|
||||||
|
if s3_bucket:
|
||||||
|
save_dir = os.path.join(work_dir, 'saved_model')
|
||||||
|
trainer.save_model(save_dir)
|
||||||
|
tokenizer.save_pretrained(save_dir)
|
||||||
|
|
||||||
|
# Save label mapping
|
||||||
|
with open(os.path.join(save_dir, 'disease_labels.json'), 'w') as f:
|
||||||
|
json.dump({'id2label': id2label, 'label2id': label2id}, f)
|
||||||
|
|
||||||
|
aws_creds = {
|
||||||
|
'aws_access_key_id': job_input.get('aws_access_key_id'),
|
||||||
|
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
|
||||||
|
'aws_session_token': job_input.get('aws_session_token'),
|
||||||
|
'aws_region': job_input.get('aws_region', 'us-east-1'),
|
||||||
|
}
|
||||||
|
s3_prefix = job_input.get('s3_prefix', 'symptom-disease-models/bert')
|
||||||
|
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
|
||||||
|
metrics['s3_uri'] = s3_uri
|
||||||
|
|
||||||
|
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'Symptom-Disease classifier trained'}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(work_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
def handler(job):
|
def handler(job):
|
||||||
"""RunPod serverless handler."""
|
"""RunPod serverless handler with multi-task support."""
|
||||||
job_input = job.get('input', {})
|
job_input = job.get('input', {})
|
||||||
|
|
||||||
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
|
# Task routing
|
||||||
use_lora = job_input.get('use_lora', True)
|
task = job_input.get('task', 'ddi')
|
||||||
|
|
||||||
# Auto-detect: use LoRA for large models
|
if task == 'ade':
|
||||||
if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']):
|
return train_ade_classifier(job_input)
|
||||||
use_lora = True
|
elif task == 'triage':
|
||||||
elif 'bert' in model_name.lower():
|
return train_triage_classifier(job_input)
|
||||||
use_lora = False
|
elif task == 'symptom_disease':
|
||||||
|
return train_symptom_disease_classifier(job_input)
|
||||||
if use_lora:
|
elif task == 'ddi' or task == 'bert':
|
||||||
|
# Original DDI training
|
||||||
|
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
||||||
|
if 'bert' in model_name.lower():
|
||||||
|
return train_bert_classifier(job_input)
|
||||||
|
else:
|
||||||
|
return train_llm_lora(job_input)
|
||||||
|
elif task == 'llm':
|
||||||
return train_llm_lora(job_input)
|
return train_llm_lora(job_input)
|
||||||
else:
|
else:
|
||||||
return train_bert_classifier(job_input)
|
# Auto-detect based on model
|
||||||
|
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
|
||||||
|
if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']):
|
||||||
|
return train_llm_lora(job_input)
|
||||||
|
else:
|
||||||
|
return train_bert_classifier(job_input)
|
||||||
|
|
||||||
|
|
||||||
# RunPod serverless entrypoint
|
# RunPod serverless entrypoint
|
||||||
|
|||||||
376
pipelines/healthcare_training.py
Normal file
376
pipelines/healthcare_training.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
"""
|
||||||
|
Healthcare ML Training Pipelines
|
||||||
|
|
||||||
|
Multi-task training pipelines for:
|
||||||
|
- Adverse Drug Event (ADE) Classification
|
||||||
|
- Medical Triage Classification
|
||||||
|
- Symptom-to-Disease Prediction
|
||||||
|
- Drug-Drug Interaction (DDI) Classification
|
||||||
|
|
||||||
|
All use RunPod serverless GPU infrastructure.
|
||||||
|
"""
|
||||||
|
from kfp import dsl
|
||||||
|
from kfp import compiler
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# ADE (Adverse Drug Event) Classification Pipeline
|
||||||
|
# ============================================================================
|
||||||
|
@dsl.component(
|
||||||
|
base_image="python:3.11-slim",
|
||||||
|
packages_to_install=["requests"]
|
||||||
|
)
|
||||||
|
def train_ade_model(
|
||||||
|
runpod_api_key: str,
|
||||||
|
runpod_endpoint: str,
|
||||||
|
model_name: str,
|
||||||
|
max_samples: int,
|
||||||
|
epochs: int,
|
||||||
|
batch_size: int,
|
||||||
|
s3_bucket: str,
|
||||||
|
aws_access_key_id: str,
|
||||||
|
aws_secret_access_key: str,
|
||||||
|
aws_session_token: str,
|
||||||
|
) -> dict:
|
||||||
|
"""Train ADE classifier on RunPod serverless GPU."""
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
|
||||||
|
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
||||||
|
json={
|
||||||
|
"input": {
|
||||||
|
"task": "ade",
|
||||||
|
"model_name": model_name,
|
||||||
|
"max_samples": max_samples,
|
||||||
|
"epochs": epochs,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"eval_split": 0.1,
|
||||||
|
"s3_bucket": s3_bucket,
|
||||||
|
"s3_prefix": "ade-models/bert",
|
||||||
|
"aws_access_key_id": aws_access_key_id,
|
||||||
|
"aws_secret_access_key": aws_secret_access_key,
|
||||||
|
"aws_session_token": aws_session_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = response.json()["id"]
|
||||||
|
print(f"RunPod job submitted: {job_id}")
|
||||||
|
|
||||||
|
# Poll for completion
|
||||||
|
while True:
|
||||||
|
status = requests.get(
|
||||||
|
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
|
||||||
|
headers={"Authorization": f"Bearer {runpod_api_key}"}
|
||||||
|
).json()
|
||||||
|
|
||||||
|
if status["status"] == "COMPLETED":
|
||||||
|
return status["output"]
|
||||||
|
elif status["status"] == "FAILED":
|
||||||
|
raise Exception(f"Training failed: {status}")
|
||||||
|
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.pipeline(name="ade-classification-pipeline")
|
||||||
|
def ade_classification_pipeline(
|
||||||
|
runpod_api_key: str,
|
||||||
|
runpod_endpoint: str = "k57do7afav01es",
|
||||||
|
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||||
|
max_samples: int = 10000,
|
||||||
|
epochs: int = 3,
|
||||||
|
batch_size: int = 16,
|
||||||
|
s3_bucket: str = "",
|
||||||
|
aws_access_key_id: str = "",
|
||||||
|
aws_secret_access_key: str = "",
|
||||||
|
aws_session_token: str = "",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adverse Drug Event Classification Pipeline
|
||||||
|
|
||||||
|
Trains Bio_ClinicalBERT on ADE Corpus V2 (30K samples)
|
||||||
|
Binary classification: ADE present / No ADE
|
||||||
|
"""
|
||||||
|
train_task = train_ade_model(
|
||||||
|
runpod_api_key=runpod_api_key,
|
||||||
|
runpod_endpoint=runpod_endpoint,
|
||||||
|
model_name=model_name,
|
||||||
|
max_samples=max_samples,
|
||||||
|
epochs=epochs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
s3_bucket=s3_bucket,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Medical Triage Classification Pipeline
|
||||||
|
# ============================================================================
|
||||||
|
@dsl.component(
|
||||||
|
base_image="python:3.11-slim",
|
||||||
|
packages_to_install=["requests"]
|
||||||
|
)
|
||||||
|
def train_triage_model(
|
||||||
|
runpod_api_key: str,
|
||||||
|
runpod_endpoint: str,
|
||||||
|
model_name: str,
|
||||||
|
max_samples: int,
|
||||||
|
epochs: int,
|
||||||
|
batch_size: int,
|
||||||
|
s3_bucket: str,
|
||||||
|
aws_access_key_id: str,
|
||||||
|
aws_secret_access_key: str,
|
||||||
|
aws_session_token: str,
|
||||||
|
) -> dict:
|
||||||
|
"""Train Medical Triage classifier on RunPod."""
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
|
||||||
|
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
||||||
|
json={
|
||||||
|
"input": {
|
||||||
|
"task": "triage",
|
||||||
|
"model_name": model_name,
|
||||||
|
"max_samples": max_samples,
|
||||||
|
"epochs": epochs,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"eval_split": 0.1,
|
||||||
|
"s3_bucket": s3_bucket,
|
||||||
|
"s3_prefix": "triage-models/bert",
|
||||||
|
"aws_access_key_id": aws_access_key_id,
|
||||||
|
"aws_secret_access_key": aws_secret_access_key,
|
||||||
|
"aws_session_token": aws_session_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = response.json()["id"]
|
||||||
|
print(f"RunPod job submitted: {job_id}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
status = requests.get(
|
||||||
|
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
|
||||||
|
headers={"Authorization": f"Bearer {runpod_api_key}"}
|
||||||
|
).json()
|
||||||
|
|
||||||
|
if status["status"] == "COMPLETED":
|
||||||
|
return status["output"]
|
||||||
|
elif status["status"] == "FAILED":
|
||||||
|
raise Exception(f"Training failed: {status}")
|
||||||
|
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.pipeline(name="triage-classification-pipeline")
|
||||||
|
def triage_classification_pipeline(
|
||||||
|
runpod_api_key: str,
|
||||||
|
runpod_endpoint: str = "k57do7afav01es",
|
||||||
|
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||||
|
max_samples: int = 5000,
|
||||||
|
epochs: int = 3,
|
||||||
|
batch_size: int = 8,
|
||||||
|
s3_bucket: str = "",
|
||||||
|
aws_access_key_id: str = "",
|
||||||
|
aws_secret_access_key: str = "",
|
||||||
|
aws_session_token: str = "",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Medical Triage Classification Pipeline
|
||||||
|
|
||||||
|
Trains classifier for ER triage urgency levels.
|
||||||
|
Multi-class: Emergency, Urgent, Standard, etc.
|
||||||
|
"""
|
||||||
|
train_task = train_triage_model(
|
||||||
|
runpod_api_key=runpod_api_key,
|
||||||
|
runpod_endpoint=runpod_endpoint,
|
||||||
|
model_name=model_name,
|
||||||
|
max_samples=max_samples,
|
||||||
|
epochs=epochs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
s3_bucket=s3_bucket,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Symptom-to-Disease Classification Pipeline
|
||||||
|
# ============================================================================
|
||||||
|
@dsl.component(
|
||||||
|
base_image="python:3.11-slim",
|
||||||
|
packages_to_install=["requests"]
|
||||||
|
)
|
||||||
|
def train_symptom_disease_model(
|
||||||
|
runpod_api_key: str,
|
||||||
|
runpod_endpoint: str,
|
||||||
|
model_name: str,
|
||||||
|
max_samples: int,
|
||||||
|
epochs: int,
|
||||||
|
batch_size: int,
|
||||||
|
s3_bucket: str,
|
||||||
|
aws_access_key_id: str,
|
||||||
|
aws_secret_access_key: str,
|
||||||
|
aws_session_token: str,
|
||||||
|
) -> dict:
|
||||||
|
"""Train Symptom-to-Disease classifier on RunPod."""
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"https://api.runpod.ai/v2/{runpod_endpoint}/run",
|
||||||
|
headers={"Authorization": f"Bearer {runpod_api_key}"},
|
||||||
|
json={
|
||||||
|
"input": {
|
||||||
|
"task": "symptom_disease",
|
||||||
|
"model_name": model_name,
|
||||||
|
"max_samples": max_samples,
|
||||||
|
"epochs": epochs,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"eval_split": 0.1,
|
||||||
|
"s3_bucket": s3_bucket,
|
||||||
|
"s3_prefix": "symptom-disease-models/bert",
|
||||||
|
"aws_access_key_id": aws_access_key_id,
|
||||||
|
"aws_secret_access_key": aws_secret_access_key,
|
||||||
|
"aws_session_token": aws_session_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = response.json()["id"]
|
||||||
|
print(f"RunPod job submitted: {job_id}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
status = requests.get(
|
||||||
|
f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}",
|
||||||
|
headers={"Authorization": f"Bearer {runpod_api_key}"}
|
||||||
|
).json()
|
||||||
|
|
||||||
|
if status["status"] == "COMPLETED":
|
||||||
|
return status["output"]
|
||||||
|
elif status["status"] == "FAILED":
|
||||||
|
raise Exception(f"Training failed: {status}")
|
||||||
|
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.pipeline(name="symptom-disease-classification-pipeline")
|
||||||
|
def symptom_disease_pipeline(
|
||||||
|
runpod_api_key: str,
|
||||||
|
runpod_endpoint: str = "k57do7afav01es",
|
||||||
|
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||||
|
max_samples: int = 5000,
|
||||||
|
epochs: int = 3,
|
||||||
|
batch_size: int = 16,
|
||||||
|
s3_bucket: str = "",
|
||||||
|
aws_access_key_id: str = "",
|
||||||
|
aws_secret_access_key: str = "",
|
||||||
|
aws_session_token: str = "",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Symptom-to-Disease Classification Pipeline
|
||||||
|
|
||||||
|
Predicts disease from symptom descriptions.
|
||||||
|
Multi-class: 40+ disease categories
|
||||||
|
"""
|
||||||
|
train_task = train_symptom_disease_model(
|
||||||
|
runpod_api_key=runpod_api_key,
|
||||||
|
runpod_endpoint=runpod_endpoint,
|
||||||
|
model_name=model_name,
|
||||||
|
max_samples=max_samples,
|
||||||
|
epochs=epochs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
s3_bucket=s3_bucket,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Full Healthcare Training Pipeline (All Tasks)
|
||||||
|
# ============================================================================
|
||||||
|
@dsl.pipeline(name="healthcare-multi-task-pipeline")
|
||||||
|
def healthcare_multi_task_pipeline(
|
||||||
|
runpod_api_key: str,
|
||||||
|
runpod_endpoint: str = "k57do7afav01es",
|
||||||
|
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||||
|
s3_bucket: str = "",
|
||||||
|
aws_access_key_id: str = "",
|
||||||
|
aws_secret_access_key: str = "",
|
||||||
|
aws_session_token: str = "",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Train all healthcare models in parallel.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- ADE classifier (s3://bucket/ade-models/...)
|
||||||
|
- Triage classifier (s3://bucket/triage-models/...)
|
||||||
|
- Symptom-Disease classifier (s3://bucket/symptom-disease-models/...)
|
||||||
|
"""
|
||||||
|
# Run all training tasks in parallel
|
||||||
|
ade_task = train_ade_model(
|
||||||
|
runpod_api_key=runpod_api_key,
|
||||||
|
runpod_endpoint=runpod_endpoint,
|
||||||
|
model_name=model_name,
|
||||||
|
max_samples=10000,
|
||||||
|
epochs=3,
|
||||||
|
batch_size=16,
|
||||||
|
s3_bucket=s3_bucket,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
triage_task = train_triage_model(
|
||||||
|
runpod_api_key=runpod_api_key,
|
||||||
|
runpod_endpoint=runpod_endpoint,
|
||||||
|
model_name=model_name,
|
||||||
|
max_samples=5000,
|
||||||
|
epochs=3,
|
||||||
|
batch_size=8,
|
||||||
|
s3_bucket=s3_bucket,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
symptom_task = train_symptom_disease_model(
|
||||||
|
runpod_api_key=runpod_api_key,
|
||||||
|
runpod_endpoint=runpod_endpoint,
|
||||||
|
model_name=model_name,
|
||||||
|
max_samples=5000,
|
||||||
|
epochs=3,
|
||||||
|
batch_size=16,
|
||||||
|
s3_bucket=s3_bucket,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Compile pipelines
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
ade_classification_pipeline,
|
||||||
|
"ade_classification_pipeline.yaml"
|
||||||
|
)
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
triage_classification_pipeline,
|
||||||
|
"triage_classification_pipeline.yaml"
|
||||||
|
)
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
symptom_disease_pipeline,
|
||||||
|
"symptom_disease_pipeline.yaml"
|
||||||
|
)
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
healthcare_multi_task_pipeline,
|
||||||
|
"healthcare_multi_task_pipeline.yaml"
|
||||||
|
)
|
||||||
|
print("All pipelines compiled!")
|
||||||
Reference in New Issue
Block a user