diff --git a/components/runpod_trainer/handler.py b/components/runpod_trainer/handler.py index 2ad2f18..55fb8dc 100644 --- a/components/runpod_trainer/handler.py +++ b/components/runpod_trainer/handler.py @@ -445,23 +445,483 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: shutil.rmtree(work_dir, ignore_errors=True) +def train_ade_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: + """ + Train Adverse Drug Event (ADE) binary classifier. + Dataset: ade-benchmark-corpus/ade_corpus_v2 (30K samples) + Labels: 0=No ADE, 1=ADE Present + """ + import torch + import tempfile + import shutil + from datasets import load_dataset + from transformers import ( + AutoTokenizer, AutoModelForSequenceClassification, + TrainingArguments, Trainer + ) + from sklearn.model_selection import train_test_split + + work_dir = tempfile.mkdtemp() + + try: + model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') + max_samples = job_input.get('max_samples', 10000) + epochs = job_input.get('epochs', 3) + batch_size = job_input.get('batch_size', 16) + learning_rate = job_input.get('learning_rate', 2e-5) + eval_split = job_input.get('eval_split', 0.1) + + print(f"Loading ADE Corpus V2 dataset...") + dataset = load_dataset("ade-benchmark-corpus/ade_corpus_v2", "Ade_corpus_v2_classification") + + # Prepare data + training_data = [] + for item in dataset['train']: + if max_samples and len(training_data) >= max_samples: + break + training_data.append({ + 'text': item['text'], + 'label': item['label'] # 0 or 1 + }) + + print(f"Loaded {len(training_data)} ADE samples") + + # Split + if eval_split > 0: + train_data, eval_data = train_test_split( + training_data, test_size=eval_split, random_state=42, + stratify=[d['label'] for d in training_data] + ) + else: + train_data, eval_data = training_data, None + + from datasets import Dataset + train_dataset = Dataset.from_list(train_data) + eval_dataset = Dataset.from_list(eval_data) if eval_data else None + + # Load model (binary: ADE / No ADE) + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSequenceClassification.from_pretrained( + model_name, num_labels=2 + ) + + def tokenize(examples): + return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256) + + train_dataset = train_dataset.map(tokenize, batched=True) + if eval_dataset: + eval_dataset = eval_dataset.map(tokenize, batched=True) + + training_args = TrainingArguments( + output_dir=work_dir, + num_train_epochs=epochs, + learning_rate=learning_rate, + per_device_train_batch_size=batch_size, + eval_strategy='epoch' if eval_dataset else 'no', + save_strategy='no', + fp16=torch.cuda.is_available(), + report_to='none', + ) + + def compute_metrics(eval_pred): + from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score + preds = eval_pred.predictions.argmax(-1) + return { + 'accuracy': accuracy_score(eval_pred.label_ids, preds), + 'f1': f1_score(eval_pred.label_ids, preds), + 'precision': precision_score(eval_pred.label_ids, preds), + 'recall': recall_score(eval_pred.label_ids, preds), + } + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics if eval_dataset else None, + ) + + train_result = trainer.train() + + metrics = { + 'task': 'ade_classification', + 'train_loss': float(train_result.training_loss), + 'epochs': epochs, + 'model_name': model_name, + 'samples': len(train_data), + 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', + 'data_source': 'ade_corpus_v2', + } + + if eval_dataset: + eval_result = trainer.evaluate() + metrics.update({ + 'eval_accuracy': float(eval_result['eval_accuracy']), + 'eval_f1': float(eval_result['eval_f1']), + 'eval_precision': float(eval_result['eval_precision']), + 'eval_recall': float(eval_result['eval_recall']), + }) + + # Save to S3 + s3_uri = None + s3_bucket = job_input.get('s3_bucket') + if s3_bucket: + save_dir = os.path.join(work_dir, 'saved_model') + trainer.save_model(save_dir) + tokenizer.save_pretrained(save_dir) + + aws_creds = { + 'aws_access_key_id': job_input.get('aws_access_key_id'), + 'aws_secret_access_key': job_input.get('aws_secret_access_key'), + 'aws_session_token': job_input.get('aws_session_token'), + 'aws_region': job_input.get('aws_region', 'us-east-1'), + } + s3_prefix = job_input.get('s3_prefix', 'ade-models/bert') + s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds) + metrics['s3_uri'] = s3_uri + + return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'ADE classifier trained on ADE Corpus V2'} + + except Exception as e: + import traceback + return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()} + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + +def train_triage_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: + """ + Train Medical Triage classifier. + Dataset: shubham212/Medical_Triage_Classification + Labels: Triage urgency levels + """ + import torch + import tempfile + import shutil + from datasets import load_dataset, Dataset + from transformers import ( + AutoTokenizer, AutoModelForSequenceClassification, + TrainingArguments, Trainer + ) + from sklearn.model_selection import train_test_split + + work_dir = tempfile.mkdtemp() + + try: + model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') + max_samples = job_input.get('max_samples', 5000) + epochs = job_input.get('epochs', 3) + batch_size = job_input.get('batch_size', 8) + learning_rate = job_input.get('learning_rate', 2e-5) + eval_split = job_input.get('eval_split', 0.1) + + print(f"Loading Medical Triage dataset...") + dataset = load_dataset("shubham212/Medical_Triage_Classification") + + # Get unique labels + labels = sorted(set(item['label'] for item in dataset['train'])) + label2id = {l: i for i, l in enumerate(labels)} + id2label = {i: l for l, i in label2id.items()} + num_labels = len(labels) + + print(f"Found {num_labels} triage levels: {labels}") + + training_data = [] + for item in dataset['train']: + if max_samples and len(training_data) >= max_samples: + break + training_data.append({ + 'text': item['text'], + 'label': label2id[item['label']] + }) + + print(f"Loaded {len(training_data)} triage samples") + + if eval_split > 0: + train_data, eval_data = train_test_split( + training_data, test_size=eval_split, random_state=42, + stratify=[d['label'] for d in training_data] + ) + else: + train_data, eval_data = training_data, None + + train_dataset = Dataset.from_list(train_data) + eval_dataset = Dataset.from_list(eval_data) if eval_data else None + + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSequenceClassification.from_pretrained( + model_name, num_labels=num_labels, id2label=id2label, label2id=label2id + ) + + def tokenize(examples): + return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512) + + train_dataset = train_dataset.map(tokenize, batched=True) + if eval_dataset: + eval_dataset = eval_dataset.map(tokenize, batched=True) + + training_args = TrainingArguments( + output_dir=work_dir, + num_train_epochs=epochs, + learning_rate=learning_rate, + per_device_train_batch_size=batch_size, + eval_strategy='epoch' if eval_dataset else 'no', + save_strategy='no', + fp16=torch.cuda.is_available(), + report_to='none', + ) + + def compute_metrics(eval_pred): + from sklearn.metrics import accuracy_score, f1_score + preds = eval_pred.predictions.argmax(-1) + return { + 'accuracy': accuracy_score(eval_pred.label_ids, preds), + 'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'), + } + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics if eval_dataset else None, + ) + + train_result = trainer.train() + + metrics = { + 'task': 'triage_classification', + 'train_loss': float(train_result.training_loss), + 'epochs': epochs, + 'model_name': model_name, + 'samples': len(train_data), + 'num_labels': num_labels, + 'labels': id2label, + 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', + 'data_source': 'medical_triage_classification', + } + + if eval_dataset: + eval_result = trainer.evaluate() + metrics.update({ + 'eval_accuracy': float(eval_result['eval_accuracy']), + 'eval_f1_weighted': float(eval_result['eval_f1_weighted']), + }) + + s3_uri = None + s3_bucket = job_input.get('s3_bucket') + if s3_bucket: + save_dir = os.path.join(work_dir, 'saved_model') + trainer.save_model(save_dir) + tokenizer.save_pretrained(save_dir) + + aws_creds = { + 'aws_access_key_id': job_input.get('aws_access_key_id'), + 'aws_secret_access_key': job_input.get('aws_secret_access_key'), + 'aws_session_token': job_input.get('aws_session_token'), + 'aws_region': job_input.get('aws_region', 'us-east-1'), + } + s3_prefix = job_input.get('s3_prefix', 'triage-models/bert') + s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds) + metrics['s3_uri'] = s3_uri + + return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'Triage classifier trained'} + + except Exception as e: + import traceback + return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()} + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + +def train_symptom_disease_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: + """ + Train Symptom-to-Disease classifier. + Dataset: shanover/disease_symptoms_prec_full + Task: Predict disease from symptoms + """ + import torch + import tempfile + import shutil + from datasets import load_dataset, Dataset + from transformers import ( + AutoTokenizer, AutoModelForSequenceClassification, + TrainingArguments, Trainer + ) + from sklearn.model_selection import train_test_split + + work_dir = tempfile.mkdtemp() + + try: + model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') + max_samples = job_input.get('max_samples', 5000) + epochs = job_input.get('epochs', 3) + batch_size = job_input.get('batch_size', 16) + learning_rate = job_input.get('learning_rate', 2e-5) + eval_split = job_input.get('eval_split', 0.1) + + print(f"Loading Symptom-Disease dataset...") + dataset = load_dataset("shanover/disease_symptoms_prec_full") + + # Build label mapping from diseases + diseases = sorted(set(item['disease'] for item in dataset['train'])) + label2id = {d: i for i, d in enumerate(diseases)} + id2label = {i: d for d, i in label2id.items()} + num_labels = len(diseases) + + print(f"Found {num_labels} diseases") + + training_data = [] + for item in dataset['train']: + if max_samples and len(training_data) >= max_samples: + break + # Format symptoms as natural text + symptoms = item['symptoms'].replace('_', ' ').replace(',', ', ') + training_data.append({ + 'text': f"Patient presents with: {symptoms}", + 'label': label2id[item['disease']] + }) + + print(f"Loaded {len(training_data)} symptom-disease samples") + + if eval_split > 0: + train_data, eval_data = train_test_split( + training_data, test_size=eval_split, random_state=42 + ) + else: + train_data, eval_data = training_data, None + + train_dataset = Dataset.from_list(train_data) + eval_dataset = Dataset.from_list(eval_data) if eval_data else None + + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSequenceClassification.from_pretrained( + model_name, num_labels=num_labels, id2label=id2label, label2id=label2id + ) + + def tokenize(examples): + return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256) + + train_dataset = train_dataset.map(tokenize, batched=True) + if eval_dataset: + eval_dataset = eval_dataset.map(tokenize, batched=True) + + training_args = TrainingArguments( + output_dir=work_dir, + num_train_epochs=epochs, + learning_rate=learning_rate, + per_device_train_batch_size=batch_size, + eval_strategy='epoch' if eval_dataset else 'no', + save_strategy='no', + fp16=torch.cuda.is_available(), + report_to='none', + ) + + def compute_metrics(eval_pred): + from sklearn.metrics import accuracy_score, f1_score, top_k_accuracy_score + preds = eval_pred.predictions.argmax(-1) + metrics = { + 'accuracy': accuracy_score(eval_pred.label_ids, preds), + 'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'), + } + # Top-5 accuracy (important for diagnosis) + try: + metrics['top5_accuracy'] = top_k_accuracy_score( + eval_pred.label_ids, eval_pred.predictions, k=5 + ) + except: + pass + return metrics + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics if eval_dataset else None, + ) + + train_result = trainer.train() + + metrics = { + 'task': 'symptom_disease_classification', + 'train_loss': float(train_result.training_loss), + 'epochs': epochs, + 'model_name': model_name, + 'samples': len(train_data), + 'num_diseases': num_labels, + 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', + 'data_source': 'disease_symptoms_prec_full', + } + + if eval_dataset: + eval_result = trainer.evaluate() + metrics.update({ + 'eval_accuracy': float(eval_result['eval_accuracy']), + 'eval_f1_weighted': float(eval_result['eval_f1_weighted']), + }) + if 'eval_top5_accuracy' in eval_result: + metrics['eval_top5_accuracy'] = float(eval_result['eval_top5_accuracy']) + + s3_uri = None + s3_bucket = job_input.get('s3_bucket') + if s3_bucket: + save_dir = os.path.join(work_dir, 'saved_model') + trainer.save_model(save_dir) + tokenizer.save_pretrained(save_dir) + + # Save label mapping + with open(os.path.join(save_dir, 'disease_labels.json'), 'w') as f: + json.dump({'id2label': id2label, 'label2id': label2id}, f) + + aws_creds = { + 'aws_access_key_id': job_input.get('aws_access_key_id'), + 'aws_secret_access_key': job_input.get('aws_secret_access_key'), + 'aws_session_token': job_input.get('aws_session_token'), + 'aws_region': job_input.get('aws_region', 'us-east-1'), + } + s3_prefix = job_input.get('s3_prefix', 'symptom-disease-models/bert') + s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds) + metrics['s3_uri'] = s3_uri + + return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'Symptom-Disease classifier trained'} + + except Exception as e: + import traceback + return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()} + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + def handler(job): - """RunPod serverless handler.""" + """RunPod serverless handler with multi-task support.""" job_input = job.get('input', {}) - model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct') - use_lora = job_input.get('use_lora', True) + # Task routing + task = job_input.get('task', 'ddi') - # Auto-detect: use LoRA for large models - if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']): - use_lora = True - elif 'bert' in model_name.lower(): - use_lora = False - - if use_lora: + if task == 'ade': + return train_ade_classifier(job_input) + elif task == 'triage': + return train_triage_classifier(job_input) + elif task == 'symptom_disease': + return train_symptom_disease_classifier(job_input) + elif task == 'ddi' or task == 'bert': + # Original DDI training + model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') + if 'bert' in model_name.lower(): + return train_bert_classifier(job_input) + else: + return train_llm_lora(job_input) + elif task == 'llm': return train_llm_lora(job_input) else: - return train_bert_classifier(job_input) + # Auto-detect based on model + model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct') + if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']): + return train_llm_lora(job_input) + else: + return train_bert_classifier(job_input) # RunPod serverless entrypoint diff --git a/pipelines/healthcare_training.py b/pipelines/healthcare_training.py new file mode 100644 index 0000000..0e037ae --- /dev/null +++ b/pipelines/healthcare_training.py @@ -0,0 +1,376 @@ +""" +Healthcare ML Training Pipelines + +Multi-task training pipelines for: +- Adverse Drug Event (ADE) Classification +- Medical Triage Classification +- Symptom-to-Disease Prediction +- Drug-Drug Interaction (DDI) Classification + +All use RunPod serverless GPU infrastructure. +""" +from kfp import dsl +from kfp import compiler + + +# ============================================================================ +# ADE (Adverse Drug Event) Classification Pipeline +# ============================================================================ +@dsl.component( + base_image="python:3.11-slim", + packages_to_install=["requests"] +) +def train_ade_model( + runpod_api_key: str, + runpod_endpoint: str, + model_name: str, + max_samples: int, + epochs: int, + batch_size: int, + s3_bucket: str, + aws_access_key_id: str, + aws_secret_access_key: str, + aws_session_token: str, +) -> dict: + """Train ADE classifier on RunPod serverless GPU.""" + import requests + import time + + response = requests.post( + f"https://api.runpod.ai/v2/{runpod_endpoint}/run", + headers={"Authorization": f"Bearer {runpod_api_key}"}, + json={ + "input": { + "task": "ade", + "model_name": model_name, + "max_samples": max_samples, + "epochs": epochs, + "batch_size": batch_size, + "eval_split": 0.1, + "s3_bucket": s3_bucket, + "s3_prefix": "ade-models/bert", + "aws_access_key_id": aws_access_key_id, + "aws_secret_access_key": aws_secret_access_key, + "aws_session_token": aws_session_token, + } + } + ) + + job_id = response.json()["id"] + print(f"RunPod job submitted: {job_id}") + + # Poll for completion + while True: + status = requests.get( + f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}", + headers={"Authorization": f"Bearer {runpod_api_key}"} + ).json() + + if status["status"] == "COMPLETED": + return status["output"] + elif status["status"] == "FAILED": + raise Exception(f"Training failed: {status}") + + time.sleep(10) + + +@dsl.pipeline(name="ade-classification-pipeline") +def ade_classification_pipeline( + runpod_api_key: str, + runpod_endpoint: str = "k57do7afav01es", + model_name: str = "emilyalsentzer/Bio_ClinicalBERT", + max_samples: int = 10000, + epochs: int = 3, + batch_size: int = 16, + s3_bucket: str = "", + aws_access_key_id: str = "", + aws_secret_access_key: str = "", + aws_session_token: str = "", +): + """ + Adverse Drug Event Classification Pipeline + + Trains Bio_ClinicalBERT on ADE Corpus V2 (30K samples) + Binary classification: ADE present / No ADE + """ + train_task = train_ade_model( + runpod_api_key=runpod_api_key, + runpod_endpoint=runpod_endpoint, + model_name=model_name, + max_samples=max_samples, + epochs=epochs, + batch_size=batch_size, + s3_bucket=s3_bucket, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + + +# ============================================================================ +# Medical Triage Classification Pipeline +# ============================================================================ +@dsl.component( + base_image="python:3.11-slim", + packages_to_install=["requests"] +) +def train_triage_model( + runpod_api_key: str, + runpod_endpoint: str, + model_name: str, + max_samples: int, + epochs: int, + batch_size: int, + s3_bucket: str, + aws_access_key_id: str, + aws_secret_access_key: str, + aws_session_token: str, +) -> dict: + """Train Medical Triage classifier on RunPod.""" + import requests + import time + + response = requests.post( + f"https://api.runpod.ai/v2/{runpod_endpoint}/run", + headers={"Authorization": f"Bearer {runpod_api_key}"}, + json={ + "input": { + "task": "triage", + "model_name": model_name, + "max_samples": max_samples, + "epochs": epochs, + "batch_size": batch_size, + "eval_split": 0.1, + "s3_bucket": s3_bucket, + "s3_prefix": "triage-models/bert", + "aws_access_key_id": aws_access_key_id, + "aws_secret_access_key": aws_secret_access_key, + "aws_session_token": aws_session_token, + } + } + ) + + job_id = response.json()["id"] + print(f"RunPod job submitted: {job_id}") + + while True: + status = requests.get( + f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}", + headers={"Authorization": f"Bearer {runpod_api_key}"} + ).json() + + if status["status"] == "COMPLETED": + return status["output"] + elif status["status"] == "FAILED": + raise Exception(f"Training failed: {status}") + + time.sleep(10) + + +@dsl.pipeline(name="triage-classification-pipeline") +def triage_classification_pipeline( + runpod_api_key: str, + runpod_endpoint: str = "k57do7afav01es", + model_name: str = "emilyalsentzer/Bio_ClinicalBERT", + max_samples: int = 5000, + epochs: int = 3, + batch_size: int = 8, + s3_bucket: str = "", + aws_access_key_id: str = "", + aws_secret_access_key: str = "", + aws_session_token: str = "", +): + """ + Medical Triage Classification Pipeline + + Trains classifier for ER triage urgency levels. + Multi-class: Emergency, Urgent, Standard, etc. + """ + train_task = train_triage_model( + runpod_api_key=runpod_api_key, + runpod_endpoint=runpod_endpoint, + model_name=model_name, + max_samples=max_samples, + epochs=epochs, + batch_size=batch_size, + s3_bucket=s3_bucket, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + + +# ============================================================================ +# Symptom-to-Disease Classification Pipeline +# ============================================================================ +@dsl.component( + base_image="python:3.11-slim", + packages_to_install=["requests"] +) +def train_symptom_disease_model( + runpod_api_key: str, + runpod_endpoint: str, + model_name: str, + max_samples: int, + epochs: int, + batch_size: int, + s3_bucket: str, + aws_access_key_id: str, + aws_secret_access_key: str, + aws_session_token: str, +) -> dict: + """Train Symptom-to-Disease classifier on RunPod.""" + import requests + import time + + response = requests.post( + f"https://api.runpod.ai/v2/{runpod_endpoint}/run", + headers={"Authorization": f"Bearer {runpod_api_key}"}, + json={ + "input": { + "task": "symptom_disease", + "model_name": model_name, + "max_samples": max_samples, + "epochs": epochs, + "batch_size": batch_size, + "eval_split": 0.1, + "s3_bucket": s3_bucket, + "s3_prefix": "symptom-disease-models/bert", + "aws_access_key_id": aws_access_key_id, + "aws_secret_access_key": aws_secret_access_key, + "aws_session_token": aws_session_token, + } + } + ) + + job_id = response.json()["id"] + print(f"RunPod job submitted: {job_id}") + + while True: + status = requests.get( + f"https://api.runpod.ai/v2/{runpod_endpoint}/status/{job_id}", + headers={"Authorization": f"Bearer {runpod_api_key}"} + ).json() + + if status["status"] == "COMPLETED": + return status["output"] + elif status["status"] == "FAILED": + raise Exception(f"Training failed: {status}") + + time.sleep(10) + + +@dsl.pipeline(name="symptom-disease-classification-pipeline") +def symptom_disease_pipeline( + runpod_api_key: str, + runpod_endpoint: str = "k57do7afav01es", + model_name: str = "emilyalsentzer/Bio_ClinicalBERT", + max_samples: int = 5000, + epochs: int = 3, + batch_size: int = 16, + s3_bucket: str = "", + aws_access_key_id: str = "", + aws_secret_access_key: str = "", + aws_session_token: str = "", +): + """ + Symptom-to-Disease Classification Pipeline + + Predicts disease from symptom descriptions. + Multi-class: 40+ disease categories + """ + train_task = train_symptom_disease_model( + runpod_api_key=runpod_api_key, + runpod_endpoint=runpod_endpoint, + model_name=model_name, + max_samples=max_samples, + epochs=epochs, + batch_size=batch_size, + s3_bucket=s3_bucket, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + + +# ============================================================================ +# Full Healthcare Training Pipeline (All Tasks) +# ============================================================================ +@dsl.pipeline(name="healthcare-multi-task-pipeline") +def healthcare_multi_task_pipeline( + runpod_api_key: str, + runpod_endpoint: str = "k57do7afav01es", + model_name: str = "emilyalsentzer/Bio_ClinicalBERT", + s3_bucket: str = "", + aws_access_key_id: str = "", + aws_secret_access_key: str = "", + aws_session_token: str = "", +): + """ + Train all healthcare models in parallel. + + Outputs: + - ADE classifier (s3://bucket/ade-models/...) + - Triage classifier (s3://bucket/triage-models/...) + - Symptom-Disease classifier (s3://bucket/symptom-disease-models/...) + """ + # Run all training tasks in parallel + ade_task = train_ade_model( + runpod_api_key=runpod_api_key, + runpod_endpoint=runpod_endpoint, + model_name=model_name, + max_samples=10000, + epochs=3, + batch_size=16, + s3_bucket=s3_bucket, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + + triage_task = train_triage_model( + runpod_api_key=runpod_api_key, + runpod_endpoint=runpod_endpoint, + model_name=model_name, + max_samples=5000, + epochs=3, + batch_size=8, + s3_bucket=s3_bucket, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + + symptom_task = train_symptom_disease_model( + runpod_api_key=runpod_api_key, + runpod_endpoint=runpod_endpoint, + model_name=model_name, + max_samples=5000, + epochs=3, + batch_size=16, + s3_bucket=s3_bucket, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + + +if __name__ == "__main__": + # Compile pipelines + compiler.Compiler().compile( + ade_classification_pipeline, + "ade_classification_pipeline.yaml" + ) + compiler.Compiler().compile( + triage_classification_pipeline, + "triage_classification_pipeline.yaml" + ) + compiler.Compiler().compile( + symptom_disease_pipeline, + "symptom_disease_pipeline.yaml" + ) + compiler.Compiler().compile( + healthcare_multi_task_pipeline, + "healthcare_multi_task_pipeline.yaml" + ) + print("All pipelines compiled!")