Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Fastai Fastbook Waterfall Ensemble

From Leeroopedia


Knowledge Sources
Domains Model Interpretation, Ensemble Methods
Last Updated 2026-02-09 17:00 GMT

Overview

Concrete tools for visualizing per-feature prediction contributions via waterfall charts and combining random forest and neural network predictions into an ensemble, provided by waterfallcharts, treeinterpreter, fastai, and scikit-learn.

Description

This implementation covers two complementary techniques from the final sections of the fastbook Tabular Modeling chapter:

  1. Waterfall charts: The waterfall function from the waterfallcharts library renders an additive decomposition of a single prediction. Starting from a bias term (global mean), each feature's contribution is shown as a positive (green) or negative (red) bar segment, with small contributions grouped under a threshold. This makes it visually clear why the model predicted a particular value for a specific row.
  1. Ensemble averaging: The chapter's final technique is to average the predictions of a trained random forest and a trained neural network. Because the two models use fundamentally different algorithms, their errors are partially uncorrelated, and the average prediction outperforms either model individually. This is implemented as a single line of arithmetic after aligning the output shapes of the two models.

Usage

Use waterfall charts in production or during model review to explain individual predictions to stakeholders. Use ensemble averaging as the final step after training both a random forest and a neural network on the same dataset to achieve the best possible accuracy with minimal additional effort.

Code Reference

Source Location

  • Repository: fastbook
  • File: translations/cn/09_tabular.md (Lines 1085-1087 for waterfall; Lines 1397-1414 for ensembling)
  • Note: waterfallcharts is an external library. Ensemble averaging uses standard NumPy/PyTorch operations.

Signature

# Waterfall chart visualization
from waterfallcharts import waterfall
waterfall(columns, contributions, threshold=0.08, rotation_value=45,
          formatting='{:,.3f}')

# Ensemble averaging (pattern, not a library function)
rf_preds = m.predict(valid_xs)
nn_preds = to_np(preds.squeeze())
ens_preds = (nn_preds + rf_preds) / 2

Import

from waterfallcharts import waterfall
import treeinterpreter
import numpy as np
from fastai.tabular.all import *

I/O Contract

Inputs

Name Type Required Description
columns (waterfall) list of str or pandas.Index Yes Feature names corresponding to each contribution value. Typically valid_xs.columns.
contributions (waterfall) numpy.ndarray (n_features,) Yes Per-feature contribution values for a single row, obtained from treeinterpreter.predict().
threshold (waterfall) float No Contributions with absolute value below this are grouped into "Others". Default 0.08.
rotation_value (waterfall) int No Rotation angle for x-axis labels. Default 45 degrees.
formatting (waterfall) str No Python format string for bar labels. Default '{:,.3f}'.
rf_preds (ensemble) numpy.ndarray (n_rows,) Yes Random forest predictions on the validation/test set. Shape is rank-1.
nn_preds (ensemble) numpy.ndarray (n_rows,) Yes Neural network predictions, squeezed from rank-2 tensor and converted to NumPy.

Outputs

Name Type Description
Waterfall chart matplotlib figure Horizontal bar chart showing how each feature's contribution (from bias to final prediction) adds up. The rightmost bar labeled "net" shows the final predicted value.
ens_preds numpy.ndarray (n_rows,) Ensemble predictions computed as the arithmetic mean of the random forest and neural network predictions. Expected to have lower RMSE than either model alone.

Usage Examples

Basic Usage

import treeinterpreter
from waterfallcharts import waterfall
import numpy as np
import math

# --- Prerequisites ---
# m: fitted RandomForestRegressor
# learn: trained fastai tabular Learner
# valid_xs_final: validation features DataFrame
# valid_y: validation targets

# ============================================================
# Part 1: Waterfall Chart for Individual Prediction Explanation
# ============================================================

# Select a few rows from the validation set
row = valid_xs_final.iloc[:5]

# Decompose predictions using treeinterpreter
prediction, bias, contributions = treeinterpreter.predict(m, row.values)

# Verify the decomposition for the first row
print(f"Prediction: {prediction[0]}")
print(f"Bias:       {bias[0]}")
print(f"Bias + Sum: {bias[0] + contributions[0].sum()}")
# These should be equal (within floating point precision)

# Create a waterfall chart for the first row
waterfall(valid_xs_final.columns, contributions[0],
          threshold=0.08, rotation_value=45,
          formatting='{:,.3f}')

# ============================================================
# Part 2: Ensemble Averaging of RF and NN Predictions
# ============================================================

# Get random forest predictions
rf_preds = m.predict(valid_xs_final)

# Get neural network predictions
nn_preds_tensor, targs = learn.get_preds()
nn_preds = to_np(nn_preds_tensor.squeeze())

# Simple arithmetic average
ens_preds = (nn_preds + rf_preds) / 2

# Helper function
def r_mse(pred, y): return round(math.sqrt(((pred - y) ** 2).mean()), 6)

# Compare all three models
print(f"RF RMSE:       {r_mse(rf_preds, valid_y)}")      # ~0.232
print(f"NN RMSE:       {r_mse(nn_preds, valid_y)}")      # ~0.226
print(f"Ensemble RMSE: {r_mse(ens_preds, valid_y)}")     # Best of all three

Production Waterfall for Multiple Rows

# Generate waterfall charts for multiple predictions (e.g., for a report)
import matplotlib.pyplot as plt

for i in range(3):
    plt.figure(figsize=(10, 5))
    waterfall(valid_xs_final.columns, contributions[i],
              threshold=0.06, rotation_value=45,
              formatting='{:,.3f}')
    plt.title(f"Prediction Explanation for Row {i} "
              f"(predicted={prediction[i][0]:.3f})")
    plt.tight_layout()
    plt.show()

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment