mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 14:55:11 +00:00
fix: remove MinIO dependency, use inline training data
This commit is contained in:
@@ -7,33 +7,7 @@ for drug-drug interaction detection.
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import runpod
|
import runpod
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
|
||||||
def download_from_minio(bucket: str, key: str, local_path: str):
|
|
||||||
"""Download file from MinIO."""
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
s3 = boto3.client(
|
|
||||||
's3',
|
|
||||||
endpoint_url=os.environ['MINIO_ENDPOINT'],
|
|
||||||
aws_access_key_id=os.environ['MINIO_ACCESS_KEY'],
|
|
||||||
aws_secret_access_key=os.environ['MINIO_SECRET_KEY']
|
|
||||||
)
|
|
||||||
s3.download_file(bucket, key, local_path)
|
|
||||||
|
|
||||||
|
|
||||||
def upload_to_minio(local_path: str, bucket: str, key: str):
|
|
||||||
"""Upload file to MinIO."""
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
s3 = boto3.client(
|
|
||||||
's3',
|
|
||||||
endpoint_url=os.environ['MINIO_ENDPOINT'],
|
|
||||||
aws_access_key_id=os.environ['MINIO_ACCESS_KEY'],
|
|
||||||
aws_secret_access_key=os.environ['MINIO_SECRET_KEY']
|
|
||||||
)
|
|
||||||
s3.upload_file(local_path, bucket, key)
|
|
||||||
|
|
||||||
|
|
||||||
def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
@@ -43,11 +17,10 @@ def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
Expected input:
|
Expected input:
|
||||||
{
|
{
|
||||||
"model_name": "emilyalsentzer/Bio_ClinicalBERT",
|
"model_name": "emilyalsentzer/Bio_ClinicalBERT",
|
||||||
"dataset_path": "datasets/ddi_train.json",
|
"training_data": [{"text": "...", "label": 0}, ...], # Inline data
|
||||||
"epochs": 3,
|
"epochs": 3,
|
||||||
"learning_rate": 2e-5,
|
"learning_rate": 2e-5,
|
||||||
"batch_size": 16,
|
"batch_size": 16
|
||||||
"output_path": "models/ddi_model_v1"
|
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
@@ -63,34 +36,51 @@ def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
|
|
||||||
# Extract parameters
|
# Extract parameters
|
||||||
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
model_name = job_input.get('model_name', 'emilyalsentzer/Bio_ClinicalBERT')
|
||||||
dataset_path = job_input.get('dataset_path', 'datasets/ddi_train.json')
|
training_data = job_input.get('training_data', None)
|
||||||
epochs = job_input.get('epochs', 3)
|
epochs = job_input.get('epochs', 3)
|
||||||
learning_rate = job_input.get('learning_rate', 2e-5)
|
learning_rate = job_input.get('learning_rate', 2e-5)
|
||||||
batch_size = job_input.get('batch_size', 16)
|
batch_size = job_input.get('batch_size', 16)
|
||||||
output_path = job_input.get('output_path', 'models/ddi_model')
|
|
||||||
|
# Use sample data if none provided
|
||||||
|
if not training_data:
|
||||||
|
print("No training data provided, using sample DDI dataset...")
|
||||||
|
training_data = [
|
||||||
|
{"text": "warfarin and aspirin interaction causes bleeding risk", "label": 3},
|
||||||
|
{"text": "metformin with lisinopril is safe combination", "label": 0},
|
||||||
|
{"text": "fluoxetine tramadol causes serotonin syndrome", "label": 4},
|
||||||
|
{"text": "simvastatin amiodarone increases myopathy risk", "label": 3},
|
||||||
|
{"text": "omeprazole reduces clopidogrel efficacy", "label": 2},
|
||||||
|
{"text": "digoxin amiodarone toxicity risk elevated", "label": 3},
|
||||||
|
{"text": "lithium NSAIDs increases lithium levels", "label": 3},
|
||||||
|
{"text": "benzodiazepines opioids respiratory depression", "label": 4},
|
||||||
|
{"text": "metronidazole alcohol disulfiram reaction", "label": 4},
|
||||||
|
{"text": "ACE inhibitors potassium hyperkalemia risk", "label": 2},
|
||||||
|
{"text": "amlodipine atorvastatin safe combination", "label": 0},
|
||||||
|
{"text": "gabapentin pregabalin CNS depression additive", "label": 2},
|
||||||
|
{"text": "warfarin vitamin K antagonism reduced effect", "label": 2},
|
||||||
|
{"text": "insulin metformin hypoglycemia risk", "label": 1},
|
||||||
|
{"text": "aspirin ibuprofen GI bleeding increased", "label": 3},
|
||||||
|
] * 10 # 150 samples
|
||||||
|
|
||||||
# Create temp directory
|
# Create temp directory
|
||||||
work_dir = tempfile.mkdtemp()
|
work_dir = tempfile.mkdtemp()
|
||||||
data_file = os.path.join(work_dir, 'train.json')
|
|
||||||
model_dir = os.path.join(work_dir, 'model')
|
model_dir = os.path.join(work_dir, 'model')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Download training data from MinIO
|
print(f"Training samples: {len(training_data)}")
|
||||||
print(f"Downloading dataset from {dataset_path}...")
|
print(f"Model: {model_name}")
|
||||||
download_from_minio('datasets', dataset_path, data_file)
|
print(f"Epochs: {epochs}, Batch size: {batch_size}")
|
||||||
|
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
with open(data_file, 'r') as f:
|
dataset = Dataset.from_list(training_data)
|
||||||
train_data = json.load(f)
|
|
||||||
|
|
||||||
dataset = Dataset.from_list(train_data)
|
|
||||||
|
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
print(f"Loading model: {model_name}")
|
print(f"Loading model: {model_name}")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
num_labels=5 # DDI severity levels: none, minor, moderate, major, contraindicated
|
num_labels=5 # DDI severity: none, minor, moderate, major, contraindicated
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tokenize dataset
|
# Tokenize dataset
|
||||||
@@ -99,7 +89,7 @@ def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
examples['text'],
|
examples['text'],
|
||||||
padding='max_length',
|
padding='max_length',
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=512
|
max_length=128
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
||||||
@@ -110,15 +100,12 @@ def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
num_train_epochs=epochs,
|
num_train_epochs=epochs,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
per_device_train_batch_size=batch_size,
|
per_device_train_batch_size=batch_size,
|
||||||
per_device_eval_batch_size=batch_size,
|
warmup_steps=50,
|
||||||
warmup_steps=100,
|
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
logging_dir=os.path.join(work_dir, 'logs'),
|
|
||||||
logging_steps=10,
|
logging_steps=10,
|
||||||
save_strategy='epoch',
|
save_strategy='epoch',
|
||||||
evaluation_strategy='epoch' if 'validation' in train_data else 'no',
|
|
||||||
load_best_model_at_end=True if 'validation' in train_data else False,
|
|
||||||
fp16=torch.cuda.is_available(),
|
fp16=torch.cuda.is_available(),
|
||||||
|
report_to='none',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize trainer
|
# Initialize trainer
|
||||||
@@ -132,41 +119,29 @@ def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
print("Starting training...")
|
print("Starting training...")
|
||||||
train_result = trainer.train()
|
train_result = trainer.train()
|
||||||
|
|
||||||
# Save model
|
# Get metrics
|
||||||
print("Saving model...")
|
|
||||||
trainer.save_model(model_dir)
|
|
||||||
tokenizer.save_pretrained(model_dir)
|
|
||||||
|
|
||||||
# Save training metrics
|
|
||||||
metrics = {
|
metrics = {
|
||||||
'train_loss': train_result.training_loss,
|
'train_loss': float(train_result.training_loss),
|
||||||
'epochs': epochs,
|
'epochs': epochs,
|
||||||
'model_name': model_name,
|
'model_name': model_name,
|
||||||
'samples': len(dataset)
|
'samples': len(training_data),
|
||||||
|
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(os.path.join(model_dir, 'metrics.json'), 'w') as f:
|
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
|
||||||
json.dump(metrics, f)
|
|
||||||
|
|
||||||
# Upload model to MinIO
|
|
||||||
print(f"Uploading model to {output_path}...")
|
|
||||||
for root, dirs, files in os.walk(model_dir):
|
|
||||||
for file in files:
|
|
||||||
local_file = os.path.join(root, file)
|
|
||||||
relative_path = os.path.relpath(local_file, model_dir)
|
|
||||||
minio_key = f"{output_path}/{relative_path}"
|
|
||||||
upload_to_minio(local_file, 'models', minio_key)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'status': 'success',
|
'status': 'success',
|
||||||
'model_path': f"s3://models/{output_path}",
|
'metrics': metrics,
|
||||||
'metrics': metrics
|
'message': 'Model trained successfully'
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
return {
|
return {
|
||||||
'status': 'error',
|
'status': 'error',
|
||||||
'error': str(e)
|
'error': str(e),
|
||||||
|
'traceback': traceback.format_exc()
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
# Cleanup
|
# Cleanup
|
||||||
@@ -175,7 +150,7 @@ def train_ddi_model(job_input: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
|
|
||||||
def handler(job):
|
def handler(job):
|
||||||
"""RunPod serverless handler."""
|
"""RunPod serverless handler."""
|
||||||
job_input = job['input']
|
job_input = job.get('input', {})
|
||||||
return train_ddi_model(job_input)
|
return train_ddi_model(job_input)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user