""" RunPod Serverless Handler for DDI Model Training Supports both BERT-style classification and LLM fine-tuning with LoRA. Uses 176K real DrugBank DDI samples with drug names. """ import os import json import runpod from typing import Dict, Any, List, Optional # DDI severity labels DDI_SEVERITY = { 1: "minor", 2: "moderate", 3: "major", 4: "contraindicated" } def load_drugbank_data(max_samples: int = None, severity_filter: List[int] = None) -> List[Dict]: """Load real DrugBank DDI data from bundled file.""" data_path = os.environ.get('DDI_DATA_PATH', '/app/data/drugbank_ddi_complete.jsonl') if not os.path.exists(data_path): print(f"WARNING: Data file not found at {data_path}, using curated fallback") return get_curated_fallback() data = [] with open(data_path) as f: for line in f: item = json.loads(line) if severity_filter and item['severity'] not in severity_filter: continue data.append(item) if max_samples and len(data) >= max_samples: break return data def get_curated_fallback() -> List[Dict]: """Fallback curated data if main file not available.""" patterns = [ {"drug1": "fluoxetine", "drug2": "tramadol", "interaction_text": "fluoxetine may increase the risk of serotonin syndrome when combined with tramadol", "severity": 4}, {"drug1": "warfarin", "drug2": "aspirin", "interaction_text": "warfarin may increase the risk of bleeding when combined with aspirin", "severity": 3}, {"drug1": "simvastatin", "drug2": "amlodipine", "interaction_text": "The serum concentration of simvastatin can be increased when combined with amlodipine", "severity": 2}, {"drug1": "metformin", "drug2": "lisinopril", "interaction_text": "metformin and lisinopril have no significant interaction", "severity": 1}, ] return patterns * 50 # 200 samples def format_for_gemma(item: Dict) -> str: """Format DDI item for Gemma instruction tuning.""" severity_name = DDI_SEVERITY.get(item['severity'], 'unknown') return f"""user Analyze the drug interaction between {item['drug1']} and {item['drug2']}. model Interaction: {item['interaction_text']} Severity: {severity_name.upper()} Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}""" def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]: """Train Gemma 3 with QLoRA for DDI classification.""" import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, ) from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from trl import SFTTrainer from datasets import Dataset import tempfile import shutil # Parameters model_name = job_input.get('model_name', 'google/gemma-3-12b-it') max_samples = job_input.get('max_samples', 10000) epochs = job_input.get('epochs', 1) learning_rate = job_input.get('learning_rate', 2e-4) batch_size = job_input.get('batch_size', 2) lora_r = job_input.get('lora_r', 16) lora_alpha = job_input.get('lora_alpha', 32) max_seq_length = job_input.get('max_seq_length', 512) severity_filter = job_input.get('severity_filter', None) # e.g., [3, 4] for major/contraindicated only work_dir = tempfile.mkdtemp() try: print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") print(f"Model: {model_name}") print(f"LoRA r={lora_r}, alpha={lora_alpha}") # Load training data print(f"Loading DrugBank DDI data (max {max_samples})...") raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter) # Format for Gemma formatted_data = [{"text": format_for_gemma(item)} for item in raw_data] dataset = Dataset.from_list(formatted_data) print(f"Dataset size: {len(dataset)}") # Severity distribution from collections import Counter sev_dist = Counter(item['severity'] for item in raw_data) print(f"Severity distribution: {dict(sev_dist)}") # QLoRA config - 4-bit quantization bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) # Load model print(f"Loading {model_name} with 4-bit quantization...") model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Prepare for k-bit training model = prepare_model_for_kbit_training(model) # LoRA config lora_config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) print(f"Trainable params: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)") # Training arguments training_args = TrainingArguments( output_dir=work_dir, num_train_epochs=epochs, per_device_train_batch_size=batch_size, gradient_accumulation_steps=8, learning_rate=learning_rate, weight_decay=0.01, warmup_ratio=0.1, logging_steps=25, save_strategy="no", bf16=True, gradient_checkpointing=True, optim="paged_adamw_8bit", report_to="none", max_grad_norm=0.3, ) # SFT Trainer trainer = SFTTrainer( model=model, train_dataset=dataset, args=training_args, peft_config=lora_config, processing_class=tokenizer, max_seq_length=max_seq_length, ) # Train print("Starting LoRA fine-tuning...") train_result = trainer.train() # Metrics metrics = { 'train_loss': float(train_result.training_loss), 'epochs': epochs, 'model_name': model_name, 'samples': len(raw_data), 'lora_r': lora_r, 'lora_alpha': lora_alpha, 'trainable_params': trainable_params, 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', 'vram_gb': torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0, 'data_source': 'drugbank_176k', 'severity_dist': dict(sev_dist), } print(f"Training complete! Loss: {metrics['train_loss']:.4f}") return { 'status': 'success', 'metrics': metrics, 'message': f'Gemma 3 fine-tuned on {len(raw_data):,} real DrugBank DDI samples' } except Exception as e: import traceback return { 'status': 'error', 'error': str(e), 'traceback': traceback.format_exc() } finally: shutil.rmtree(work_dir, ignore_errors=True) if torch.cuda.is_available(): torch.cuda.empty_cache() def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: """Train BERT-style classifier for DDI severity prediction.""" import torch from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer ) from datasets import Dataset from sklearn.model_selection import train_test_split import tempfile import shutil model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') max_samples = job_input.get('max_samples', 50000) epochs = job_input.get('epochs', 3) learning_rate = job_input.get('learning_rate', 2e-5) batch_size = job_input.get('batch_size', 16) eval_split = job_input.get('eval_split', 0.1) work_dir = tempfile.mkdtemp() try: print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") print(f"Model: {model_name}") # Load data raw_data = load_drugbank_data(max_samples=max_samples) # Create text + label format # Shift severity to 0-indexed (1-4 -> 0-3) training_data = [{ "text": f"{d['drug1']} and {d['drug2']}: {d['interaction_text']}", "label": d['severity'] - 1 # 0-indexed } for d in raw_data] print(f"Loaded {len(training_data)} samples") # Split if eval_split > 0: train_data, eval_data = train_test_split( training_data, test_size=eval_split, random_state=42, stratify=[d['label'] for d in training_data] ) else: train_data, eval_data = training_data, None train_dataset = Dataset.from_list(train_data) eval_dataset = Dataset.from_list(eval_data) if eval_data else None # Load model (4 classes: minor, moderate, major, contraindicated) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=4 ) def tokenize(examples): return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256) train_dataset = train_dataset.map(tokenize, batched=True) if eval_dataset: eval_dataset = eval_dataset.map(tokenize, batched=True) training_args = TrainingArguments( output_dir=work_dir, num_train_epochs=epochs, learning_rate=learning_rate, per_device_train_batch_size=batch_size, eval_strategy='epoch' if eval_dataset else 'no', save_strategy='no', fp16=torch.cuda.is_available(), report_to='none', ) def compute_metrics(eval_pred): from sklearn.metrics import accuracy_score, f1_score preds = eval_pred.predictions.argmax(-1) return { 'accuracy': accuracy_score(eval_pred.label_ids, preds), 'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'), } trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics if eval_dataset else None, ) train_result = trainer.train() metrics = { 'train_loss': float(train_result.training_loss), 'epochs': epochs, 'model_name': model_name, 'samples': len(train_data), 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', 'data_source': 'drugbank_176k', } if eval_dataset: eval_result = trainer.evaluate() metrics.update({ 'eval_accuracy': float(eval_result['eval_accuracy']), 'eval_f1_weighted': float(eval_result['eval_f1_weighted']), }) return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained on DrugBank data'} except Exception as e: import traceback return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()} finally: shutil.rmtree(work_dir, ignore_errors=True) def handler(job): """RunPod serverless handler.""" job_input = job.get('input', {}) model_name = job_input.get('model_name', 'google/gemma-3-12b-it') use_lora = job_input.get('use_lora', True) # Auto-detect: use LoRA for large models if any(x in model_name.lower() for x in ['gemma', 'llama', 'mistral', 'qwen']): use_lora = True elif 'bert' in model_name.lower(): use_lora = False if use_lora: return train_gemma_lora(job_input) else: return train_bert_classifier(job_input) # RunPod serverless entrypoint runpod.serverless.start({'handler': handler})