Implementation:Snorkel team Snorkel LabelModel Fit
| Knowledge Sources | |
|---|---|
| Domains | Weak_Supervision, Graphical_Models, PyTorch |
| Last Updated | 2026-02-14 20:00 GMT |
Overview
Concrete tool for training a generative label model that learns LF accuracy parameters from a label matrix, provided by the Snorkel library.
Description
The LabelModel class (extending nn.Module and BaseLabeler) implements the matrix completion approach for learning labeling function conditional probabilities. The fit() method trains the model using PyTorch optimization to estimate parameters.
Internally, fit() performs:
- Setting up the augmented label matrix and clique tree
- Initializing conditional probability parameters with precision priors
- Running SGD/Adam/Adamax optimization with optional LR scheduling
- Aligning label permutations via the Munkres algorithm
- Logging training progress and metrics
Usage
Import and use this class when you have a label matrix from LF application and want to train a generative model to combine LF votes. The trained model can then produce probabilistic labels via predict_proba().
Code Reference
Source Location
- Repository: snorkel
- File: snorkel/labeling/model/label_model.py
- Lines: L89-978 (class), L136-146 (__init__), L812-878 (fit)
Signature
class LabelModel(nn.Module, BaseLabeler):
def __init__(self, cardinality: int = 2, **kwargs: Any) -> None:
"""
Args:
cardinality: Number of label classes (default 2).
**kwargs: Passed to LabelModelConfig (verbose, device).
"""
def fit(
self,
L_train: np.ndarray,
Y_dev: Optional[np.ndarray] = None,
class_balance: Optional[List[float]] = None,
progress_bar: bool = True,
**kwargs: Any,
) -> None:
"""
Train label model to estimate mu parameters.
Args:
L_train: [n, m] label matrix with values in {-1, 0, ..., k-1}.
Y_dev: Gold labels for dev set (for class_balance estimation).
class_balance: Prior class probabilities [p_0, ..., p_{k-1}].
progress_bar: Display training progress bar.
**kwargs: Training config overrides:
n_epochs (int, default 100), lr (float, default 0.01),
l2 (float, default 0.0), optimizer (str, default "sgd"),
lr_scheduler (str, default "constant"),
prec_init (float, default 0.7), seed (int),
log_freq (int, default 10), mu_eps (Optional[float]).
"""
Import
from snorkel.labeling.model import LabelModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cardinality | int | No | Number of label classes (default 2) |
| L_train | np.ndarray | Yes | Label matrix [n, m] with values in {-1, 0, ..., k-1} |
| Y_dev | Optional[np.ndarray] | No | Dev set gold labels for class balance estimation |
| class_balance | Optional[List[float]] | No | Prior class probabilities |
| n_epochs | int | No | Training epochs (default 100) |
| lr | float | No | Learning rate (default 0.01) |
| l2 | float | No | L2 regularization strength (default 0.0) |
| optimizer | str | No | Optimizer: "sgd", "adam", or "adamax" (default "sgd") |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model | LabelModel | Model with learned mu parameters (in-place); ready for predict_proba/predict |
| get_weights() | np.ndarray | Learned LF accuracy weights [m] accessible after training |
Usage Examples
Basic Training
import numpy as np
from snorkel.labeling.model import LabelModel
# Label matrix from LF application
L_train = np.array([
[0, 0, -1],
[-1, 0, 1],
[1, -1, 0],
[0, 0, 0],
])
# Initialize and train
label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_train, n_epochs=500, lr=0.01, seed=123)
# Check learned weights
weights = label_model.get_weights()
print(f"LF weights: {weights}")
Training with Dev Set
# With dev labels for better class balance estimation
Y_dev = np.array([0, 1, 1, 0])
label_model = LabelModel(cardinality=2, verbose=False)
label_model.fit(
L_train,
Y_dev=Y_dev,
n_epochs=200,
lr=0.05,
l2=0.1,
optimizer="adam",
seed=42,
)