From 39922e8d2ec89587ca585bff918227091635f364 Mon Sep 17 00:00:00 2001 From: Greg Hendrickson Date: Tue, 3 Feb 2026 03:58:25 +0000 Subject: [PATCH] feat: Add Gemma 3 12B with QLoRA fine-tuning - Added PEFT, bitsandbytes, TRL for LoRA training - 4-bit QLoRA quantization for 48GB GPU fit - Instruction-tuning format for Gemma chat template - Auto-detect model type (BERT vs LLM) - Updated GPU tier to ADA_24/AMPERE_48 --- components/runpod_trainer/handler.py | 410 +++++++++++++-------- components/runpod_trainer/requirements.txt | 5 +- 2 files changed, 264 insertions(+), 151 deletions(-) diff --git a/components/runpod_trainer/handler.py b/components/runpod_trainer/handler.py index 80ee807..7fca7d2 100644 --- a/components/runpod_trainer/handler.py +++ b/components/runpod_trainer/handler.py @@ -1,8 +1,8 @@ """ RunPod Serverless Handler for DDI Model Training -This runs on RunPod GPU instances and trains the Bio_ClinicalBERT model -for drug-drug interaction detection using real DDI data. +Supports both BERT-style classification and LLM fine-tuning with LoRA. +Default: Gemma 3 12B with QLoRA for DDI severity classification. """ import os import json @@ -12,86 +12,93 @@ from typing import Dict, Any, List, Optional # DDI severity labels DDI_LABELS = { - 0: "none", # No significant interaction - 1: "minor", # Minor interaction - 2: "moderate", # Moderate interaction - 3: "major", # Major interaction - 4: "contraindicated" # Contraindicated + 0: "no_interaction", + 1: "minor", + 2: "moderate", + 3: "major", + 4: "contraindicated" +} + +LABEL_DESCRIPTIONS = { + 0: "No clinically significant interaction", + 1: "Minor interaction - minimal clinical significance", + 2: "Moderate interaction - may require monitoring or dose adjustment", + 3: "Major interaction - avoid combination if possible, high risk", + 4: "Contraindicated - do not use together, life-threatening risk" } -def get_real_ddi_data(max_samples: int = 10000) -> List[Dict[str, Any]]: - """ - Generate real DDI training data from DrugBank patterns. - Uses curated drug interaction patterns based on clinical guidelines. - """ +def get_ddi_training_data(max_samples: int = 5000) -> List[Dict[str, Any]]: + """Generate DDI training data formatted for instruction tuning.""" import random random.seed(42) - # Real drug pairs with known interactions (based on clinical data) + # Real drug interaction patterns based on clinical data ddi_patterns = [ # Contraindicated (4) {"drugs": ["fluoxetine", "tramadol"], "type": "serotonin syndrome risk", "label": 4}, - {"drugs": ["fluoxetine", "monoamine oxidase inhibitor"], "type": "serotonin syndrome risk", "label": 4}, + {"drugs": ["fluoxetine", "phenelzine"], "type": "serotonin syndrome risk", "label": 4}, {"drugs": ["simvastatin", "itraconazole"], "type": "rhabdomyolysis risk", "label": 4}, {"drugs": ["methotrexate", "trimethoprim"], "type": "severe bone marrow suppression", "label": 4}, {"drugs": ["warfarin", "miconazole"], "type": "severe bleeding risk", "label": 4}, {"drugs": ["cisapride", "erythromycin"], "type": "QT prolongation cardiac arrest", "label": 4}, {"drugs": ["pimozide", "clarithromycin"], "type": "QT prolongation risk", "label": 4}, {"drugs": ["ergotamine", "ritonavir"], "type": "ergot toxicity risk", "label": 4}, - {"drugs": ["sildenafil", "nitrates"], "type": "severe hypotension", "label": 4}, - {"drugs": ["linezolid", "serotonergic agents"], "type": "serotonin syndrome", "label": 4}, + {"drugs": ["sildenafil", "nitroglycerin"], "type": "severe hypotension", "label": 4}, + {"drugs": ["linezolid", "sertraline"], "type": "serotonin syndrome", "label": 4}, + {"drugs": ["maoi", "meperidine"], "type": "hypertensive crisis", "label": 4}, + {"drugs": ["metronidazole", "disulfiram"], "type": "psychosis risk", "label": 4}, # Major (3) {"drugs": ["warfarin", "aspirin"], "type": "increased bleeding risk", "label": 3}, {"drugs": ["digoxin", "amiodarone"], "type": "digoxin toxicity elevated", "label": 3}, {"drugs": ["lithium", "ibuprofen"], "type": "lithium toxicity risk", "label": 3}, - {"drugs": ["metformin", "contrast media"], "type": "lactic acidosis risk", "label": 3}, - {"drugs": ["potassium", "ACE inhibitor"], "type": "hyperkalemia risk", "label": 3}, - {"drugs": ["opioid", "benzodiazepine"], "type": "respiratory depression", "label": 3}, + {"drugs": ["metformin", "iodinated contrast"], "type": "lactic acidosis risk", "label": 3}, + {"drugs": ["potassium chloride", "lisinopril"], "type": "hyperkalemia risk", "label": 3}, + {"drugs": ["oxycodone", "alprazolam"], "type": "respiratory depression", "label": 3}, {"drugs": ["theophylline", "ciprofloxacin"], "type": "theophylline toxicity", "label": 3}, {"drugs": ["phenytoin", "fluconazole"], "type": "phenytoin toxicity", "label": 3}, {"drugs": ["carbamazepine", "verapamil"], "type": "carbamazepine toxicity", "label": 3}, {"drugs": ["cyclosporine", "ketoconazole"], "type": "nephrotoxicity risk", "label": 3}, - {"drugs": ["methotrexate", "NSAIDs"], "type": "methotrexate toxicity", "label": 3}, + {"drugs": ["methotrexate", "ibuprofen"], "type": "methotrexate toxicity", "label": 3}, {"drugs": ["quinidine", "digoxin"], "type": "digoxin toxicity", "label": 3}, {"drugs": ["clopidogrel", "omeprazole"], "type": "reduced antiplatelet effect", "label": 3}, - {"drugs": ["warfarin", "vitamin K"], "type": "reduced anticoagulation", "label": 3}, + {"drugs": ["warfarin", "rifampin"], "type": "reduced anticoagulation", "label": 3}, {"drugs": ["dabigatran", "rifampin"], "type": "reduced anticoagulant effect", "label": 3}, # Moderate (2) {"drugs": ["simvastatin", "amlodipine"], "type": "increased statin exposure", "label": 2}, {"drugs": ["metformin", "cimetidine"], "type": "increased metformin levels", "label": 2}, - {"drugs": ["levothyroxine", "calcium"], "type": "reduced thyroid absorption", "label": 2}, - {"drugs": ["gabapentin", "antacids"], "type": "reduced gabapentin absorption", "label": 2}, + {"drugs": ["levothyroxine", "calcium carbonate"], "type": "reduced thyroid absorption", "label": 2}, + {"drugs": ["gabapentin", "aluminum hydroxide"], "type": "reduced gabapentin absorption", "label": 2}, {"drugs": ["furosemide", "gentamicin"], "type": "ototoxicity risk", "label": 2}, - {"drugs": ["prednisone", "NSAIDs"], "type": "GI bleeding risk", "label": 2}, + {"drugs": ["prednisone", "naproxen"], "type": "GI bleeding risk", "label": 2}, {"drugs": ["metoprolol", "verapamil"], "type": "bradycardia risk", "label": 2}, {"drugs": ["sertraline", "tramadol"], "type": "seizure threshold lowered", "label": 2}, {"drugs": ["losartan", "potassium supplements"], "type": "hyperkalemia risk", "label": 2}, {"drugs": ["alprazolam", "ketoconazole"], "type": "increased sedation", "label": 2}, - {"drugs": ["atorvastatin", "grapefruit"], "type": "increased statin levels", "label": 2}, - {"drugs": ["ciprofloxacin", "iron"], "type": "reduced antibiotic absorption", "label": 2}, + {"drugs": ["atorvastatin", "grapefruit juice"], "type": "increased statin levels", "label": 2}, + {"drugs": ["ciprofloxacin", "ferrous sulfate"], "type": "reduced antibiotic absorption", "label": 2}, {"drugs": ["warfarin", "acetaminophen"], "type": "slight INR increase", "label": 2}, - {"drugs": ["insulin", "beta blocker"], "type": "masked hypoglycemia", "label": 2}, + {"drugs": ["insulin", "propranolol"], "type": "masked hypoglycemia", "label": 2}, {"drugs": ["digoxin", "spironolactone"], "type": "increased digoxin levels", "label": 2}, # Minor (1) {"drugs": ["aspirin", "ibuprofen"], "type": "reduced cardioprotection", "label": 1}, {"drugs": ["metformin", "vitamin B12"], "type": "reduced B12 absorption long-term", "label": 1}, - {"drugs": ["amoxicillin", "oral contraceptives"], "type": "theoretical reduced efficacy", "label": 1}, - {"drugs": ["proton pump inhibitor", "vitamin B12"], "type": "reduced absorption", "label": 1}, - {"drugs": ["caffeine", "fluoroquinolones"], "type": "increased caffeine effect", "label": 1}, - {"drugs": ["antacids", "iron"], "type": "timing interaction", "label": 1}, - {"drugs": ["statin", "niacin"], "type": "monitoring recommended", "label": 1}, - {"drugs": ["ACE inhibitor", "aspirin"], "type": "possible reduced effect", "label": 1}, - {"drugs": ["thiazide", "calcium"], "type": "hypercalcemia monitoring", "label": 1}, - {"drugs": ["beta blocker", "clonidine"], "type": "withdrawal monitoring", "label": 1}, + {"drugs": ["amoxicillin", "ethinyl estradiol"], "type": "theoretical reduced efficacy", "label": 1}, + {"drugs": ["omeprazole", "vitamin B12"], "type": "reduced absorption", "label": 1}, + {"drugs": ["caffeine", "ciprofloxacin"], "type": "increased caffeine effect", "label": 1}, + {"drugs": ["calcium carbonate", "ferrous sulfate"], "type": "timing interaction", "label": 1}, + {"drugs": ["atorvastatin", "niacin"], "type": "monitoring recommended", "label": 1}, + {"drugs": ["lisinopril", "aspirin"], "type": "possible reduced effect", "label": 1}, + {"drugs": ["hydrochlorothiazide", "calcium"], "type": "hypercalcemia monitoring", "label": 1}, + {"drugs": ["metoprolol", "clonidine"], "type": "withdrawal monitoring", "label": 1}, # No interaction (0) {"drugs": ["amlodipine", "atorvastatin"], "type": "safe combination", "label": 0}, {"drugs": ["metformin", "lisinopril"], "type": "complementary therapy", "label": 0}, - {"drugs": ["omeprazole", "levothyroxine"], "type": "can be used together", "label": 0}, + {"drugs": ["omeprazole", "levothyroxine"], "type": "can be used together with spacing", "label": 0}, {"drugs": ["aspirin", "atorvastatin"], "type": "standard combination", "label": 0}, {"drugs": ["metoprolol", "lisinopril"], "type": "common combination", "label": 0}, {"drugs": ["gabapentin", "acetaminophen"], "type": "no interaction", "label": 0}, @@ -101,52 +108,201 @@ def get_real_ddi_data(max_samples: int = 10000) -> List[Dict[str, Any]]: {"drugs": ["pantoprazole", "amlodipine"], "type": "no known interaction", "label": 0}, ] - # Expand with variations training_data = [] for pattern in ddi_patterns: drug1, drug2 = pattern["drugs"] interaction_type = pattern["type"] label = pattern["label"] + label_name = DDI_LABELS[label] + label_desc = LABEL_DESCRIPTIONS[label] - # Create multiple text variations - variations = [ - f"{drug1} and {drug2} interaction: {interaction_type}", - f"{drug2} combined with {drug1} causes {interaction_type}", - f"Patient taking {drug1} with {drug2}: {interaction_type}", - f"Concomitant use of {drug1} and {drug2} leads to {interaction_type}", - f"{drug1} {drug2} drug-drug interaction {interaction_type}", + # Create instruction-tuning format + prompts = [ + f"Analyze the drug-drug interaction between {drug1} and {drug2}.", + f"What is the severity of combining {drug1} with {drug2}?", + f"A patient is taking {drug1}. They need to start {drug2}. Assess the interaction risk.", + f"Evaluate the interaction: {drug1} + {drug2}", + f"Drug interaction check: {drug1} and {drug2}", ] - for text in variations: - training_data.append({"text": text, "label": label}) + for prompt in prompts: + response = f"Severity: {label_name.upper()}\nInteraction: {interaction_type}\nRecommendation: {label_desc}" + + training_data.append({ + "instruction": prompt, + "response": response, + "label": label, + "label_name": label_name + }) - # Shuffle and limit + # Shuffle and replicate to reach target size random.shuffle(training_data) - # Replicate to reach target size while len(training_data) < max_samples: training_data.extend(training_data[:min(len(training_data), max_samples - len(training_data))]) return training_data[:max_samples] -def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]: - """ - Train DDI detection model. +def format_for_gemma(example: Dict) -> str: + """Format example for Gemma instruction tuning.""" + return f"""user +{example['instruction']} +model +{example['response']}""" + + +def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]: + """Train Gemma 3 with QLoRA for DDI classification.""" + import torch + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + BitsAndBytesConfig, + TrainingArguments, + ) + from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + from trl import SFTTrainer + from datasets import Dataset + import tempfile + import shutil - Expected input: - { - "model_name": "emilyalsentzer/Bio_ClinicalBERT", - "use_real_data": true, - "max_samples": 5000, - "training_data": [...], # Or provide inline data - "epochs": 3, - "learning_rate": 2e-5, - "batch_size": 16, - "eval_split": 0.1 - } - """ + # Parameters + model_name = job_input.get('model_name', 'google/gemma-3-12b-it') + max_samples = job_input.get('max_samples', 2000) + epochs = job_input.get('epochs', 1) + learning_rate = job_input.get('learning_rate', 2e-4) + batch_size = job_input.get('batch_size', 4) + lora_r = job_input.get('lora_r', 16) + lora_alpha = job_input.get('lora_alpha', 32) + max_seq_length = job_input.get('max_seq_length', 512) + + work_dir = tempfile.mkdtemp() + + try: + print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") + print(f"Model: {model_name}") + print(f"LoRA r={lora_r}, alpha={lora_alpha}") + print(f"Samples: {max_samples}, Epochs: {epochs}") + + # Load training data + print("Loading DDI training data...") + training_data = get_ddi_training_data(max_samples=max_samples) + + # Format for Gemma + formatted_data = [{"text": format_for_gemma(ex)} for ex in training_data] + dataset = Dataset.from_list(formatted_data) + + print(f"Dataset size: {len(dataset)}") + + # QLoRA config - 4-bit quantization + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + ) + + # Load model + print(f"Loading {model_name} with 4-bit quantization...") + model = AutoModelForCausalLM.from_pretrained( + model_name, + quantization_config=bnb_config, + device_map="auto", + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + + # Prepare for k-bit training + model = prepare_model_for_kbit_training(model) + + # LoRA config + lora_config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + # Training arguments + training_args = TrainingArguments( + output_dir=work_dir, + num_train_epochs=epochs, + per_device_train_batch_size=batch_size, + gradient_accumulation_steps=4, + learning_rate=learning_rate, + weight_decay=0.01, + warmup_ratio=0.1, + logging_steps=10, + save_strategy="no", + bf16=True, + gradient_checkpointing=True, + optim="paged_adamw_8bit", + report_to="none", + max_grad_norm=0.3, + ) + + # SFT Trainer + trainer = SFTTrainer( + model=model, + train_dataset=dataset, + args=training_args, + peft_config=lora_config, + processing_class=tokenizer, + max_seq_length=max_seq_length, + ) + + # Train + print("Starting LoRA fine-tuning...") + train_result = trainer.train() + + # Metrics + metrics = { + 'train_loss': float(train_result.training_loss), + 'epochs': epochs, + 'model_name': model_name, + 'samples': len(training_data), + 'lora_r': lora_r, + 'lora_alpha': lora_alpha, + 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', + 'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad), + 'quantization': '4-bit QLoRA', + } + + print(f"Training complete! Loss: {metrics['train_loss']:.4f}") + + return { + 'status': 'success', + 'metrics': metrics, + 'message': f'Gemma 3 12B fine-tuned with QLoRA on DDI data' + } + + except Exception as e: + import traceback + return { + 'status': 'error', + 'error': str(e), + 'traceback': traceback.format_exc() + } + finally: + shutil.rmtree(work_dir, ignore_errors=True) + # Clear GPU memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: + """Train BERT-style classifier (original approach).""" import torch from transformers import ( AutoTokenizer, @@ -159,165 +315,119 @@ def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]: import tempfile import shutil - # Extract parameters model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT') - use_real_data = job_input.get('use_real_data', True) max_samples = job_input.get('max_samples', 5000) - training_data = job_input.get('training_data', None) epochs = job_input.get('epochs', 3) learning_rate = job_input.get('learning_rate', 2e-5) batch_size = job_input.get('batch_size', 16) eval_split = job_input.get('eval_split', 0.1) - # Load data - if use_real_data and not training_data: - print("Loading curated DDI dataset...") - training_data = get_real_ddi_data(max_samples=max_samples) - elif not training_data: - print("No training data provided, using sample DDI dataset...") - training_data = get_real_ddi_data(max_samples=150) - - # Create temp directory work_dir = tempfile.mkdtemp() - model_dir = os.path.join(work_dir, 'model') try: - print(f"Training samples: {len(training_data)}") - print(f"Model: {model_name}") - print(f"Epochs: {epochs}, Batch size: {batch_size}") print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") + print(f"Model: {model_name}") - # Count label distribution - label_counts = {} - for item in training_data: - label_counts[item['label']] = label_counts.get(item['label'], 0) + 1 - print(f"Label distribution: {label_counts}") + # Get data in BERT format + raw_data = get_ddi_training_data(max_samples=max_samples) + training_data = [{"text": d["instruction"], "label": d["label"]} for d in raw_data] - # Split into train/eval - if eval_split > 0 and len(training_data) > 100: + # Split + if eval_split > 0: train_data, eval_data = train_test_split( - training_data, - test_size=eval_split, - random_state=42, + training_data, test_size=eval_split, random_state=42, stratify=[d['label'] for d in training_data] ) - print(f"Train: {len(train_data)}, Eval: {len(eval_data)}") else: - train_data = training_data - eval_data = None + train_data, eval_data = training_data, None - # Create datasets train_dataset = Dataset.from_list(train_data) eval_dataset = Dataset.from_list(eval_data) if eval_data else None - # Load model and tokenizer - print(f"Loading model: {model_name}") + # Load model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( - model_name, - num_labels=5 # DDI severity: none(0), minor(1), moderate(2), major(3), contraindicated(4) + model_name, num_labels=5 ) - # Tokenize datasets - def tokenize_function(examples): - return tokenizer( - examples['text'], - padding='max_length', - truncation=True, - max_length=256 - ) + def tokenize(examples): + return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=256) - tokenized_train = train_dataset.map(tokenize_function, batched=True) - tokenized_eval = eval_dataset.map(tokenize_function, batched=True) if eval_dataset else None + train_dataset = train_dataset.map(tokenize, batched=True) + if eval_dataset: + eval_dataset = eval_dataset.map(tokenize, batched=True) - # Training arguments training_args = TrainingArguments( - output_dir=model_dir, + output_dir=work_dir, num_train_epochs=epochs, learning_rate=learning_rate, per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - warmup_ratio=0.1, - weight_decay=0.01, - logging_steps=50, - eval_strategy='epoch' if tokenized_eval else 'no', + eval_strategy='epoch' if eval_dataset else 'no', save_strategy='no', fp16=torch.cuda.is_available(), report_to='none', - load_best_model_at_end=False, ) - # Compute metrics function def compute_metrics(eval_pred): from sklearn.metrics import accuracy_score, f1_score - predictions, labels = eval_pred - predictions = predictions.argmax(-1) + preds = eval_pred.predictions.argmax(-1) return { - 'accuracy': accuracy_score(labels, predictions), - 'f1_macro': f1_score(labels, predictions, average='macro'), - 'f1_weighted': f1_score(labels, predictions, average='weighted'), + 'accuracy': accuracy_score(eval_pred.label_ids, preds), + 'f1_weighted': f1_score(eval_pred.label_ids, preds, average='weighted'), } - # Initialize trainer trainer = Trainer( model=model, args=training_args, - train_dataset=tokenized_train, - eval_dataset=tokenized_eval, - compute_metrics=compute_metrics if tokenized_eval else None, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics if eval_dataset else None, ) - # Train - print("Starting training...") train_result = trainer.train() - # Get metrics metrics = { 'train_loss': float(train_result.training_loss), 'epochs': epochs, 'model_name': model_name, 'samples': len(train_data), 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU', - 'data_source': 'curated_ddi' } - # Run evaluation if we have eval data - if tokenized_eval: - print("Running evaluation...") + if eval_dataset: eval_result = trainer.evaluate() metrics.update({ - 'eval_loss': float(eval_result['eval_loss']), 'eval_accuracy': float(eval_result['eval_accuracy']), - 'eval_f1_macro': float(eval_result['eval_f1_macro']), 'eval_f1_weighted': float(eval_result['eval_f1_weighted']), }) - print(f"Training complete! Loss: {metrics['train_loss']:.4f}") - if 'eval_accuracy' in metrics: - print(f"Eval accuracy: {metrics['eval_accuracy']:.4f}, F1: {metrics['eval_f1_weighted']:.4f}") - - return { - 'status': 'success', - 'metrics': metrics, - 'message': 'Model trained successfully on curated DDI data' - } + return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained'} except Exception as e: import traceback - return { - 'status': 'error', - 'error': str(e), - 'traceback': traceback.format_exc() - } + return {'status': 'error', 'error': str(e), 'traceback': traceback.format_exc()} finally: - # Cleanup shutil.rmtree(work_dir, ignore_errors=True) def handler(job): """RunPod serverless handler.""" job_input = job.get('input', {}) - return train_ddi_model(job_input) + + # Choose training mode + model_name = job_input.get('model_name', 'google/gemma-3-12b-it') + use_lora = job_input.get('use_lora', True) + + # Auto-detect: use LoRA for large models + if 'gemma' in model_name.lower() or 'llama' in model_name.lower() or 'mistral' in model_name.lower(): + use_lora = True + elif 'bert' in model_name.lower(): + use_lora = False + + if use_lora: + return train_gemma_lora(job_input) + else: + return train_bert_classifier(job_input) # RunPod serverless entrypoint diff --git a/components/runpod_trainer/requirements.txt b/components/runpod_trainer/requirements.txt index 62a8681..63b3802 100644 --- a/components/runpod_trainer/requirements.txt +++ b/components/runpod_trainer/requirements.txt @@ -1,5 +1,5 @@ runpod>=1.7.0 -transformers==4.44.0 +transformers>=4.48.0 datasets>=2.16.0 accelerate>=0.30.0 boto3>=1.34.0 @@ -7,3 +7,6 @@ scikit-learn>=1.3.0 scipy>=1.11.0 safetensors>=0.4.0 requests>=2.31.0 +peft>=0.14.0 +bitsandbytes>=0.45.0 +trl>=0.14.0