From 59c808cb3a45bba225ff0e6863d941a84bb0a5e0 Mon Sep 17 00:00:00 2001 From: Greg Hendrickson Date: Tue, 3 Feb 2026 15:13:21 +0000 Subject: [PATCH] feat: add S3 model upload support - Add upload_to_s3 function to handler - Save trained BERT models to S3 when credentials provided - Save LoRA adapters to S3 when credentials provided - Input params: s3_bucket, s3_prefix, aws_access_key_id, aws_secret_access_key, aws_region --- components/runpod_trainer/handler.py | 80 ++++++++++++++++++++++++++- components/runpod_trainer/runpod.toml | 12 ++++ 2 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 components/runpod_trainer/runpod.toml diff --git a/components/runpod_trainer/handler.py b/components/runpod_trainer/handler.py index f8527ab..f74ffdd 100644 --- a/components/runpod_trainer/handler.py +++ b/components/runpod_trainer/handler.py @@ -3,13 +3,49 @@ RunPod Serverless Handler for DDI Model Training Supports both BERT-style classification and LLM fine-tuning with LoRA. Uses 176K real DrugBank DDI samples with drug names. +Saves trained models to S3. """ import os import json import runpod +import boto3 +from datetime import datetime from typing import Dict, Any, List, Optional +def upload_to_s3(local_path: str, s3_bucket: str, s3_prefix: str, aws_credentials: Dict) -> str: + """Upload a directory to S3 and return the S3 URI.""" + import tarfile + import tempfile + + # Create tar.gz of the model directory + timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + tar_name = f"model_{timestamp}.tar.gz" + tar_path = os.path.join(tempfile.gettempdir(), tar_name) + + print(f"Creating archive: {tar_path}") + with tarfile.open(tar_path, "w:gz") as tar: + tar.add(local_path, arcname="model") + + # Upload to S3 + s3_client = boto3.client( + 's3', + aws_access_key_id=aws_credentials.get('aws_access_key_id'), + aws_secret_access_key=aws_credentials.get('aws_secret_access_key'), + region_name=aws_credentials.get('aws_region', 'us-east-1') + ) + + s3_key = f"{s3_prefix}/{tar_name}" + print(f"Uploading to s3://{s3_bucket}/{s3_key}") + + s3_client.upload_file(tar_path, s3_bucket, s3_key) + + # Cleanup + os.remove(tar_path) + + return f"s3://{s3_bucket}/{s3_key}" + + # DDI severity labels DDI_SEVERITY = { 1: "minor", @@ -229,10 +265,31 @@ def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]: print(f"Training complete! Loss: {metrics['train_loss']:.4f}") + # Save LoRA adapter to S3 if credentials provided + s3_uri = None + s3_bucket = job_input.get('s3_bucket') + if s3_bucket: + save_dir = os.path.join(work_dir, 'lora_adapter') + print(f"Saving LoRA adapter to {save_dir}...") + model.save_pretrained(save_dir) + tokenizer.save_pretrained(save_dir) + + aws_creds = { + 'aws_access_key_id': job_input.get('aws_access_key_id'), + 'aws_secret_access_key': job_input.get('aws_secret_access_key'), + 'aws_region': job_input.get('aws_region', 'us-east-1'), + } + model_short = model_name.split('/')[-1] + s3_prefix = job_input.get('s3_prefix', f'ddi-models/lora-{model_short}') + s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds) + metrics['s3_uri'] = s3_uri + print(f"LoRA adapter uploaded to {s3_uri}") + return { 'status': 'success', 'metrics': metrics, - 'message': f'Gemma 3 fine-tuned on {len(raw_data):,} real DrugBank DDI samples' + 'model_uri': s3_uri, + 'message': f'LLM fine-tuned on {len(raw_data):,} real DrugBank DDI samples' } except Exception as e: @@ -357,7 +414,26 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]: 'eval_f1_weighted': float(eval_result['eval_f1_weighted']), }) - return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained on DrugBank data'} + # Save model to S3 if credentials provided + s3_uri = None + s3_bucket = job_input.get('s3_bucket') + if s3_bucket: + save_dir = os.path.join(work_dir, 'saved_model') + print(f"Saving model to {save_dir}...") + trainer.save_model(save_dir) + tokenizer.save_pretrained(save_dir) + + aws_creds = { + 'aws_access_key_id': job_input.get('aws_access_key_id'), + 'aws_secret_access_key': job_input.get('aws_secret_access_key'), + 'aws_region': job_input.get('aws_region', 'us-east-1'), + } + s3_prefix = job_input.get('s3_prefix', 'ddi-models/bert') + s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds) + metrics['s3_uri'] = s3_uri + print(f"Model uploaded to {s3_uri}") + + return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'BERT classifier trained on DrugBank data'} except Exception as e: import traceback diff --git a/components/runpod_trainer/runpod.toml b/components/runpod_trainer/runpod.toml new file mode 100644 index 0000000..e9a48ef --- /dev/null +++ b/components/runpod_trainer/runpod.toml @@ -0,0 +1,12 @@ +[project] +name = "ddi-trainer" +base_image = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04" +gpu_types = ["NVIDIA RTX A4000", "NVIDIA RTX A5000", "NVIDIA RTX A6000", "NVIDIA GeForce RTX 4090"] +gpu_count = 1 +volume_mount_path = "/runpod-volume" + +[project.env_vars] + +[runtime] +handler_path = "handler.py" +requirements_path = "requirements.txt"