""" Medical Drug Interaction Training Pipeline This pipeline trains a model to detect drug-drug interactions (DDI) from clinical documents in CCDA/FHIR formats. """ from kfp import dsl from kfp import compiler @dsl.component( base_image="python:3.11-slim", packages_to_install=["pandas", "lxml", "fhir.resources"] ) def preprocess_ccda( input_path: str, output_path: dsl.OutputPath("Dataset") ): """Parse CCDA XML files and extract medication data.""" import json from lxml import etree # CCDA namespace NS = {"hl7": "urn:hl7-org:v3"} medications = [] # Parse CCDA and extract medications # (simplified example - full implementation in production) result = { "source": "ccda", "medications": medications, "processed": True } with open(output_path, 'w') as f: json.dump(result, f) @dsl.component( base_image="python:3.11-slim", packages_to_install=["pandas", "fhir.resources"] ) def preprocess_fhir( input_path: str, output_path: dsl.OutputPath("Dataset") ): """Parse FHIR R4 resources and extract medication data.""" import json medications = [] result = { "source": "fhir", "medications": medications, "processed": True } with open(output_path, 'w') as f: json.dump(result, f) @dsl.component( base_image="python:3.11-slim", packages_to_install=["requests"] ) def normalize_rxnorm( input_dataset: dsl.Input["Dataset"], output_path: dsl.OutputPath("Dataset") ): """Normalize medication names using RxNorm API.""" import json with open(input_dataset.path, 'r') as f: data = json.load(f) # Normalize medications via RxNorm # (API call implementation) data["normalized"] = True with open(output_path, 'w') as f: json.dump(data, f) @dsl.component( base_image="huggingface/transformers-pytorch-gpu:latest", packages_to_install=["datasets", "accelerate", "scikit-learn"] ) def train_ddi_model( train_dataset: dsl.Input["Dataset"], model_name: str, epochs: int, learning_rate: float, output_model: dsl.OutputPath("Model") ): """Fine-tune a transformer model for DDI detection.""" import json import os from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer ) # Load base model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=5 # DDI severity levels ) # Training configuration training_args = TrainingArguments( output_dir=output_model, num_train_epochs=epochs, learning_rate=learning_rate, per_device_train_batch_size=16, evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, ) # Train (placeholder - needs actual dataset loading) print(f"Training {model_name} for {epochs} epochs") # Save model model.save_pretrained(output_model) tokenizer.save_pretrained(output_model) @dsl.component( base_image="python:3.11-slim", packages_to_install=["scikit-learn", "pandas"] ) def evaluate_model( model_path: dsl.Input["Model"], test_dataset: dsl.Input["Dataset"], metrics_output: dsl.OutputPath("Metrics") ): """Evaluate the trained model and output metrics.""" import json metrics = { "f1_micro": 0.0, "f1_macro": 0.0, "precision": 0.0, "recall": 0.0, "auprc": 0.0 } with open(metrics_output, 'w') as f: json.dump(metrics, f) @dsl.pipeline( name="med-rx-ddi-training", description="Train DDI detection model on CCDA/FHIR clinical data" ) def med_rx_training_pipeline( ccda_input_path: str = "s3://minio/data/ccda/", fhir_input_path: str = "s3://minio/data/fhir/", base_model: str = "emilyalsentzer/Bio_ClinicalBERT", epochs: int = 3, learning_rate: float = 2e-5 ): """ Full DDI training pipeline: 1. Preprocess CCDA and FHIR data 2. Normalize medications via RxNorm 3. Train transformer model 4. Evaluate and output metrics """ # Preprocess data sources ccda_task = preprocess_ccda(input_path=ccda_input_path) fhir_task = preprocess_fhir(input_path=fhir_input_path) # Normalize CCDA data normalize_ccda = normalize_rxnorm(input_dataset=ccda_task.outputs["output_path"]) # Train model (using CCDA for now) train_task = train_ddi_model( train_dataset=normalize_ccda.outputs["output_path"], model_name=base_model, epochs=epochs, learning_rate=learning_rate ) # Evaluate eval_task = evaluate_model( model_path=train_task.outputs["output_model"], test_dataset=normalize_ccda.outputs["output_path"] ) if __name__ == "__main__": compiler.Compiler().compile( pipeline_func=med_rx_training_pipeline, package_path="med_rx_training_pipeline.yaml" ) print("Pipeline compiled to med_rx_training_pipeline.yaml")