Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Snorkel team Snorkel LabelModel Fit

From Leeroopedia
Knowledge Sources
Domains Weak_Supervision, Graphical_Models, PyTorch
Last Updated 2026-02-14 20:00 GMT

Overview

Concrete tool for training a generative label model that learns LF accuracy parameters from a label matrix, provided by the Snorkel library.

Description

The LabelModel class (extending nn.Module and BaseLabeler) implements the matrix completion approach for learning labeling function conditional probabilities. The fit() method trains the model using PyTorch optimization to estimate μ parameters.

Internally, fit() performs:

  • Setting up the augmented label matrix and clique tree
  • Initializing conditional probability parameters with precision priors
  • Running SGD/Adam/Adamax optimization with optional LR scheduling
  • Aligning label permutations via the Munkres algorithm
  • Logging training progress and metrics

Usage

Import and use this class when you have a label matrix from LF application and want to train a generative model to combine LF votes. The trained model can then produce probabilistic labels via predict_proba().

Code Reference

Source Location

  • Repository: snorkel
  • File: snorkel/labeling/model/label_model.py
  • Lines: L89-978 (class), L136-146 (__init__), L812-878 (fit)

Signature

class LabelModel(nn.Module, BaseLabeler):
    def __init__(self, cardinality: int = 2, **kwargs: Any) -> None:
        """
        Args:
            cardinality: Number of label classes (default 2).
            **kwargs: Passed to LabelModelConfig (verbose, device).
        """

    def fit(
        self,
        L_train: np.ndarray,
        Y_dev: Optional[np.ndarray] = None,
        class_balance: Optional[List[float]] = None,
        progress_bar: bool = True,
        **kwargs: Any,
    ) -> None:
        """
        Train label model to estimate mu parameters.

        Args:
            L_train: [n, m] label matrix with values in {-1, 0, ..., k-1}.
            Y_dev: Gold labels for dev set (for class_balance estimation).
            class_balance: Prior class probabilities [p_0, ..., p_{k-1}].
            progress_bar: Display training progress bar.
            **kwargs: Training config overrides:
                n_epochs (int, default 100), lr (float, default 0.01),
                l2 (float, default 0.0), optimizer (str, default "sgd"),
                lr_scheduler (str, default "constant"),
                prec_init (float, default 0.7), seed (int),
                log_freq (int, default 10), mu_eps (Optional[float]).
        """

Import

from snorkel.labeling.model import LabelModel

I/O Contract

Inputs

Name Type Required Description
cardinality int No Number of label classes (default 2)
L_train np.ndarray Yes Label matrix [n, m] with values in {-1, 0, ..., k-1}
Y_dev Optional[np.ndarray] No Dev set gold labels for class balance estimation
class_balance Optional[List[float]] No Prior class probabilities
n_epochs int No Training epochs (default 100)
lr float No Learning rate (default 0.01)
l2 float No L2 regularization strength (default 0.0)
optimizer str No Optimizer: "sgd", "adam", or "adamax" (default "sgd")

Outputs

Name Type Description
Trained model LabelModel Model with learned mu parameters (in-place); ready for predict_proba/predict
get_weights() np.ndarray Learned LF accuracy weights [m] accessible after training

Usage Examples

Basic Training

import numpy as np
from snorkel.labeling.model import LabelModel

# Label matrix from LF application
L_train = np.array([
    [0, 0, -1],
    [-1, 0, 1],
    [1, -1, 0],
    [0, 0, 0],
])

# Initialize and train
label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_train, n_epochs=500, lr=0.01, seed=123)

# Check learned weights
weights = label_model.get_weights()
print(f"LF weights: {weights}")

Training with Dev Set

# With dev labels for better class balance estimation
Y_dev = np.array([0, 1, 1, 0])

label_model = LabelModel(cardinality=2, verbose=False)
label_model.fit(
    L_train,
    Y_dev=Y_dev,
    n_epochs=200,
    lr=0.05,
    l2=0.1,
    optimizer="adam",
    seed=42,
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment