mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
feat: Switch to Llama 3.1 8B (Bedrock-compatible)
- Default model now meta-llama/Llama-3.1-8B-Instruct - Added multi-model chat format support: - Llama 3 format - Mistral/Mixtral format - Qwen format - Gemma format - Trained model can be imported to AWS Bedrock
This commit is contained in:
@@ -51,20 +51,48 @@ def get_curated_fallback() -> List[Dict]:
|
|||||||
return patterns * 50 # 200 samples
|
return patterns * 50 # 200 samples
|
||||||
|
|
||||||
|
|
||||||
def format_for_gemma(item: Dict) -> str:
|
def format_for_llm(item: Dict, model_name: str = "") -> str:
|
||||||
"""Format DDI item for Gemma instruction tuning."""
|
"""Format DDI item for LLM instruction tuning. Auto-detects format based on model."""
|
||||||
severity_name = DDI_SEVERITY.get(item['severity'], 'unknown')
|
severity_name = DDI_SEVERITY.get(item['severity'], 'unknown')
|
||||||
|
|
||||||
return f"""<start_of_turn>user
|
user_msg = f"Analyze the drug interaction between {item['drug1']} and {item['drug2']}."
|
||||||
Analyze the drug interaction between {item['drug1']} and {item['drug2']}.<end_of_turn>
|
assistant_msg = f"""Interaction: {item['interaction_text']}
|
||||||
<start_of_turn>model
|
|
||||||
Interaction: {item['interaction_text']}
|
|
||||||
Severity: {severity_name.upper()}
|
Severity: {severity_name.upper()}
|
||||||
Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}<end_of_turn>"""
|
Recommendation: {"Avoid this combination" if item['severity'] >= 3 else "Monitor patient" if item['severity'] == 2 else "Generally safe"}"""
|
||||||
|
|
||||||
|
# Llama 3 format
|
||||||
|
if 'llama' in model_name.lower():
|
||||||
|
return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{user_msg}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{assistant_msg}<|eot_id|>"""
|
||||||
|
|
||||||
|
# Mistral/Mixtral format
|
||||||
|
elif 'mistral' in model_name.lower() or 'mixtral' in model_name.lower():
|
||||||
|
return f"""<s>[INST] {user_msg} [/INST] {assistant_msg}</s>"""
|
||||||
|
|
||||||
|
# Qwen format
|
||||||
|
elif 'qwen' in model_name.lower():
|
||||||
|
return f"""<|im_start|>user
|
||||||
|
{user_msg}<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
{assistant_msg}<|im_end|>"""
|
||||||
|
|
||||||
|
# Gemma format
|
||||||
|
elif 'gemma' in model_name.lower():
|
||||||
|
return f"""<start_of_turn>user
|
||||||
|
{user_msg}<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
{assistant_msg}<end_of_turn>"""
|
||||||
|
|
||||||
|
# Generic fallback
|
||||||
|
else:
|
||||||
|
return f"""### User: {user_msg}\n### Assistant: {assistant_msg}"""
|
||||||
|
|
||||||
|
|
||||||
def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Train Gemma 3 with QLoRA for DDI classification."""
|
"""Train LLM with QLoRA for DDI classification. Supports Llama, Mistral, Qwen, Gemma."""
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -78,8 +106,8 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
# Parameters
|
# Parameters - Default to Llama 3.1 8B (Bedrock-compatible)
|
||||||
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
|
||||||
max_samples = job_input.get('max_samples', 10000)
|
max_samples = job_input.get('max_samples', 10000)
|
||||||
epochs = job_input.get('epochs', 1)
|
epochs = job_input.get('epochs', 1)
|
||||||
learning_rate = job_input.get('learning_rate', 2e-4)
|
learning_rate = job_input.get('learning_rate', 2e-4)
|
||||||
@@ -101,8 +129,8 @@ def train_gemma_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
print(f"Loading DrugBank DDI data (max {max_samples})...")
|
print(f"Loading DrugBank DDI data (max {max_samples})...")
|
||||||
raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter)
|
raw_data = load_drugbank_data(max_samples=max_samples, severity_filter=severity_filter)
|
||||||
|
|
||||||
# Format for Gemma
|
# Format for the target LLM
|
||||||
formatted_data = [{"text": format_for_gemma(item)} for item in raw_data]
|
formatted_data = [{"text": format_for_llm(item, model_name)} for item in raw_data]
|
||||||
dataset = Dataset.from_list(formatted_data)
|
dataset = Dataset.from_list(formatted_data)
|
||||||
|
|
||||||
print(f"Dataset size: {len(dataset)}")
|
print(f"Dataset size: {len(dataset)}")
|
||||||
@@ -342,7 +370,7 @@ def handler(job):
|
|||||||
"""RunPod serverless handler."""
|
"""RunPod serverless handler."""
|
||||||
job_input = job.get('input', {})
|
job_input = job.get('input', {})
|
||||||
|
|
||||||
model_name = job_input.get('model_name', 'google/gemma-3-12b-it')
|
model_name = job_input.get('model_name', 'meta-llama/Llama-3.1-8B-Instruct')
|
||||||
use_lora = job_input.get('use_lora', True)
|
use_lora = job_input.get('use_lora', True)
|
||||||
|
|
||||||
# Auto-detect: use LoRA for large models
|
# Auto-detect: use LoRA for large models
|
||||||
@@ -352,7 +380,7 @@ def handler(job):
|
|||||||
use_lora = False
|
use_lora = False
|
||||||
|
|
||||||
if use_lora:
|
if use_lora:
|
||||||
return train_gemma_lora(job_input)
|
return train_llm_lora(job_input)
|
||||||
else:
|
else:
|
||||||
return train_bert_classifier(job_input)
|
return train_bert_classifier(job_input)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user