Files
kubeflow-pipelines/components/runpod_trainer/handler.py
Greg Hendrickson 0bf3837e78 feat: add ADE, Triage, and Symptom-Disease training pipelines
New tasks supported:
- task=ade: Adverse Drug Event classification (ADE Corpus V2, 30K samples)
- task=triage: Medical Triage classification (urgency levels)
- task=symptom_disease: Symptom-to-Disease prediction (40+ diseases)

All use HuggingFace datasets, Bio_ClinicalBERT, and S3 model storage.
2026-02-03 16:20:55 +00:00

929 lines
35 KiB
Python

"""
RunPod Serverless Handler for DDI Model Training
Supports both BERT-style classification and LLM fine-tuning with LoRA.
Uses 176K real DrugBank DDI samples with drug names.
Saves trained models to S3.
"""
import os
import json
import runpod
import boto3
from datetime import datetime
from typing import Dict, Any, List, Optional
def upload_to_s3(local_path: str, s3_bucket: str, s3_prefix: str, aws_credentials: Dict) -> str:
"""Upload a directory to S3 and return the S3 URI."""
import tarfile
import tempfile
# Create tar.gz of the model directory
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
tar_name = f"model_{timestamp}.tar.gz"
tar_path = os.path.join(tempfile.gettempdir(), tar_name)
print(f"Creating archive: {tar_path}")
with tarfile.open(tar_path, "w:gz") as tar:
tar.add(local_path, arcname="model")
# Upload to S3
s3_client = boto3.client(
's3',
aws_access_key_id=aws_credentials.get('aws_access_key_id'),
aws_secret_access_key=aws_credentials.get('aws_secret_access_key'),
aws_session_token=aws_credentials.get('aws_session_token'),
region_name=aws_credentials.get('aws_region', 'us-east-1')
)
s3_key = f"{s3_prefix}/{tar_name}"
print(f"Uploading to s3://{s3_bucket}/{s3_key}")
s3_client.upload_file(tar_path, s3_bucket, s3_key)
# Cleanup
os.remove(tar_path)
return f"s3://{s3_bucket}/{s3_key}"
# DDI severity labels
DDI_SEVERITY = {
1: "minor",
2: "moderate",
3: "major",
4: "contraindicated"
}
def load_drugbank_data(max_samples: int = None, severity_filter: List[int] = None) -> List[Dict]:
"""Load real DrugBank DDI data from bundled file."""
data_path = os.environ.get('DDI_DATA_PATH', '/app/data/drugbank_ddi_complete.jsonl')
if not os.path.exists(data_path):
print(f"WARNING: Data file not found at {data_path}, using curated fallback")
return get_curated_fallback()
data = []
with open(data_path) as f:
for line in f:
item = json.loads(line)
if severity_filter and item['severity'] not in severity_filter:
continue
data.append(item)
if max_samples and len(data) >= max_samples:
break
return data
def get_curated_fallback() -> List[Dict]:
"""Fallback curated data if main file not available."""
patterns = [
{"drug1": "fluoxetine", "drug2": "tramadol", "interaction_text": "fluoxetine may increase the risk of serotonin syndrome when combined with tramadol", "severity": 4},
{"drug1": "warfarin", "drug2": "aspirin", "interaction_text": "warfarin may increase the risk of bleeding when combined with aspirin", "severity": 3},
{"drug1": "simvastatin", "drug2": "amlodipine", "interaction_text": "The serum concentration of simvastatin can be increased when combined with amlodipine", "severity": 2},
{"drug1": "metformin", "drug2": "lisinopril", "interaction_text": "metformin and lisinopril have no significant interaction", "severity": 1},
]
return patterns * 50 # 200 samples
def format_for_llm(item: Dict, model_name: str = "") -> str:
"""Format DDI item for LLM instruction tuning. Auto-detects format based on model."""
severity_name = DDI_SEVERITY.get(item['severity'], 'unknown')
user_msg = f"Analyze the drug interaction between {item['drug1']} and {item['drug2']}."
assistant_msg = f"""Interaction: {item['interaction_text']}
Severity: {severity_name.upper()}
Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}"""
# Llama 3 format
if 'llama' in model_name.lower():
return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{user_msg}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{assistant_msg}<|eot_id|>"""
# Mistral/Mixtral format
elif 'mistral' in model_name.lower() or 'mixtral' in model_name.lower():
return f"""<s>[INST] {user_msg} [/INST] {assistant_msg}</s>"""
# Qwen format
elif 'qwen' in model_name.lower():
return f"""<|im_start|>user
{user_msg}<|im_end|>
<|im_start|>assistant
{assistant_msg}<|im_end|>"""
# Gemma format
elif 'gemma' in model_name.lower():
return f"""<start_of_turn>user
{user_msg}<end_of_turn>
<start_of_turn>model
{assistant_msg}<end_of_turn>"""
# Generic fallback
else:
return f"""### User: {user_msg}\n### Assistant: {assistant_msg}"""
def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""Train LLM with QLoRA for DDI classification. Supports Llama, Mistral, Qwen, Gemma."""
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import Dataset
import tempfile
import shutil
# Parameters - Default to Llama 3.1 8B (Bedrock-compatible)
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
max_samples = job_input.get('max_samples', 10000)
epochs = job_input.get('epochs', 1)
learning_rate = job_input.get('learning_rate', 2e-4)
batch_size = job_input.get('batch_size', 2)
lora_r = job_input.get('lora_r', 16)
lora_alpha = job_input.get('lora_alpha', 32)
max_seq_length = job_input.get('max_seq_length', 512)
severity_filter = job_input.get('severity_filter', None) # e.g., [3, 4] for major/contraindicated only
work_dir = tempfile.mkdtemp()
try:
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"Model: {model_name}")
print(f"LoRA r={lora_r}, alpha={lora_alpha}")
# Load training data
print(f"Loading DrugBank DDI data (max {max_samples})...")
raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter)
# Format for the target LLM
formatted_data = [{"text": format_for_llm(item, model_name)} for item in raw_data]
dataset = Dataset.from_list(formatted_data)
print(f"Dataset size: {len(dataset)}")
# Severity distribution
from collections import Counter
sev_dist = Counter(item['severity'] for item in raw_data)
print(f"Severity distribution: {dict(sev_dist)}")
# QLoRA config - 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load model
print(f"Loading {model_name} with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)
# LoRA config
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")
# Training arguments
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=8,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=0.1,
logging_steps=25,
save_strategy="no",
bf16=True,
gradient_checkpointing=True,
optim="paged_adamw_8bit",
report_to="none",
max_grad_norm=0.3,
)
# SFT Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
peft_config=lora_config,
processing_class=tokenizer,
max_seq_length=max_seq_length,
)
# Train
print("Starting LoRA fine-tuning...")
train_result = trainer.train()
# Metrics
metrics = {
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(raw_data),
'lora_r': lora_r,
'lora_alpha': lora_alpha,
'trainable_params': trainable_params,
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'vram_gb': torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0,
'data_source': 'drugbank_176k',
'severity_dist': dict(sev_dist),
}
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
# Save LoRA adapter to S3 if credentials provided
s3_uri = None
s3_bucket = job_input.get('s3_bucket')
if s3_bucket:
save_dir = os.path.join(work_dir, 'lora_adapter')
print(f"Saving LoRA adapter to {save_dir}...")
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_session_token': job_input.get('aws_session_token'),
'aws_region': job_input.get('aws_region', 'us-east-1'),
}
model_short = model_name.split('/')[-1]
s3_prefix = job_input.get('s3_prefix', f'ddi-models/lora-{model_short}')
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
metrics['s3_uri'] = s3_uri
print(f"LoRA adapter uploaded to {s3_uri}")
return {
'status': 'success',
'metrics': metrics,
'model_uri': s3_uri,
'message': f'LLM fine-tuned on {len(raw_data):,} real DrugBank DDI samples'
}
except Exception as e:
import traceback
return {
'status': 'error',
'error': str(e),
'traceback': traceback.format_exc()
}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""Train BERT-style classifier for DDI severity prediction."""
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer
)
from datasets import Dataset
from sklearn.model_selection import train_test_split
import tempfile
import shutil
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
max_samples = job_input.get('max_samples', 50000)
epochs = job_input.get('epochs', 3)
learning_rate = job_input.get('learning_rate', 2e-5)
batch_size = job_input.get('batch_size', 16)
eval_split = job_input.get('eval_split', 0.1)
work_dir = tempfile.mkdtemp()
try:
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"Model: {model_name}")
# Load data
raw_data = load_drugbank_data(max_samples=max_samples)
# Create text + label format
# Shift severity to 0-indexed (1-4 -> 0-3)
training_data = [{
"text": f"{d['drug1']} and {d['drug2']}: {d['interaction_text']}",
"label": d['severity'] - 1 # 0-indexed
} for d in raw_data]
print(f"Loaded {len(training_data)} samples")
# Split
if eval_split > 0:
train_data, eval_data = train_test_split(
training_data, test_size=eval_split, random_state=42,
stratify=[d['label'] for d in training_data]
)
else:
train_data, eval_data = training_data, None
train_dataset = Dataset.from_list(train_data)
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
# Load model (4 classes: minor, moderate, major, contraindicated)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=4
)
def tokenize(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256)
train_dataset = train_dataset.map(tokenize, batched=True)
if eval_dataset:
eval_dataset = eval_dataset.map(tokenize, batched=True)
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
eval_strategy='epoch' if eval_dataset else 'no',
save_strategy='no',
fp16=torch.cuda.is_available(),
report_to='none',
)
def compute_metrics(eval_pred):
from sklearn.metrics import accuracy_score, f1_score
preds = eval_pred.predictions.argmax(-1)
return {
'accuracy': accuracy_score(eval_pred.label_ids, preds),
'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics if eval_dataset else None,
)
train_result = trainer.train()
metrics = {
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(train_data),
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'data_source': 'drugbank_176k',
}
if eval_dataset:
eval_result = trainer.evaluate()
metrics.update({
'eval_accuracy': float(eval_result['eval_accuracy']),
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
})
# Save model to S3 if credentials provided
s3_uri = None
s3_bucket = job_input.get('s3_bucket')
if s3_bucket:
save_dir = os.path.join(work_dir, 'saved_model')
print(f"Saving model to {save_dir}...")
trainer.save_model(save_dir)
tokenizer.save_pretrained(save_dir)
aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_session_token': job_input.get('aws_session_token'),
'aws_region': job_input.get('aws_region', 'us-east-1'),
}
s3_prefix = job_input.get('s3_prefix', 'ddi-models/bert')
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
metrics['s3_uri'] = s3_uri
print(f"Model uploaded to {s3_uri}")
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'BERT classifier trained on DrugBank data'}
except Exception as e:
import traceback
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
def train_ade_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""
Train Adverse Drug Event (ADE) binary classifier.
Dataset: ade-benchmark-corpus/ade_corpus_v2 (30K samples)
Labels: 0=No ADE, 1=ADE Present
"""
import torch
import tempfile
import shutil
from datasets import load_dataset
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer
)
from sklearn.model_selection import train_test_split
work_dir = tempfile.mkdtemp()
try:
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
max_samples = job_input.get('max_samples', 10000)
epochs = job_input.get('epochs', 3)
batch_size = job_input.get('batch_size', 16)
learning_rate = job_input.get('learning_rate', 2e-5)
eval_split = job_input.get('eval_split', 0.1)
print(f"Loading ADE Corpus V2 dataset...")
dataset = load_dataset("ade-benchmark-corpus/ade_corpus_v2", "Ade_corpus_v2_classification")
# Prepare data
training_data = []
for item in dataset['train']:
if max_samples and len(training_data) >= max_samples:
break
training_data.append({
'text': item['text'],
'label': item['label'] # 0 or 1
})
print(f"Loaded {len(training_data)} ADE samples")
# Split
if eval_split > 0:
train_data, eval_data = train_test_split(
training_data, test_size=eval_split, random_state=42,
stratify=[d['label'] for d in training_data]
)
else:
train_data, eval_data = training_data, None
from datasets import Dataset
train_dataset = Dataset.from_list(train_data)
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
# Load model (binary: ADE / No ADE)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=2
)
def tokenize(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256)
train_dataset = train_dataset.map(tokenize, batched=True)
if eval_dataset:
eval_dataset = eval_dataset.map(tokenize, batched=True)
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
eval_strategy='epoch' if eval_dataset else 'no',
save_strategy='no',
fp16=torch.cuda.is_available(),
report_to='none',
)
def compute_metrics(eval_pred):
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
preds = eval_pred.predictions.argmax(-1)
return {
'accuracy': accuracy_score(eval_pred.label_ids, preds),
'f1': f1_score(eval_pred.label_ids, preds),
'precision': precision_score(eval_pred.label_ids, preds),
'recall': recall_score(eval_pred.label_ids, preds),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics if eval_dataset else None,
)
train_result = trainer.train()
metrics = {
'task': 'ade_classification',
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(train_data),
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'data_source': 'ade_corpus_v2',
}
if eval_dataset:
eval_result = trainer.evaluate()
metrics.update({
'eval_accuracy': float(eval_result['eval_accuracy']),
'eval_f1': float(eval_result['eval_f1']),
'eval_precision': float(eval_result['eval_precision']),
'eval_recall': float(eval_result['eval_recall']),
})
# Save to S3
s3_uri = None
s3_bucket = job_input.get('s3_bucket')
if s3_bucket:
save_dir = os.path.join(work_dir, 'saved_model')
trainer.save_model(save_dir)
tokenizer.save_pretrained(save_dir)
aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_session_token': job_input.get('aws_session_token'),
'aws_region': job_input.get('aws_region', 'us-east-1'),
}
s3_prefix = job_input.get('s3_prefix', 'ade-models/bert')
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
metrics['s3_uri'] = s3_uri
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'ADE classifier trained on ADE Corpus V2'}
except Exception as e:
import traceback
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
def train_triage_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""
Train Medical Triage classifier.
Dataset: shubham212/Medical_Triage_Classification
Labels: Triage urgency levels
"""
import torch
import tempfile
import shutil
from datasets import load_dataset, Dataset
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer
)
from sklearn.model_selection import train_test_split
work_dir = tempfile.mkdtemp()
try:
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
max_samples = job_input.get('max_samples', 5000)
epochs = job_input.get('epochs', 3)
batch_size = job_input.get('batch_size', 8)
learning_rate = job_input.get('learning_rate', 2e-5)
eval_split = job_input.get('eval_split', 0.1)
print(f"Loading Medical Triage dataset...")
dataset = load_dataset("shubham212/Medical_Triage_Classification")
# Get unique labels
labels = sorted(set(item['label'] for item in dataset['train']))
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
num_labels = len(labels)
print(f"Found {num_labels} triage levels: {labels}")
training_data = []
for item in dataset['train']:
if max_samples and len(training_data) >= max_samples:
break
training_data.append({
'text': item['text'],
'label': label2id[item['label']]
})
print(f"Loaded {len(training_data)} triage samples")
if eval_split > 0:
train_data, eval_data = train_test_split(
training_data, test_size=eval_split, random_state=42,
stratify=[d['label'] for d in training_data]
)
else:
train_data, eval_data = training_data, None
train_dataset = Dataset.from_list(train_data)
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, id2label=id2label, label2id=label2id
)
def tokenize(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)
train_dataset = train_dataset.map(tokenize, batched=True)
if eval_dataset:
eval_dataset = eval_dataset.map(tokenize, batched=True)
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
eval_strategy='epoch' if eval_dataset else 'no',
save_strategy='no',
fp16=torch.cuda.is_available(),
report_to='none',
)
def compute_metrics(eval_pred):
from sklearn.metrics import accuracy_score, f1_score
preds = eval_pred.predictions.argmax(-1)
return {
'accuracy': accuracy_score(eval_pred.label_ids, preds),
'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics if eval_dataset else None,
)
train_result = trainer.train()
metrics = {
'task': 'triage_classification',
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(train_data),
'num_labels': num_labels,
'labels': id2label,
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'data_source': 'medical_triage_classification',
}
if eval_dataset:
eval_result = trainer.evaluate()
metrics.update({
'eval_accuracy': float(eval_result['eval_accuracy']),
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
})
s3_uri = None
s3_bucket = job_input.get('s3_bucket')
if s3_bucket:
save_dir = os.path.join(work_dir, 'saved_model')
trainer.save_model(save_dir)
tokenizer.save_pretrained(save_dir)
aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_session_token': job_input.get('aws_session_token'),
'aws_region': job_input.get('aws_region', 'us-east-1'),
}
s3_prefix = job_input.get('s3_prefix', 'triage-models/bert')
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
metrics['s3_uri'] = s3_uri
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'Triage classifier trained'}
except Exception as e:
import traceback
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
def train_symptom_disease_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""
Train Symptom-to-Disease classifier.
Dataset: shanover/disease_symptoms_prec_full
Task: Predict disease from symptoms
"""
import torch
import tempfile
import shutil
from datasets import load_dataset, Dataset
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer
)
from sklearn.model_selection import train_test_split
work_dir = tempfile.mkdtemp()
try:
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
max_samples = job_input.get('max_samples', 5000)
epochs = job_input.get('epochs', 3)
batch_size = job_input.get('batch_size', 16)
learning_rate = job_input.get('learning_rate', 2e-5)
eval_split = job_input.get('eval_split', 0.1)
print(f"Loading Symptom-Disease dataset...")
dataset = load_dataset("shanover/disease_symptoms_prec_full")
# Build label mapping from diseases
diseases = sorted(set(item['disease'] for item in dataset['train']))
label2id = {d: i for i, d in enumerate(diseases)}
id2label = {i: d for d, i in label2id.items()}
num_labels = len(diseases)
print(f"Found {num_labels} diseases")
training_data = []
for item in dataset['train']:
if max_samples and len(training_data) >= max_samples:
break
# Format symptoms as natural text
symptoms = item['symptoms'].replace('_', ' ').replace(',', ', ')
training_data.append({
'text': f"Patient presents with: {symptoms}",
'label': label2id[item['disease']]
})
print(f"Loaded {len(training_data)} symptom-disease samples")
if eval_split > 0:
train_data, eval_data = train_test_split(
training_data, test_size=eval_split, random_state=42
)
else:
train_data, eval_data = training_data, None
train_dataset = Dataset.from_list(train_data)
eval_dataset = Dataset.from_list(eval_data) if eval_data else None
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, id2label=id2label, label2id=label2id
)
def tokenize(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256)
train_dataset = train_dataset.map(tokenize, batched=True)
if eval_dataset:
eval_dataset = eval_dataset.map(tokenize, batched=True)
training_args = TrainingArguments(
output_dir=work_dir,
num_train_epochs=epochs,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
eval_strategy='epoch' if eval_dataset else 'no',
save_strategy='no',
fp16=torch.cuda.is_available(),
report_to='none',
)
def compute_metrics(eval_pred):
from sklearn.metrics import accuracy_score, f1_score, top_k_accuracy_score
preds = eval_pred.predictions.argmax(-1)
metrics = {
'accuracy': accuracy_score(eval_pred.label_ids, preds),
'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'),
}
# Top-5 accuracy (important for diagnosis)
try:
metrics['top5_accuracy'] = top_k_accuracy_score(
eval_pred.label_ids, eval_pred.predictions, k=5
)
except:
pass
return metrics
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics if eval_dataset else None,
)
train_result = trainer.train()
metrics = {
'task': 'symptom_disease_classification',
'train_loss': float(train_result.training_loss),
'epochs': epochs,
'model_name': model_name,
'samples': len(train_data),
'num_diseases': num_labels,
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
'data_source': 'disease_symptoms_prec_full',
}
if eval_dataset:
eval_result = trainer.evaluate()
metrics.update({
'eval_accuracy': float(eval_result['eval_accuracy']),
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
})
if 'eval_top5_accuracy' in eval_result:
metrics['eval_top5_accuracy'] = float(eval_result['eval_top5_accuracy'])
s3_uri = None
s3_bucket = job_input.get('s3_bucket')
if s3_bucket:
save_dir = os.path.join(work_dir, 'saved_model')
trainer.save_model(save_dir)
tokenizer.save_pretrained(save_dir)
# Save label mapping
with open(os.path.join(save_dir, 'disease_labels.json'), 'w') as f:
json.dump({'id2label': id2label, 'label2id': label2id}, f)
aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_session_token': job_input.get('aws_session_token'),
'aws_region': job_input.get('aws_region', 'us-east-1'),
}
s3_prefix = job_input.get('s3_prefix', 'symptom-disease-models/bert')
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
metrics['s3_uri'] = s3_uri
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'Symptom-Disease classifier trained'}
except Exception as e:
import traceback
return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()}
finally:
shutil.rmtree(work_dir, ignore_errors=True)
def handler(job):
"""RunPod serverless handler with multi-task support."""
job_input = job.get('input', {})
# Task routing
task = job_input.get('task', 'ddi')
if task == 'ade':
return train_ade_classifier(job_input)
elif task == 'triage':
return train_triage_classifier(job_input)
elif task == 'symptom_disease':
return train_symptom_disease_classifier(job_input)
elif task == 'ddi' or task == 'bert':
# Original DDI training
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
if 'bert' in model_name.lower():
return train_bert_classifier(job_input)
else:
return train_llm_lora(job_input)
elif task == 'llm':
return train_llm_lora(job_input)
else:
# Auto-detect based on model
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']):
return train_llm_lora(job_input)
else:
return train_bert_classifier(job_input)
# RunPod serverless entrypoint
runpod.serverless.start({'handler': handler})