feat: add S3 model upload support

- Add upload_to_s3 function to handler
- Save trained BERT models to S3 when credentials provided
- Save LoRA adapters to S3 when credentials provided
- Input params: s3_bucket, s3_prefix, aws_access_key_id, aws_secret_access_key, aws_region
This commit is contained in:
2026-02-03 15:13:21 +00:00
parent c6fd06369f
commit 59c808cb3a
2 changed files with 90 additions and 2 deletions

View File

@@ -3,13 +3,49 @@ RunPod Serverless Handler for DDI Model Training
Supports both BERT-style classification and LLM fine-tuning with LoRA.
Uses 176K real DrugBank DDI samples with drug names.
Saves trained models to S3.
"""
import os
import json
import runpod
import boto3
from datetime import datetime
from typing import Dict, Any, List, Optional
def upload_to_s3(local_path: str, s3_bucket: str, s3_prefix: str, aws_credentials: Dict) -> str:
"""Upload a directory to S3 and return the S3 URI."""
import tarfile
import tempfile
# Create tar.gz of the model directory
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
tar_name = f"model_{timestamp}.tar.gz"
tar_path = os.path.join(tempfile.gettempdir(), tar_name)
print(f"Creating archive: {tar_path}")
with tarfile.open(tar_path, "w:gz") as tar:
tar.add(local_path, arcname="model")
# Upload to S3
s3_client = boto3.client(
's3',
aws_access_key_id=aws_credentials.get('aws_access_key_id'),
aws_secret_access_key=aws_credentials.get('aws_secret_access_key'),
region_name=aws_credentials.get('aws_region', 'us-east-1')
)
s3_key = f"{s3_prefix}/{tar_name}"
print(f"Uploading to s3://{s3_bucket}/{s3_key}")
s3_client.upload_file(tar_path, s3_bucket, s3_key)
# Cleanup
os.remove(tar_path)
return f"s3://{s3_bucket}/{s3_key}"
# DDI severity labels
DDI_SEVERITY = {
1: "minor",
@@ -229,10 +265,31 @@ def train_llm_lora(job_input: Dict[str, Any]) -> Dict[str, Any]:
print(f"Training complete! Loss: {metrics['train_loss']:.4f}")
# Save LoRA adapter to S3 if credentials provided
s3_uri = None
s3_bucket = job_input.get('s3_bucket')
if s3_bucket:
save_dir = os.path.join(work_dir, 'lora_adapter')
print(f"Saving LoRA adapter to {save_dir}...")
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_region': job_input.get('aws_region', 'us-east-1'),
}
model_short = model_name.split('/')[-1]
s3_prefix = job_input.get('s3_prefix', f'ddi-models/lora-{model_short}')
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
metrics['s3_uri'] = s3_uri
print(f"LoRA adapter uploaded to {s3_uri}")
return {
'status': 'success',
'metrics': metrics,
'message': f'Gemma 3 fine-tuned on {len(raw_data):,} real DrugBank DDI samples'
'model_uri': s3_uri,
'message': f'LLM fine-tuned on {len(raw_data):,} real DrugBank DDI samples'
}
except Exception as e:
@@ -357,7 +414,26 @@ def train_bert_classifier(job_input: Dict[str, Any]) -> Dict[str, Any]:
'eval_f1_weighted': float(eval_result['eval_f1_weighted']),
})
return {'status': 'success', 'metrics': metrics, 'message': 'BERT classifier trained on DrugBank data'}
# Save model to S3 if credentials provided
s3_uri = None
s3_bucket = job_input.get('s3_bucket')
if s3_bucket:
save_dir = os.path.join(work_dir, 'saved_model')
print(f"Saving model to {save_dir}...")
trainer.save_model(save_dir)
tokenizer.save_pretrained(save_dir)
aws_creds = {
'aws_access_key_id': job_input.get('aws_access_key_id'),
'aws_secret_access_key': job_input.get('aws_secret_access_key'),
'aws_region': job_input.get('aws_region', 'us-east-1'),
}
s3_prefix = job_input.get('s3_prefix', 'ddi-models/bert')
s3_uri = upload_to_s3(save_dir, s3_bucket, s3_prefix, aws_creds)
metrics['s3_uri'] = s3_uri
print(f"Model uploaded to {s3_uri}")
return {'status': 'success', 'metrics': metrics, 'model_uri': s3_uri, 'message': 'BERT classifier trained on DrugBank data'}
except Exception as e:
import traceback