Implementation:Kubeflow Pipelines Dsl Graph Component
| Sources | Kubeflow Pipelines, KFP Control Flow |
|---|---|
| Domains | ML_Pipelines, Recursion |
| Last Updated | 2026-02-13 |
Overview
Concrete tool for defining recursive sub-pipelines provided by the KFP DSL graph_component decorator.
Description
@dsl.graph_component decorates a function as a reusable sub-pipeline that can call itself recursively. Inside the function, tasks are defined and wired just like a pipeline. Recursive calls are gated by dsl.Condition to ensure termination. Caching must be set to max_cache_staleness = "P0D" on tasks within the loop to prevent infinite loop from cached results.
Usage
Use to implement iterative loops within a pipeline — train-evaluate-repeat cycles, recursive data processing, or any workflow requiring convergence.
Code Reference
Source Location: Repository: kubeflow/pipelines, File: samples/core/recursion/recursion.py (L45-56 graph component, L59-70 main pipeline), samples/core/train_until_good/train_until_good.py (L33-64)
Signature:
@dsl.graph_component
def sub_pipeline_func(
param1: type,
param2: type,
):
# Define tasks
task = some_component(input=param1)
# Gate recursive call
with dsl.Condition(task.output == 'continue'):
sub_pipeline_func(param1=task.output, param2=param2)
Import: from kfp import dsl
I/O Contract
| Direction | Description |
|---|---|
| Inputs | Function parameters define the graph component's inputs (passed at invocation and during recursive self-call) |
| Outputs | GraphComponent — a sub-DAG invocable within a pipeline; results accessible via .outputs
|
Usage Examples
Example 1: Coin flip recursion (recursion.py)
@dsl.graph_component
def flip_component(flip_result):
print_flip = print_op(flip_result)
flipA = flip_coin_op().after(print_flip)
flipA.execution_options.caching_strategy.max_cache_staleness = "P0D"
with dsl.Condition(flipA.output == 'heads'):
flip_component(flipA.output)
@dsl.pipeline(name='recursive-loop-pipeline')
def flipcoin():
first_flip = flip_coin_op()
first_flip.execution_options.caching_strategy.max_cache_staleness = "P0D"
flip_loop = flip_component(first_flip.output)
print_op('cool, it is over.').after(flip_loop)
Example 2: Iterative training (train_until_good.py)
@dsl.graph_component
def train_until_low_error(starting_model, training_data, true_values):
model = xgboost_train_on_csv_op(
training_data=training_data,
starting_model=starting_model,
label_column=0, objective='reg:squarederror', num_iterations=50,
).outputs['model']
predictions = xgboost_predict_on_csv_op(data=training_data, model=model, label_column=0).output
metrics_task = calculate_regression_metrics_from_csv_op(true_values=true_values, predicted_values=predictions)
with dsl.Condition(metrics_task.outputs['mean_squared_error'] > 0.01):
train_until_low_error(starting_model=model, training_data=training_data, true_values=true_values)