feat: Switch to Llama 3.1 8B (Bedrock-compatible)

- Default model now meta-llama/Llama-3.1-8B-Instruct
- Added multi-model chat format support:
  - Llama 3 format
  - Mistral/Mixtral format
  - Qwen format
  - Gemma format
- Trained model can be imported to AWS Bedrock
This commit is contained in:
2026-02-03 04:38:54 +00:00
parent 67a1095100
commit 45b96e2094

View File

@@ -51,20 +51,48 @@ def get_curated_fallback() -> List[Dict]:
return patterns * 50 # 200 samples
def format_for_gemma(item: Dict) -> str:
"""Format DDI item for Gemma instruction tuning."""
def format_for_llm(item: Dict, model_name: str = "") -> str:
"""Format DDI item for LLM instruction tuning. Auto-detects format based on model."""
severity_name = DDI_SEVERITY.get(item['severity'], 'unknown')
return f"""<start_of_turn>user
Analyze the drug interaction between {item['drug1']} and {item['drug2']}.<end_of_turn>
<start_of_turn>model
Interaction: {item['interaction_text']}
user_msg = f"Analyze the drug interaction between {item['drug1']} and {item['drug2']}."
assistant_msg = f"""Interaction: {item['interaction_text']}
Severity: {severity_name.upper()}
Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}<end_of_turn>"""
Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}"""
# Llama 3 format
if 'llama' in model_name.lower():
return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{user_msg}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{assistant_msg}<|eot_id|>"""
# Mistral/Mixtral format
elif 'mistral' in model_name.lower() or 'mixtral' in model_name.lower():
return f"""<s>[INST] {user_msg} [/INST] {assistant_msg}</s>"""
# Qwen format
elif 'qwen' in model_name.lower():
return f"""<|im_start|>user
{user_msg}<|im_end|>
<|im_start|>assistant
{assistant_msg}<|im_end|>"""
# Gemma format
elif 'gemma' in model_name.lower():
return f"""<start_of_turn>user
{user_msg}<end_of_turn>
<start_of_turn>model
{assistant_msg}<end_of_turn>"""
# Generic fallback
else:
return f"""### User: {user_msg}\n### Assistant: {assistant_msg}"""
def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""Train Gemma 3 with QLoRA for DDI classification."""
def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
"""Train LLM with QLoRA for DDI classification. Supports Llama, Mistral, Qwen, Gemma."""
import torch
from transformers import (
AutoTokenizer,
@@ -78,8 +106,8 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
import tempfile
import shutil
# Parameters
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
# Parameters - Default to Llama 3.1 8B (Bedrock-compatible)
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
max_samples = job_input.get('max_samples', 10000)
epochs = job_input.get('epochs', 1)
learning_rate = job_input.get('learning_rate', 2e-4)
@@ -101,8 +129,8 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
print(f"Loading DrugBank DDI data (max {max_samples})...")
raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter)
# Format for Gemma
formatted_data = [{"text": format_for_gemma(item)} for item in raw_data]
# Format for the target LLM
formatted_data = [{"text": format_for_llm(item, model_name)} for item in raw_data]
dataset = Dataset.from_list(formatted_data)
print(f"Dataset size: {len(dataset)}")
@@ -342,7 +370,7 @@ def handler(job):
"""RunPod serverless handler."""
job_input = job.get('input', {})
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
use_lora = job_input.get('use_lora', True)
# Auto-detect: use LoRA for large models
@@ -352,7 +380,7 @@ def handler(job):
use_lora = False
if use_lora:
return train_gemma_lora(job_input)
return train_llm_lora(job_input)
else:
return train_bert_classifier(job_input)