mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
203 lines
5.1 KiB
Python
203 lines
5.1 KiB
Python
"""
|
|
Medical Drug Interaction Training Pipeline
|
|
|
|
This pipeline trains a model to detect drug-drug interactions (DDI)
|
|
from clinical documents in CCDA/FHIR formats.
|
|
"""
|
|
from kfp import dsl
|
|
from kfp import compiler
|
|
|
|
|
|
@dsl.component(
|
|
base_image="python:3.11-slim",
|
|
packages_to_install=["pandas", "lxml", "fhir.resources"]
|
|
)
|
|
def preprocess_ccda(
|
|
input_path: str,
|
|
output_path: dsl.OutputPath("Dataset")
|
|
):
|
|
"""Parse CCDA XML files and extract medication data."""
|
|
import json
|
|
from lxml import etree
|
|
|
|
# CCDA namespace
|
|
NS = {"hl7": "urn:hl7-org:v3"}
|
|
|
|
medications = []
|
|
|
|
# Parse CCDA and extract medications
|
|
# (simplified example - full implementation in production)
|
|
result = {
|
|
"source": "ccda",
|
|
"medications": medications,
|
|
"processed": True
|
|
}
|
|
|
|
with open(output_path, 'w') as f:
|
|
json.dump(result, f)
|
|
|
|
|
|
@dsl.component(
|
|
base_image="python:3.11-slim",
|
|
packages_to_install=["pandas", "fhir.resources"]
|
|
)
|
|
def preprocess_fhir(
|
|
input_path: str,
|
|
output_path: dsl.OutputPath("Dataset")
|
|
):
|
|
"""Parse FHIR R4 resources and extract medication data."""
|
|
import json
|
|
|
|
medications = []
|
|
|
|
result = {
|
|
"source": "fhir",
|
|
"medications": medications,
|
|
"processed": True
|
|
}
|
|
|
|
with open(output_path, 'w') as f:
|
|
json.dump(result, f)
|
|
|
|
|
|
@dsl.component(
|
|
base_image="python:3.11-slim",
|
|
packages_to_install=["requests"]
|
|
)
|
|
def normalize_rxnorm(
|
|
input_dataset: dsl.Input["Dataset"],
|
|
output_path: dsl.OutputPath("Dataset")
|
|
):
|
|
"""Normalize medication names using RxNorm API."""
|
|
import json
|
|
|
|
with open(input_dataset.path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
# Normalize medications via RxNorm
|
|
# (API call implementation)
|
|
|
|
data["normalized"] = True
|
|
|
|
with open(output_path, 'w') as f:
|
|
json.dump(data, f)
|
|
|
|
|
|
@dsl.component(
|
|
base_image="huggingface/transformers-pytorch-gpu:latest",
|
|
packages_to_install=["datasets", "accelerate", "scikit-learn"]
|
|
)
|
|
def train_ddi_model(
|
|
train_dataset: dsl.Input["Dataset"],
|
|
model_name: str,
|
|
epochs: int,
|
|
learning_rate: float,
|
|
output_model: dsl.OutputPath("Model")
|
|
):
|
|
"""Fine-tune a transformer model for DDI detection."""
|
|
import json
|
|
import os
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForSequenceClassification,
|
|
TrainingArguments,
|
|
Trainer
|
|
)
|
|
|
|
# Load base model
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForSequenceClassification.from_pretrained(
|
|
model_name,
|
|
num_labels=5 # DDI severity levels
|
|
)
|
|
|
|
# Training configuration
|
|
training_args = TrainingArguments(
|
|
output_dir=output_model,
|
|
num_train_epochs=epochs,
|
|
learning_rate=learning_rate,
|
|
per_device_train_batch_size=16,
|
|
evaluation_strategy="epoch",
|
|
save_strategy="epoch",
|
|
load_best_model_at_end=True,
|
|
)
|
|
|
|
# Train (placeholder - needs actual dataset loading)
|
|
print(f"Training {model_name} for {epochs} epochs")
|
|
|
|
# Save model
|
|
model.save_pretrained(output_model)
|
|
tokenizer.save_pretrained(output_model)
|
|
|
|
|
|
@dsl.component(
|
|
base_image="python:3.11-slim",
|
|
packages_to_install=["scikit-learn", "pandas"]
|
|
)
|
|
def evaluate_model(
|
|
model_path: dsl.Input["Model"],
|
|
test_dataset: dsl.Input["Dataset"],
|
|
metrics_output: dsl.OutputPath("Metrics")
|
|
):
|
|
"""Evaluate the trained model and output metrics."""
|
|
import json
|
|
|
|
metrics = {
|
|
"f1_micro": 0.0,
|
|
"f1_macro": 0.0,
|
|
"precision": 0.0,
|
|
"recall": 0.0,
|
|
"auprc": 0.0
|
|
}
|
|
|
|
with open(metrics_output, 'w') as f:
|
|
json.dump(metrics, f)
|
|
|
|
|
|
@dsl.pipeline(
|
|
name="med-rx-ddi-training",
|
|
description="Train DDI detection model on CCDA/FHIR clinical data"
|
|
)
|
|
def med_rx_training_pipeline(
|
|
ccda_input_path: str = "s3://minio/data/ccda/",
|
|
fhir_input_path: str = "s3://minio/data/fhir/",
|
|
base_model: str = "emilyalsentzer/Bio_ClinicalBERT",
|
|
epochs: int = 3,
|
|
learning_rate: float = 2e-5
|
|
):
|
|
"""
|
|
Full DDI training pipeline:
|
|
1. Preprocess CCDA and FHIR data
|
|
2. Normalize medications via RxNorm
|
|
3. Train transformer model
|
|
4. Evaluate and output metrics
|
|
"""
|
|
# Preprocess data sources
|
|
ccda_task = preprocess_ccda(input_path=ccda_input_path)
|
|
fhir_task = preprocess_fhir(input_path=fhir_input_path)
|
|
|
|
# Normalize CCDA data
|
|
normalize_ccda = normalize_rxnorm(input_dataset=ccda_task.outputs["output_path"])
|
|
|
|
# Train model (using CCDA for now)
|
|
train_task = train_ddi_model(
|
|
train_dataset=normalize_ccda.outputs["output_path"],
|
|
model_name=base_model,
|
|
epochs=epochs,
|
|
learning_rate=learning_rate
|
|
)
|
|
|
|
# Evaluate
|
|
eval_task = evaluate_model(
|
|
model_path=train_task.outputs["output_model"],
|
|
test_dataset=normalize_ccda.outputs["output_path"]
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
compiler.Compiler().compile(
|
|
pipeline_func=med_rx_training_pipeline,
|
|
package_path="med_rx_training_pipeline.yaml"
|
|
)
|
|
print("Pipeline compiled to med_rx_training_pipeline.yaml")
|