Heuristic:Spotify Luigi Parameter Propagation Decorators
| Knowledge Sources | |
|---|---|
| Domains | Pipeline_Framework, Code_Reuse |
| Last Updated | 2026-02-10 07:00 GMT |
Overview
Code reuse pattern using `@inherits` and `@requires` decorators to automatically propagate parameters between dependent tasks.
Description
When building Luigi pipelines, downstream tasks often need the same parameters as upstream tasks (e.g., `date`, `environment`). Instead of manually redeclaring these parameters on every task, Luigi provides the `@inherits` and `@requires` decorators. `@inherits(TaskA)` copies all parameters from TaskA onto the decorated class. `@requires(TaskA)` does the same and auto-generates a `requires()` method that returns an instance of TaskA with matching parameter values. The `clone()` method handles copying only the parameters that are common between tasks.
Usage
Use these decorators when you have a chain of tasks that share parameters. The `@requires` decorator is the most convenient for simple linear chains. Use `@inherits` when you need custom `requires()` logic but still want parameter propagation.
The Insight (Rule of Thumb)
- Action: Use `@requires(UpstreamTask)` for simple linear pipelines. Use `@inherits(UpstreamTask)` when you need custom `requires()` logic. Use `self.clone(TaskClass)` to create instances with matching parameters.
- Value: Eliminates parameter duplication across task chains. Adding a new parameter to an upstream task automatically propagates it downstream.
- Trade-off: All parameters from the inherited task become parameters of the decorated task. This can create unexpected CLI arguments. Carefully consider which tasks to inherit from.
Reasoning
Without these decorators, a pipeline with 5 tasks sharing a `date` parameter would need `date = luigi.DateParameter()` declared on all 5 classes. If a new shared parameter is added, all 5 classes need updating. With `@inherits`, the parameter is declared once and propagated automatically.
The `common_params()` utility function ensures that only parameters common to both tasks are passed during `clone()`, preventing errors from mismatched parameter sets:
def common_params(task_instance, task_cls):
task_instance_param_names = dict(task_instance.get_params()).keys()
task_cls_param_names = dict(task_cls.get_params()).keys()
common_param_names = set(task_instance_param_names).intersection(
set(task_cls_param_names))
# ...
return vals
Code Evidence
The @inherits decorator from `luigi/util.py:247-324`:
class inherits:
"""
Task inheritance.
*New after Luigi 2.7.6:* multiple arguments support.
"""
def __call__(self, task_that_inherits):
task_iterator = self.tasks_to_inherit or self.kw_tasks_to_inherit.values()
for task_to_inherit in task_iterator:
for param_name, param_obj in task_to_inherit.get_params():
if not hasattr(task_that_inherits, param_name):
setattr(task_that_inherits, param_name, param_obj)
if self.tasks_to_inherit:
def clone_parent(_self, **kwargs):
return _self.clone(cls=self.tasks_to_inherit[0], **kwargs)
task_that_inherits.clone_parent = clone_parent
The @requires decorator from `luigi/util.py:327-351`:
class requires:
"""
Same as inherits, but also auto-defines the requires method.
"""
def __call__(self, task_that_requires):
task_that_requires = inherits(
*self.tasks_to_require,
**self.kw_tasks_to_require
)(task_that_requires)
def requires(_self):
return (_self.clone_parent()
if len(self.tasks_to_require) == 1
else _self.clone_parents())
task_that_requires.requires = requires
return task_that_requires
common_params utility from `luigi/util.py:230-244`:
def common_params(task_instance, task_cls):
"""Grab all the values in task_instance that are found in task_cls."""
task_instance_param_names = dict(task_instance.get_params()).keys()
task_cls_param_names = dict(task_cls.get_params()).keys()
common_param_names = set(task_instance_param_names).intersection(
set(task_cls_param_names))
common_param_vals = [(key, task_cls_params_dict[key])
for key in common_param_names]
common_kwargs = dict((key, task_instance.param_kwargs[key])
for key in common_param_names)
vals = dict(task_instance.get_param_values(
common_param_vals, [], common_kwargs))
return vals