mirror of
https://github.com/ghndrx/kubeflow-pipelines.git
synced 2026-02-10 06:45:13 +00:00
Initial Kubeflow GitOps setup with example pipelines
This commit is contained in:
44
pipelines/examples/hello_world.py
Normal file
44
pipelines/examples/hello_world.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Hello World Pipeline - Basic Kubeflow Pipeline Example
|
||||
"""
|
||||
from kfp import dsl
|
||||
from kfp import compiler
|
||||
|
||||
|
||||
@dsl.component(base_image="python:3.11-slim")
|
||||
def say_hello(name: str) -> str:
|
||||
"""Simple component that returns a greeting."""
|
||||
message = f"Hello, {name}! Welcome to Kubeflow Pipelines."
|
||||
print(message)
|
||||
return message
|
||||
|
||||
|
||||
@dsl.component(base_image="python:3.11-slim")
|
||||
def process_greeting(greeting: str) -> str:
|
||||
"""Process the greeting message."""
|
||||
processed = greeting.upper()
|
||||
print(f"Processed: {processed}")
|
||||
return processed
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="hello-world-pipeline",
|
||||
description="A simple hello world pipeline to test Kubeflow setup"
|
||||
)
|
||||
def hello_world_pipeline(name: str = "Kubeflow User"):
|
||||
"""
|
||||
Simple pipeline that:
|
||||
1. Generates a greeting
|
||||
2. Processes it
|
||||
"""
|
||||
hello_task = say_hello(name=name)
|
||||
process_task = process_greeting(greeting=hello_task.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Compile the pipeline
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=hello_world_pipeline,
|
||||
package_path="hello_world_pipeline.yaml"
|
||||
)
|
||||
print("Pipeline compiled to hello_world_pipeline.yaml")
|
||||
202
pipelines/med_rx_training.py
Normal file
202
pipelines/med_rx_training.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Medical Drug Interaction Training Pipeline
|
||||
|
||||
This pipeline trains a model to detect drug-drug interactions (DDI)
|
||||
from clinical documents in CCDA/FHIR formats.
|
||||
"""
|
||||
from kfp import dsl
|
||||
from kfp import compiler
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.11-slim",
|
||||
packages_to_install=["pandas", "lxml", "fhir.resources"]
|
||||
)
|
||||
def preprocess_ccda(
|
||||
input_path: str,
|
||||
output_path: dsl.OutputPath("Dataset")
|
||||
):
|
||||
"""Parse CCDA XML files and extract medication data."""
|
||||
import json
|
||||
from lxml import etree
|
||||
|
||||
# CCDA namespace
|
||||
NS = {"hl7": "urn:hl7-org:v3"}
|
||||
|
||||
medications = []
|
||||
|
||||
# Parse CCDA and extract medications
|
||||
# (simplified example - full implementation in production)
|
||||
result = {
|
||||
"source": "ccda",
|
||||
"medications": medications,
|
||||
"processed": True
|
||||
}
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(result, f)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.11-slim",
|
||||
packages_to_install=["pandas", "fhir.resources"]
|
||||
)
|
||||
def preprocess_fhir(
|
||||
input_path: str,
|
||||
output_path: dsl.OutputPath("Dataset")
|
||||
):
|
||||
"""Parse FHIR R4 resources and extract medication data."""
|
||||
import json
|
||||
|
||||
medications = []
|
||||
|
||||
result = {
|
||||
"source": "fhir",
|
||||
"medications": medications,
|
||||
"processed": True
|
||||
}
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(result, f)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.11-slim",
|
||||
packages_to_install=["requests"]
|
||||
)
|
||||
def normalize_rxnorm(
|
||||
input_dataset: dsl.Input["Dataset"],
|
||||
output_path: dsl.OutputPath("Dataset")
|
||||
):
|
||||
"""Normalize medication names using RxNorm API."""
|
||||
import json
|
||||
|
||||
with open(input_dataset.path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Normalize medications via RxNorm
|
||||
# (API call implementation)
|
||||
|
||||
data["normalized"] = True
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="huggingface/transformers-pytorch-gpu:latest",
|
||||
packages_to_install=["datasets", "accelerate", "scikit-learn"]
|
||||
)
|
||||
def train_ddi_model(
|
||||
train_dataset: dsl.Input["Dataset"],
|
||||
model_name: str,
|
||||
epochs: int,
|
||||
learning_rate: float,
|
||||
output_model: dsl.OutputPath("Model")
|
||||
):
|
||||
"""Fine-tune a transformer model for DDI detection."""
|
||||
import json
|
||||
import os
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
TrainingArguments,
|
||||
Trainer
|
||||
)
|
||||
|
||||
# Load base model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name,
|
||||
num_labels=5 # DDI severity levels
|
||||
)
|
||||
|
||||
# Training configuration
|
||||
training_args = TrainingArguments(
|
||||
output_dir=output_model,
|
||||
num_train_epochs=epochs,
|
||||
learning_rate=learning_rate,
|
||||
per_device_train_batch_size=16,
|
||||
evaluation_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
|
||||
# Train (placeholder - needs actual dataset loading)
|
||||
print(f"Training {model_name} for {epochs} epochs")
|
||||
|
||||
# Save model
|
||||
model.save_pretrained(output_model)
|
||||
tokenizer.save_pretrained(output_model)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.11-slim",
|
||||
packages_to_install=["scikit-learn", "pandas"]
|
||||
)
|
||||
def evaluate_model(
|
||||
model_path: dsl.Input["Model"],
|
||||
test_dataset: dsl.Input["Dataset"],
|
||||
metrics_output: dsl.OutputPath("Metrics")
|
||||
):
|
||||
"""Evaluate the trained model and output metrics."""
|
||||
import json
|
||||
|
||||
metrics = {
|
||||
"f1_micro": 0.0,
|
||||
"f1_macro": 0.0,
|
||||
"precision": 0.0,
|
||||
"recall": 0.0,
|
||||
"auprc": 0.0
|
||||
}
|
||||
|
||||
with open(metrics_output, 'w') as f:
|
||||
json.dump(metrics, f)
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="med-rx-ddi-training",
|
||||
description="Train DDI detection model on CCDA/FHIR clinical data"
|
||||
)
|
||||
def med_rx_training_pipeline(
|
||||
ccda_input_path: str = "s3://minio/data/ccda/",
|
||||
fhir_input_path: str = "s3://minio/data/fhir/",
|
||||
base_model: str = "emilyalsentzer/Bio_ClinicalBERT",
|
||||
epochs: int = 3,
|
||||
learning_rate: float = 2e-5
|
||||
):
|
||||
"""
|
||||
Full DDI training pipeline:
|
||||
1. Preprocess CCDA and FHIR data
|
||||
2. Normalize medications via RxNorm
|
||||
3. Train transformer model
|
||||
4. Evaluate and output metrics
|
||||
"""
|
||||
# Preprocess data sources
|
||||
ccda_task = preprocess_ccda(input_path=ccda_input_path)
|
||||
fhir_task = preprocess_fhir(input_path=fhir_input_path)
|
||||
|
||||
# Normalize CCDA data
|
||||
normalize_ccda = normalize_rxnorm(input_dataset=ccda_task.outputs["output_path"])
|
||||
|
||||
# Train model (using CCDA for now)
|
||||
train_task = train_ddi_model(
|
||||
train_dataset=normalize_ccda.outputs["output_path"],
|
||||
model_name=base_model,
|
||||
epochs=epochs,
|
||||
learning_rate=learning_rate
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
eval_task = evaluate_model(
|
||||
model_path=train_task.outputs["output_model"],
|
||||
test_dataset=normalize_ccda.outputs["output_path"]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=med_rx_training_pipeline,
|
||||
package_path="med_rx_training_pipeline.yaml"
|
||||
)
|
||||
print("Pipeline compiled to med_rx_training_pipeline.yaml")
|
||||
Reference in New Issue
Block a user