From 45b96e209472f1cc4127bb588ab6052ea301b58d Mon Sep 17 00:00:00 2001 From: Greg Hendrickson Date: Tue, 3 Feb 2026 04:38:54 +0000 Subject: [PATCH] feat: Switch to Llama 3.1 8B (Bedrock-compatible) - Default model now meta-llama/Llama-3.1-8B-Instruct - Added multi-model chat format support: - Llama 3 format - Mistral/Mixtral format - Qwen format - Gemma format - Trained model can be imported to AWS Bedrock --- components/runpod_trainer/handler.py | 58 +++++++++++++++++++++------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/components/runpod_trainer/handler.py b/components/runpod_trainer/handler.py index d035467..f8527ab 100644 --- a/components/runpod_trainer/handler.py +++ b/components/runpod_trainer/handler.py @@ -51,20 +51,48 @@ def get_curated_fallback() -> List[Dict]: return patterns * 50 # 200 samples -def format_for_gemma(item: Dict) -> str: - """Format DDI item for Gemma instruction tuning.""" +def format_for_llm(item: Dict, model_name: str = "") -> str: + """Format DDI item for LLM instruction tuning. Auto-detects format based on model.""" severity_name = DDI_SEVERITY.get(item['severity'], 'unknown') - return f"""user -Analyze the drug interaction between {item['drug1']} and {item['drug2']}. -model -Interaction: {item['interaction_text']} + user_msg = f"Analyze the drug interaction between {item['drug1']} and {item['drug2']}." + assistant_msg = f"""Interaction: {item['interaction_text']} Severity: {severity_name.upper()} -Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}""" +Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}""" + + # Llama 3 format + if 'llama' in model_name.lower(): + return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> + +{user_msg}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{assistant_msg}<|eot_id|>""" + + # Mistral/Mixtral format + elif 'mistral' in model_name.lower() or 'mixtral' in model_name.lower(): + return f"""[INST] {user_msg} [/INST] {assistant_msg}""" + + # Qwen format + elif 'qwen' in model_name.lower(): + return f"""<|im_start|>user +{user_msg}<|im_end|> +<|im_start|>assistant +{assistant_msg}<|im_end|>""" + + # Gemma format + elif 'gemma' in model_name.lower(): + return f"""user +{user_msg} +model +{assistant_msg}""" + + # Generic fallback + else: + return f"""### User: {user_msg}\n### Assistant: {assistant_msg}""" -def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]: - """Train Gemma 3 with QLoRA for DDI classification.""" +def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]: + """Train LLM with QLoRA for DDI classification. Supports Llama, Mistral, Qwen, Gemma.""" import torch from transformers import ( AutoTokenizer, @@ -78,8 +106,8 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]: import tempfile import shutil - # Parameters - model_name = job_input.get('model_name', 'google/gemma-3-12b-it') + # Parameters - Default to Llama 3.1 8B (Bedrock-compatible) + model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct') max_samples = job_input.get('max_samples', 10000) epochs = job_input.get('epochs', 1) learning_rate = job_input.get('learning_rate', 2e-4) @@ -101,8 +129,8 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]: print(f"Loading DrugBank DDI data (max {max_samples})...") raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter) - # Format for Gemma - formatted_data = [{"text": format_for_gemma(item)} for item in raw_data] + # Format for the target LLM + formatted_data = [{"text": format_for_llm(item, model_name)} for item in raw_data] dataset = Dataset.from_list(formatted_data) print(f"Dataset size: {len(dataset)}") @@ -342,7 +370,7 @@ def handler(job): """RunPod serverless handler.""" job_input = job.get('input', {}) - model_name = job_input.get('model_name', 'google/gemma-3-12b-it') + model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct') use_lora = job_input.get('use_lora', True) # Auto-detect: use LoRA for large models @@ -352,7 +380,7 @@ def handler(job): use_lora = False if use_lora: - return train_gemma_lora(job_input) + return train_llm_lora(job_input) else: return train_bert_classifier(job_input)