""" RunPod Serverless Handler for DDI Model Training Supports both BERT-style classification and LLM fine-tuning with LoRA. Uses 176K real DrugBank DDI samples with drug names. Saves trained models to S3. """ import os import json import runpod import boto3 from datetime import datetime from typing import Dict, Any, List, Optional def upload_to_s3(local_path: str, s3_bucket: str, s3_prefix: str, aws_credentials: Dict) -> str: """Upload a directory to S3 and return the S3 URI.""" import tarfile import tempfile # Create tar.gz of the model directory timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") tar_name = f"model_{timestamp}.tar.gz" tar_path = os.path.join(tempfile.gettempdir(), tar_name) print(f"Creating archive: {tar_path}") with tarfile.open(tar_path, "w:gz") as tar: tar.add(local_path, arcname="model") # Upload to S3 s3_client = boto3.client( 's3', aws_access_key_id=aws_credentials.get('aws_access_key_id'), aws_secret_access_key=aws_credentials.get('aws_secret_access_key'), aws_session_token=aws_credentials.get('aws_session_token'), region_name=aws_credentials.get('aws_region', 'us-east-1') ) s3_key = f"{s3_prefix}/{tar_name}" print(f"Uploading to s3://{s3_bucket}/{s3_key}") s3_client.upload_file(tar_path, s3_bucket, s3_key) # Cleanup os.remove(tar_path) return f"s3://{s3_bucket}/{s3_key}" # DDI severity labels DDI_SEVERITY = { 1: "minor", 2: "moderate", 3: "major", 4: "contraindicated" } def load_drugbank_data(max_samples: int = None, severity_filter: List[int] = None) -> List[Dict]: """Load real DrugBank DDI data from bundled file.""" data_path = os.environ.get('DDI_DATA_PATH', '/app/data/drugbank_ddi_complete.jsonl') if not os.path.exists(data_path): print(f"WARNING: Data file not found at {data_path}, using curated fallback") return get_curated_fallback() data = [] with open(data_path) as f: for line in f: item = json.loads(line) if severity_filter and item['severity'] not in severity_filter: continue data.append(item) if max_samples and len(data) >= max_samples: break return data def get_curated_fallback() -> List[Dict]: """Fallback curated data if main file not available.""" patterns = [ {"drug1": "fluoxetine", "drug2": "tramadol", "interaction_text": "fluoxetine may increase the risk of serotonin syndrome when combined with tramadol", "severity": 4}, {"drug1": "warfarin", "drug2": "aspirin", "interaction_text": "warfarin may increase the risk of bleeding when combined with aspirin", "severity": 3}, {"drug1": "simvastatin", "drug2": "amlodipine", "interaction_text": "The serum concentration of simvastatin can be increased when combined with amlodipine", "severity": 2}, {"drug1": "metformin", "drug2": "lisinopril", "interaction_text": "metformin and lisinopril have no significant interaction", "severity": 1}, ] return patterns * 50 # 200 samples def format_for_llm(item: Dict, model_name: str = "") -> str: """Format DDI item for LLM instruction tuning. Auto-detects format based on model.""" severity_name = DDI_SEVERITY.get(item['severity'], 'unknown') user_msg = f"Analyze the drug interaction between {item['drug1']} and {item['drug2']}." assistant_msg = f"""Interaction: {item['interaction_text']} Severity: {severity_name.upper()} Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}""" # Llama 3 format if 'llama' in model_name.lower(): return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> {user_msg}<|eot_id|><|start_header_id|>assistant<|end_header_id|> {assistant_msg}<|eot_id|>""" # Mistral/Mixtral format elif 'mistral' in model_name.lower() or 'mixtral' in model_name.lower(): return f"""[INST] {user_msg} [/INST] {assistant_msg}""" # Qwen format elif 'qwen' in model_name.lower(): return f"""<|im_start|>user {user_msg}<|im_end|> <|im_start|>assistant {assistant_msg}<|im_end|>""" # Gemma format elif 'gemma' in model_name.lower(): return f"""user {user_msg} model {assistant_msg}""" # Generic fallback else: return f"""### User: {user_msg}\n### Assistant: {assistant_msg}""" def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]: """Train LLM with QLoRA for DDI classification. Supports Llama, Mistral, Qwen, Gemma.""" import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, ) from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from trl import SFTTrainer from datasets import Dataset import tempfile import shutil # Parameters - Default to Llama 3.1 8B (Bedrock-compatible) model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct') max_samples = job_input.get('max_samples', 10000) epochs = job_input.get('epochs', 1) learning_rate = job_input.get('learning_rate', 2e-4) batch_size = job_input.get('batch_size', 2) lora_r = job_input.get('lora_r', 16) lora_alpha = job_input.get('lora_alpha', 32) max_seq_length = job_input.get('max_seq_length', 512) severity_filter = job_input.get('severity_filter', None) # e.g., [3, 4] for major/contraindicated only work_dir = tempfile.mkdtemp() try: print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") print(f"Model: {model_name}") print(f"LoRA r={lora_r}, alpha={lora_alpha}") # Load training data print(f"Loading DrugBank DDI data (max {max_samples})...") raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter) # Format for the target LLM formatted_data = [{"text": format_for_llm(item, model_name)} for item in raw_data] dataset = Dataset.from_list(formatted_data) print(f"Dataset size: {len(dataset)}") # Severity distribution from collections import Counter sev_dist = Counter(item['severity'] for item in raw_data) print(f"Severity distribution: {dict(sev_dist)}") # QLoRA config - 4-bit quantization bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) # Load model print(f"Loading {model_name} with 4-bit quantization...") model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Prepare for k-bit training model = prepare_model_for_kbit_training(model) # LoRA config lora_config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) print(f"Trainable params: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)") # Training arguments training_args = TrainingArguments( output_dir=work_dir, num_train_epochs=epochs, per_device_train_batch_size=batch_size, gradient_accumulation_steps=8, learning_rate=learning_rate, weight_decay=0.01, warmup_ratio=0.1, logging_steps=25, save_strategy="no", bf16=True, gradient_checkpointing=True, optim="paged_adamw_8bit", report_to="none", max_grad_norm=0.3, ) # SFT Trainer trainer = SFTTrainer( model=model, train_dataset=dataset, args=training_args, peft_config=lora_config, processing_class=tokenizer, max_seq_length=max_seq_length, ) # Train print("Starting LoRA fine-tuning...") train_result = trainer.train() # Metrics metrics = { 'train_loss': float(train_result.training_loss), 'epochs': epochs, 'model_name': model_name, 'samples': len(raw_data), 'lora_r': lora_r, 'lora_alpha': lora_alpha, 'trainable_params': trainable_params, 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', 'vram_gb': torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0, 'data_source': 'drugbank_176k', 'severity_dist': dict(sev_dist), } print(f"Training complete! Loss: {metrics['train_loss']:.4f}") # Save LoRA adapter to S3 if credentials provided s3_uri = None s3_bucket = job_input.get('s3_bucket') if s3_bucket: save_dir = os.path.join(work_dir, 'lora_adapter') print(f"Saving LoRA adapter to {save_dir}...") model.save_pretrained(save_dir) tokenizer.save_pretrained(save_dir) aws_creds = { 'aws_access_key_id': job_input.get('aws_access_key_id'), 'aws_secret_access_key': job_input.get('aws_secret_access_key'), 'aws_session_token': job_input.get('aws_session_token'), 'aws_region': job_input.get('aws_region', 'us-east-1'), } model_short = model_name.split('/')[-1] s3_prefix = job_input.get('s3_prefix', f'ddi-models/lora-{model_short}') s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds) metrics['s3_uri'] = s3_uri print(f"LoRA adapter uploaded to {s3_uri}") return { 'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': f'LLM fine-tuned on {len(raw_data):,} real DrugBank DDI samples' } except Exception as e: import traceback return { 'status': 'error', 'error': str(e), 'traceback': traceback.format_exc() } finally: shutil.rmtree(work_dir, ignore_errors=True) if torch.cuda.is_available(): torch.cuda.empty_cache() def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: """Train BERT-style classifier for DDI severity prediction.""" import torch from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer ) from datasets import Dataset from sklearn.model_selection import train_test_split import tempfile import shutil model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') max_samples = job_input.get('max_samples', 50000) epochs = job_input.get('epochs', 3) learning_rate = job_input.get('learning_rate', 2e-5) batch_size = job_input.get('batch_size', 16) eval_split = job_input.get('eval_split', 0.1) work_dir = tempfile.mkdtemp() try: print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") print(f"Model: {model_name}") # Load data raw_data = load_drugbank_data(max_samples=max_samples) # Create text + label format # Shift severity to 0-indexed (1-4 -> 0-3) training_data = [{ "text": f"{d['drug1']} and {d['drug2']}: {d['interaction_text']}", "label": d['severity'] - 1 # 0-indexed } for d in raw_data] print(f"Loaded {len(training_data)} samples") # Split if eval_split > 0: train_data, eval_data = train_test_split( training_data, test_size=eval_split, random_state=42, stratify=[d['label'] for d in training_data] ) else: train_data, eval_data = training_data, None train_dataset = Dataset.from_list(train_data) eval_dataset = Dataset.from_list(eval_data) if eval_data else None # Load model (4 classes: minor, moderate, major, contraindicated) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=4 ) def tokenize(examples): return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256) train_dataset = train_dataset.map(tokenize, batched=True) if eval_dataset: eval_dataset = eval_dataset.map(tokenize, batched=True) training_args = TrainingArguments( output_dir=work_dir, num_train_epochs=epochs, learning_rate=learning_rate, per_device_train_batch_size=batch_size, eval_strategy='epoch' if eval_dataset else 'no', save_strategy='no', fp16=torch.cuda.is_available(), report_to='none', ) def compute_metrics(eval_pred): from sklearn.metrics import accuracy_score, f1_score preds = eval_pred.predictions.argmax(-1) return { 'accuracy': accuracy_score(eval_pred.label_ids, preds), 'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'), } trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics if eval_dataset else None, ) train_result = trainer.train() metrics = { 'train_loss': float(train_result.training_loss), 'epochs': epochs, 'model_name': model_name, 'samples': len(train_data), 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', 'data_source': 'drugbank_176k', } if eval_dataset: eval_result = trainer.evaluate() metrics.update({ 'eval_accuracy': float(eval_result['eval_accuracy']), 'eval_f1_weighted': float(eval_result['eval_f1_weighted']), }) # Save model to S3 if credentials provided s3_uri = None s3_bucket = job_input.get('s3_bucket') if s3_bucket: save_dir = os.path.join(work_dir, 'saved_model') print(f"Saving model to {save_dir}...") trainer.save_model(save_dir) tokenizer.save_pretrained(save_dir) aws_creds = { 'aws_access_key_id': job_input.get('aws_access_key_id'), 'aws_secret_access_key': job_input.get('aws_secret_access_key'), 'aws_session_token': job_input.get('aws_session_token'), 'aws_region': job_input.get('aws_region', 'us-east-1'), } s3_prefix = job_input.get('s3_prefix', 'ddi-models/bert') s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds) metrics['s3_uri'] = s3_uri print(f"Model uploaded to {s3_uri}") return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'BERT classifier trained on DrugBank data'} except Exception as e: import traceback return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()} finally: shutil.rmtree(work_dir, ignore_errors=True) def train_ade_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: """ Train Adverse Drug Event (ADE) binary classifier. Dataset: ade-benchmark-corpus/ade_corpus_v2 (30K samples) Labels: 0=No ADE, 1=ADE Present """ import torch import tempfile import shutil from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer ) from sklearn.model_selection import train_test_split work_dir = tempfile.mkdtemp() try: model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') max_samples = job_input.get('max_samples', 10000) epochs = job_input.get('epochs', 3) batch_size = job_input.get('batch_size', 16) learning_rate = job_input.get('learning_rate', 2e-5) eval_split = job_input.get('eval_split', 0.1) print(f"Loading ADE Corpus V2 dataset...") dataset = load_dataset("ade-benchmark-corpus/ade_corpus_v2", "Ade_corpus_v2_classification") # Prepare data training_data = [] for item in dataset['train']: if max_samples and len(training_data) >= max_samples: break training_data.append({ 'text': item['text'], 'label': item['label'] # 0 or 1 }) print(f"Loaded {len(training_data)} ADE samples") # Split if eval_split > 0: train_data, eval_data = train_test_split( training_data, test_size=eval_split, random_state=42, stratify=[d['label'] for d in training_data] ) else: train_data, eval_data = training_data, None from datasets import Dataset train_dataset = Dataset.from_list(train_data) eval_dataset = Dataset.from_list(eval_data) if eval_data else None # Load model (binary: ADE / No ADE) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=2 ) def tokenize(examples): return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256) train_dataset = train_dataset.map(tokenize, batched=True) if eval_dataset: eval_dataset = eval_dataset.map(tokenize, batched=True) training_args = TrainingArguments( output_dir=work_dir, num_train_epochs=epochs, learning_rate=learning_rate, per_device_train_batch_size=batch_size, eval_strategy='epoch' if eval_dataset else 'no', save_strategy='no', fp16=torch.cuda.is_available(), report_to='none', ) def compute_metrics(eval_pred): from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score preds = eval_pred.predictions.argmax(-1) return { 'accuracy': accuracy_score(eval_pred.label_ids, preds), 'f1': f1_score(eval_pred.label_ids, preds), 'precision': precision_score(eval_pred.label_ids, preds), 'recall': recall_score(eval_pred.label_ids, preds), } trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics if eval_dataset else None, ) train_result = trainer.train() metrics = { 'task': 'ade_classification', 'train_loss': float(train_result.training_loss), 'epochs': epochs, 'model_name': model_name, 'samples': len(train_data), 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', 'data_source': 'ade_corpus_v2', } if eval_dataset: eval_result = trainer.evaluate() metrics.update({ 'eval_accuracy': float(eval_result['eval_accuracy']), 'eval_f1': float(eval_result['eval_f1']), 'eval_precision': float(eval_result['eval_precision']), 'eval_recall': float(eval_result['eval_recall']), }) # Save to S3 s3_uri = None s3_bucket = job_input.get('s3_bucket') if s3_bucket: save_dir = os.path.join(work_dir, 'saved_model') trainer.save_model(save_dir) tokenizer.save_pretrained(save_dir) aws_creds = { 'aws_access_key_id': job_input.get('aws_access_key_id'), 'aws_secret_access_key': job_input.get('aws_secret_access_key'), 'aws_session_token': job_input.get('aws_session_token'), 'aws_region': job_input.get('aws_region', 'us-east-1'), } s3_prefix = job_input.get('s3_prefix', 'ade-models/bert') s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds) metrics['s3_uri'] = s3_uri return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'ADE classifier trained on ADE Corpus V2'} except Exception as e: import traceback return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()} finally: shutil.rmtree(work_dir, ignore_errors=True) def train_triage_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: """ Train Medical Triage classifier. Dataset: shubham212/Medical_Triage_Classification Labels: Triage urgency levels """ import torch import tempfile import shutil from datasets import load_dataset, Dataset from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer ) from sklearn.model_selection import train_test_split work_dir = tempfile.mkdtemp() try: model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') max_samples = job_input.get('max_samples', 5000) epochs = job_input.get('epochs', 3) batch_size = job_input.get('batch_size', 8) learning_rate = job_input.get('learning_rate', 2e-5) eval_split = job_input.get('eval_split', 0.1) print(f"Loading Medical Triage dataset...") dataset = load_dataset("shubham212/Medical_Triage_Classification") # Get unique labels labels = sorted(set(item['label'] for item in dataset['train'])) label2id = {l: i for i, l in enumerate(labels)} id2label = {i: l for l, i in label2id.items()} num_labels = len(labels) print(f"Found {num_labels} triage levels: {labels}") training_data = [] for item in dataset['train']: if max_samples and len(training_data) >= max_samples: break training_data.append({ 'text': item['text'], 'label': label2id[item['label']] }) print(f"Loaded {len(training_data)} triage samples") if eval_split > 0: train_data, eval_data = train_test_split( training_data, test_size=eval_split, random_state=42, stratify=[d['label'] for d in training_data] ) else: train_data, eval_data = training_data, None train_dataset = Dataset.from_list(train_data) eval_dataset = Dataset.from_list(eval_data) if eval_data else None tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, id2label=id2label, label2id=label2id ) def tokenize(examples): return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512) train_dataset = train_dataset.map(tokenize, batched=True) if eval_dataset: eval_dataset = eval_dataset.map(tokenize, batched=True) training_args = TrainingArguments( output_dir=work_dir, num_train_epochs=epochs, learning_rate=learning_rate, per_device_train_batch_size=batch_size, eval_strategy='epoch' if eval_dataset else 'no', save_strategy='no', fp16=torch.cuda.is_available(), report_to='none', ) def compute_metrics(eval_pred): from sklearn.metrics import accuracy_score, f1_score preds = eval_pred.predictions.argmax(-1) return { 'accuracy': accuracy_score(eval_pred.label_ids, preds), 'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'), } trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics if eval_dataset else None, ) train_result = trainer.train() metrics = { 'task': 'triage_classification', 'train_loss': float(train_result.training_loss), 'epochs': epochs, 'model_name': model_name, 'samples': len(train_data), 'num_labels': num_labels, 'labels': id2label, 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', 'data_source': 'medical_triage_classification', } if eval_dataset: eval_result = trainer.evaluate() metrics.update({ 'eval_accuracy': float(eval_result['eval_accuracy']), 'eval_f1_weighted': float(eval_result['eval_f1_weighted']), }) s3_uri = None s3_bucket = job_input.get('s3_bucket') if s3_bucket: save_dir = os.path.join(work_dir, 'saved_model') trainer.save_model(save_dir) tokenizer.save_pretrained(save_dir) aws_creds = { 'aws_access_key_id': job_input.get('aws_access_key_id'), 'aws_secret_access_key': job_input.get('aws_secret_access_key'), 'aws_session_token': job_input.get('aws_session_token'), 'aws_region': job_input.get('aws_region', 'us-east-1'), } s3_prefix = job_input.get('s3_prefix', 'triage-models/bert') s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds) metrics['s3_uri'] = s3_uri return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'Triage classifier trained'} except Exception as e: import traceback return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()} finally: shutil.rmtree(work_dir, ignore_errors=True) def train_symptom_disease_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: """ Train Symptom-to-Disease classifier. Dataset: shanover/disease_symptoms_prec_full Task: Predict disease from symptoms """ import torch import tempfile import shutil from datasets import load_dataset, Dataset from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer ) from sklearn.model_selection import train_test_split work_dir = tempfile.mkdtemp() try: model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') max_samples = job_input.get('max_samples', 5000) epochs = job_input.get('epochs', 3) batch_size = job_input.get('batch_size', 16) learning_rate = job_input.get('learning_rate', 2e-5) eval_split = job_input.get('eval_split', 0.1) print(f"Loading Symptom-Disease dataset...") dataset = load_dataset("shanover/disease_symptoms_prec_full") # Build label mapping from diseases diseases = sorted(set(item['disease'] for item in dataset['train'])) label2id = {d: i for i, d in enumerate(diseases)} id2label = {i: d for d, i in label2id.items()} num_labels = len(diseases) print(f"Found {num_labels} diseases") training_data = [] for item in dataset['train']: if max_samples and len(training_data) >= max_samples: break # Format symptoms as natural text symptoms = item['symptoms'].replace('_', ' ').replace(',', ', ') training_data.append({ 'text': f"Patient presents with: {symptoms}", 'label': label2id[item['disease']] }) print(f"Loaded {len(training_data)} symptom-disease samples") if eval_split > 0: train_data, eval_data = train_test_split( training_data, test_size=eval_split, random_state=42 ) else: train_data, eval_data = training_data, None train_dataset = Dataset.from_list(train_data) eval_dataset = Dataset.from_list(eval_data) if eval_data else None tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, id2label=id2label, label2id=label2id ) def tokenize(examples): return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256) train_dataset = train_dataset.map(tokenize, batched=True) if eval_dataset: eval_dataset = eval_dataset.map(tokenize, batched=True) training_args = TrainingArguments( output_dir=work_dir, num_train_epochs=epochs, learning_rate=learning_rate, per_device_train_batch_size=batch_size, eval_strategy='epoch' if eval_dataset else 'no', save_strategy='no', fp16=torch.cuda.is_available(), report_to='none', ) def compute_metrics(eval_pred): from sklearn.metrics import accuracy_score, f1_score, top_k_accuracy_score preds = eval_pred.predictions.argmax(-1) metrics = { 'accuracy': accuracy_score(eval_pred.label_ids, preds), 'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'), } # Top-5 accuracy (important for diagnosis) try: metrics['top5_accuracy'] = top_k_accuracy_score( eval_pred.label_ids, eval_pred.predictions, k=5 ) except: pass return metrics trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics if eval_dataset else None, ) train_result = trainer.train() metrics = { 'task': 'symptom_disease_classification', 'train_loss': float(train_result.training_loss), 'epochs': epochs, 'model_name': model_name, 'samples': len(train_data), 'num_diseases': num_labels, 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', 'data_source': 'disease_symptoms_prec_full', } if eval_dataset: eval_result = trainer.evaluate() metrics.update({ 'eval_accuracy': float(eval_result['eval_accuracy']), 'eval_f1_weighted': float(eval_result['eval_f1_weighted']), }) if 'eval_top5_accuracy' in eval_result: metrics['eval_top5_accuracy'] = float(eval_result['eval_top5_accuracy']) s3_uri = None s3_bucket = job_input.get('s3_bucket') if s3_bucket: save_dir = os.path.join(work_dir, 'saved_model') trainer.save_model(save_dir) tokenizer.save_pretrained(save_dir) # Save label mapping with open(os.path.join(save_dir, 'disease_labels.json'), 'w') as f: json.dump({'id2label': id2label, 'label2id': label2id}, f) aws_creds = { 'aws_access_key_id': job_input.get('aws_access_key_id'), 'aws_secret_access_key': job_input.get('aws_secret_access_key'), 'aws_session_token': job_input.get('aws_session_token'), 'aws_region': job_input.get('aws_region', 'us-east-1'), } s3_prefix = job_input.get('s3_prefix', 'symptom-disease-models/bert') s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds) metrics['s3_uri'] = s3_uri return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'Symptom-Disease classifier trained'} except Exception as e: import traceback return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()} finally: shutil.rmtree(work_dir, ignore_errors=True) def handler(job): """RunPod serverless handler with multi-task support.""" job_input = job.get('input', {}) # Task routing task = job_input.get('task', 'ddi') if task == 'ade': return train_ade_classifier(job_input) elif task == 'triage': return train_triage_classifier(job_input) elif task == 'symptom_disease': return train_symptom_disease_classifier(job_input) elif task == 'ddi' or task == 'bert': # Original DDI training model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') if 'bert' in model_name.lower(): return train_bert_classifier(job_input) else: return train_llm_lora(job_input) elif task == 'llm': return train_llm_lora(job_input) else: # Auto-detect based on model model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct') if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']): return train_llm_lora(job_input) else: return train_bert_classifier(job_input) # RunPod serverless entrypoint runpod.serverless.start({'handler': handler})