mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
feat: Switch to Llama 3.1 8B (Bedrock-compatible)
- Default model now meta-llama/Llama-3.1-8B-Instruct - Added multi-model chat format support: - Llama 3 format - Mistral/Mixtral format - Qwen format - Gemma format - Trained model can be imported to AWS Bedrock
This commit is contained in:
@@ -51,20 +51,48 @@ def get_curated_fallback() -> List[Dict]:
|
||||
return patterns * 50 # 200 samples
|
||||
|
||||
|
||||
def format_for_gemma(item: Dict) -> str:
|
||||
"""Format DDI item for Gemma instruction tuning."""
|
||||
def format_for_llm(item: Dict, model_name: str = "") -> str:
|
||||
"""Format DDI item for LLM instruction tuning. Auto-detects format based on model."""
|
||||
severity_name = DDI_SEVERITY.get(item['severity'], 'unknown')
|
||||
|
||||
return f"""<start_of_turn>user
|
||||
Analyze the drug interaction between {item['drug1']} and {item['drug2']}.<end_of_turn>
|
||||
<start_of_turn>model
|
||||
Interaction: {item['interaction_text']}
|
||||
user_msg = f"Analyze the drug interaction between {item['drug1']} and {item['drug2']}."
|
||||
assistant_msg = f"""Interaction: {item['interaction_text']}
|
||||
Severity: {severity_name.upper()}
|
||||
Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}<end_of_turn>"""
|
||||
Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}"""
|
||||
|
||||
# Llama 3 format
|
||||
if 'llama' in model_name.lower():
|
||||
return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{user_msg}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{assistant_msg}<|eot_id|>"""
|
||||
|
||||
# Mistral/Mixtral format
|
||||
elif 'mistral' in model_name.lower() or 'mixtral' in model_name.lower():
|
||||
return f"""<s>[INST] {user_msg} [/INST] {assistant_msg}</s>"""
|
||||
|
||||
# Qwen format
|
||||
elif 'qwen' in model_name.lower():
|
||||
return f"""<|im_start|>user
|
||||
{user_msg}<|im_end|>
|
||||
<|im_start|>assistant
|
||||
{assistant_msg}<|im_end|>"""
|
||||
|
||||
# Gemma format
|
||||
elif 'gemma' in model_name.lower():
|
||||
return f"""<start_of_turn>user
|
||||
{user_msg}<end_of_turn>
|
||||
<start_of_turn>model
|
||||
{assistant_msg}<end_of_turn>"""
|
||||
|
||||
# Generic fallback
|
||||
else:
|
||||
return f"""### User: {user_msg}\n### Assistant: {assistant_msg}"""
|
||||
|
||||
|
||||
def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Train Gemma 3 with QLoRA for DDI classification."""
|
||||
def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Train LLM with QLoRA for DDI classification. Supports Llama, Mistral, Qwen, Gemma."""
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -78,8 +106,8 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
# Parameters
|
||||
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
||||
# Parameters - Default to Llama 3.1 8B (Bedrock-compatible)
|
||||
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
|
||||
max_samples = job_input.get('max_samples', 10000)
|
||||
epochs = job_input.get('epochs', 1)
|
||||
learning_rate = job_input.get('learning_rate', 2e-4)
|
||||
@@ -101,8 +129,8 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
print(f"Loading DrugBank DDI data (max {max_samples})...")
|
||||
raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter)
|
||||
|
||||
# Format for Gemma
|
||||
formatted_data = [{"text": format_for_gemma(item)} for item in raw_data]
|
||||
# Format for the target LLM
|
||||
formatted_data = [{"text": format_for_llm(item, model_name)} for item in raw_data]
|
||||
dataset = Dataset.from_list(formatted_data)
|
||||
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
@@ -342,7 +370,7 @@ def handler(job):
|
||||
"""RunPod serverless handler."""
|
||||
job_input = job.get('input', {})
|
||||
|
||||
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
||||
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
|
||||
use_lora = job_input.get('use_lora', True)
|
||||
|
||||
# Auto-detect: use LoRA for large models
|
||||
@@ -352,7 +380,7 @@ def handler(job):
|
||||
use_lora = False
|
||||
|
||||
if use_lora:
|
||||
return train_gemma_lora(job_input)
|
||||
return train_llm_lora(job_input)
|
||||
else:
|
||||
return train_bert_classifier(job_input)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user