feat: add ADE, Triage, and Symptom-Disease training pipelines

New tasks supported:
- task=ade: Adverse Drug Event classification (ADE Corpus V2, 30K samples)
- task=triage: Medical Triage classification (urgency levels)
- task=symptom_disease: Symptom-to-Disease prediction (40+ diseases)

All use HuggingFace datasets, Bio_ClinicalBERT, and S3 model storage.
This commit is contained in:
2026-02-03 16:20:55 +00:00
parent f8a0e00a7f
commit 0bf3837e78
2 changed files with 847 additions and 11 deletions

View File

@@ -445,23 +445,483 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
shutil.rmtree(work_dir, ignore_errors=True)
def train_ade_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""
Train Adverse Drug Event (ADE) binary classifier.
Dataset: ade-benchmark-corpus/ade_corpus_v2 (30K samples)
Labels: 0=No ADE, 1=ADE Present
"""
import torch
import tempfile
import shutil
from datasets import load_dataset
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer
)
from sklearn.model_selection import train_test_split
work_dir = tempfile.mkdtemp()
try:
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
max_samples = job_input.get('max_samples', 10000)
epochs = job_input.get('epochs', 3)
batch_size = job_input.get('batch_size', 16)
learning_rate = job_input.get('learning_rate', 2e-5)
eval_split = job_input.get('eval_split', 0.1)
print(f"Loading ADE Corpus V2 dataset...")
dataset = load_dataset("ade-benchmark-corpus/ade_corpus_v2", "Ade_corpus_v2_classification")
# Prepare data
training_data = []
for item in dataset['train']:
if max_samples and len(training_data) >= max_samples:
break
training_data.append({
'text': item['text'],
'label': item['label'] # 0 or 1
})
print(f"Loaded {len(training_data)} ADE samples")
# Split
if eval_split > 0:
train_data, eval_data = train_test_split(
training_data, test_size=eval_split, random_state=42,
stratify=[d['label'] for d in training_data]
)
else:
train_data, eval_data = training_data, None
from datasets import Dataset
train_dataset = Dataset.from_list(train_data)
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
# Load model (binary: ADE / No ADE)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=2
)
def tokenize(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256)
train_dataset = train_dataset.map(tokenize, batched=True)
if eval_dataset:
eval_dataset = eval_dataset.map(tokenize, batched=True)
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
eval_strategy='epoch' if eval_dataset else 'no',
save_strategy='no',
fp16=torch.cuda.is_available(),
report_to='none',
)
def compute_metrics(eval_pred):
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
preds = eval_pred.predictions.argmax(-1)
return {
'accuracy': accuracy_score(eval_pred.label_ids, preds),
'f1': f1_score(eval_pred.label_ids, preds),
'precision': precision_score(eval_pred.label_ids, preds),
'recall': recall_score(eval_pred.label_ids, preds),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics if eval_dataset else None,
)
train_result = trainer.train()
metrics = {
'task': 'ade_classification',
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(train_data),
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'data_source': 'ade_corpus_v2',
}
if eval_dataset:
eval_result = trainer.evaluate()
metrics.update({
'eval_accuracy': float(eval_result['eval_accuracy']),
'eval_f1': float(eval_result['eval_f1']),
'eval_precision': float(eval_result['eval_precision']),
'eval_recall': float(eval_result['eval_recall']),
})
# Save to S3
s3_uri = None
s3_bucket = job_input.get('s3_bucket')
if s3_bucket:
save_dir = os.path.join(work_dir, 'saved_model')
trainer.save_model(save_dir)
tokenizer.save_pretrained(save_dir)
aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_session_token': job_input.get('aws_session_token'),
'aws_region': job_input.get('aws_region', 'us-east-1'),
}
s3_prefix = job_input.get('s3_prefix', 'ade-models/bert')
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
metrics['s3_uri'] = s3_uri
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'ADE classifier trained on ADE Corpus V2'}
except Exception as e:
import traceback
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
def train_triage_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""
Train Medical Triage classifier.
Dataset: shubham212/Medical_Triage_Classification
Labels: Triage urgency levels
"""
import torch
import tempfile
import shutil
from datasets import load_dataset, Dataset
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer
)
from sklearn.model_selection import train_test_split
work_dir = tempfile.mkdtemp()
try:
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
max_samples = job_input.get('max_samples', 5000)
epochs = job_input.get('epochs', 3)
batch_size = job_input.get('batch_size', 8)
learning_rate = job_input.get('learning_rate', 2e-5)
eval_split = job_input.get('eval_split', 0.1)
print(f"Loading Medical Triage dataset...")
dataset = load_dataset("shubham212/Medical_Triage_Classification")
# Get unique labels
labels = sorted(set(item['label'] for item in dataset['train']))
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
num_labels = len(labels)
print(f"Found {num_labels} triage levels: {labels}")
training_data = []
for item in dataset['train']:
if max_samples and len(training_data) >= max_samples:
break
training_data.append({
'text': item['text'],
'label': label2id[item['label']]
})
print(f"Loaded {len(training_data)} triage samples")
if eval_split > 0:
train_data, eval_data = train_test_split(
training_data, test_size=eval_split, random_state=42,
stratify=[d['label'] for d in training_data]
)
else:
train_data, eval_data = training_data, None
train_dataset = Dataset.from_list(train_data)
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, id2label=id2label, label2id=label2id
)
def tokenize(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)
train_dataset = train_dataset.map(tokenize, batched=True)
if eval_dataset:
eval_dataset = eval_dataset.map(tokenize, batched=True)
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
eval_strategy='epoch' if eval_dataset else 'no',
save_strategy='no',
fp16=torch.cuda.is_available(),
report_to='none',
)
def compute_metrics(eval_pred):
from sklearn.metrics import accuracy_score, f1_score
preds = eval_pred.predictions.argmax(-1)
return {
'accuracy': accuracy_score(eval_pred.label_ids, preds),
'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics if eval_dataset else None,
)
train_result = trainer.train()
metrics = {
'task': 'triage_classification',
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(train_data),
'num_labels': num_labels,
'labels': id2label,
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'data_source': 'medical_triage_classification',
}
if eval_dataset:
eval_result = trainer.evaluate()
metrics.update({
'eval_accuracy': float(eval_result['eval_accuracy']),
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
})
s3_uri = None
s3_bucket = job_input.get('s3_bucket')
if s3_bucket:
save_dir = os.path.join(work_dir, 'saved_model')
trainer.save_model(save_dir)
tokenizer.save_pretrained(save_dir)
aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_session_token': job_input.get('aws_session_token'),
'aws_region': job_input.get('aws_region', 'us-east-1'),
}
s3_prefix = job_input.get('s3_prefix', 'triage-models/bert')
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
metrics['s3_uri'] = s3_uri
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'Triage classifier trained'}
except Exception as e:
import traceback
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
def train_symptom_disease_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""
Train Symptom-to-Disease classifier.
Dataset: shanover/disease_symptoms_prec_full
Task: Predict disease from symptoms
"""
import torch
import tempfile
import shutil
from datasets import load_dataset, Dataset
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer
)
from sklearn.model_selection import train_test_split
work_dir = tempfile.mkdtemp()
try:
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
max_samples = job_input.get('max_samples', 5000)
epochs = job_input.get('epochs', 3)
batch_size = job_input.get('batch_size', 16)
learning_rate = job_input.get('learning_rate', 2e-5)
eval_split = job_input.get('eval_split', 0.1)
print(f"Loading Symptom-Disease dataset...")
dataset = load_dataset("shanover/disease_symptoms_prec_full")
# Build label mapping from diseases
diseases = sorted(set(item['disease'] for item in dataset['train']))
label2id = {d: i for i, d in enumerate(diseases)}
id2label = {i: d for d, i in label2id.items()}
num_labels = len(diseases)
print(f"Found {num_labels} diseases")
training_data = []
for item in dataset['train']:
if max_samples and len(training_data) >= max_samples:
break
# Format symptoms as natural text
symptoms = item['symptoms'].replace('_', ' ').replace(',', ', ')
training_data.append({
'text': f"Patient presents with: {symptoms}",
'label': label2id[item['disease']]
})
print(f"Loaded {len(training_data)} symptom-disease samples")
if eval_split > 0:
train_data, eval_data = train_test_split(
training_data, test_size=eval_split, random_state=42
)
else:
train_data, eval_data = training_data, None
train_dataset = Dataset.from_list(train_data)
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, id2label=id2label, label2id=label2id
)
def tokenize(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256)
train_dataset = train_dataset.map(tokenize, batched=True)
if eval_dataset:
eval_dataset = eval_dataset.map(tokenize, batched=True)
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
eval_strategy='epoch' if eval_dataset else 'no',
save_strategy='no',
fp16=torch.cuda.is_available(),
report_to='none',
)
def compute_metrics(eval_pred):
from sklearn.metrics import accuracy_score, f1_score, top_k_accuracy_score
preds = eval_pred.predictions.argmax(-1)
metrics = {
'accuracy': accuracy_score(eval_pred.label_ids, preds),
'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'),
}
# Top-5 accuracy (important for diagnosis)
try:
metrics['top5_accuracy'] = top_k_accuracy_score(
eval_pred.label_ids, eval_pred.predictions, k=5
)
except:
pass
return metrics
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics if eval_dataset else None,
)
train_result = trainer.train()
metrics = {
'task': 'symptom_disease_classification',
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(train_data),
'num_diseases': num_labels,
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'data_source': 'disease_symptoms_prec_full',
}
if eval_dataset:
eval_result = trainer.evaluate()
metrics.update({
'eval_accuracy': float(eval_result['eval_accuracy']),
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
})
if 'eval_top5_accuracy' in eval_result:
metrics['eval_top5_accuracy'] = float(eval_result['eval_top5_accuracy'])
s3_uri = None
s3_bucket = job_input.get('s3_bucket')
if s3_bucket:
save_dir = os.path.join(work_dir, 'saved_model')
trainer.save_model(save_dir)
tokenizer.save_pretrained(save_dir)
# Save label mapping
with open(os.path.join(save_dir, 'disease_labels.json'), 'w') as f:
json.dump({'id2label': id2label, 'label2id': label2id}, f)
aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_session_token': job_input.get('aws_session_token'),
'aws_region': job_input.get('aws_region', 'us-east-1'),
}
s3_prefix = job_input.get('s3_prefix', 'symptom-disease-models/bert')
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
metrics['s3_uri'] = s3_uri
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'Symptom-Disease classifier trained'}
except Exception as e:
import traceback
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
def handler(job):
"""RunPod serverless handler."""
"""RunPod serverless handler with multi-task support."""
job_input = job.get('input', {})
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
use_lora = job_input.get('use_lora', True)
# Task routing
task = job_input.get('task', 'ddi')
# Auto-detect: use LoRA for large models
if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']):
use_lora = True
elif 'bert' in model_name.lower():
use_lora = False
if use_lora:
if task == 'ade':
return train_ade_classifier(job_input)
elif task == 'triage':
return train_triage_classifier(job_input)
elif task == 'symptom_disease':
return train_symptom_disease_classifier(job_input)
elif task == 'ddi' or task == 'bert':
# Original DDI training
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
if 'bert' in model_name.lower():
return train_bert_classifier(job_input)
else:
return train_llm_lora(job_input)
elif task == 'llm':
return train_llm_lora(job_input)
else:
return train_bert_classifier(job_input)
# Auto-detect based on model
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']):
return train_llm_lora(job_input)
else:
return train_bert_classifier(job_input)
# RunPod serverless entrypoint