Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Snorkel team Snorkel MultitaskClassifier Score

From Leeroopedia
Knowledge Sources
Domains Evaluation, Multi_Task_Learning, PyTorch
Last Updated 2026-02-14 20:00 GMT

Overview

Concrete tool for evaluating multi-task classifiers with per-task metrics and generating predictions, provided by the Snorkel library.

Description

The MultitaskClassifier.score() method evaluates the model across all tasks, datasets, and splits, returning structured metrics. MultitaskClassifier.predict() generates probability distributions and optional hard predictions per task.

Additionally, cross_entropy_with_probs() extends standard cross-entropy to support probabilistic (soft) target labels, enabling noise-aware training with labels from the weak supervision pipeline.

Usage

Import these after training to evaluate model performance and generate predictions. Use cross_entropy_with_probs when training with soft labels from the label model.

Code Reference

Source Location

  • Repository: snorkel
  • File: snorkel/classification/multitask_classifier.py (predict L317-380, score L382-456), snorkel/classification/loss.py (cross_entropy_with_probs L9-66)

Signature

class MultitaskClassifier(nn.Module):
    def score(
        self,
        dataloaders: List[DictDataLoader],
        remap_labels: Dict[str, Optional[str]] = {},
        as_dataframe: bool = False,
    ) -> Union[Dict[str, float], pd.DataFrame]:
        """
        Compute per-task metrics.

        Args:
            dataloaders: DictDataLoaders to evaluate.
            remap_labels: Dict mapping label names to task names for evaluation.
            as_dataframe: Return as DataFrame.
        Returns:
            Dict ("task/dataset/split/metric" -> float) or DataFrame.
        """

    def predict(
        self,
        dataloader: DictDataLoader,
        return_preds: bool = False,
        remap_labels: Dict[str, Optional[str]] = {},
    ) -> Dict[str, Dict[str, torch.Tensor]]:
        """
        Generate predictions.

        Args:
            dataloader: Data to predict on.
            return_preds: Include hard predictions.
            remap_labels: Label remapping.
        Returns:
            Dict per task with 'golds', 'probs', optionally 'preds'.
        """

def cross_entropy_with_probs(
    input: Tensor,
    target: Tensor,
    weight: Optional[Tensor] = None,
    reduction: str = "mean",
) -> Tensor:
    """
    Cross-entropy supporting probabilistic (soft) targets.

    Args:
        input: [n, k] logits.
        target: [n, k] soft targets (probabilities).
        weight: Optional class weights.
        reduction: "mean", "sum", or "none".
    Returns:
        Loss tensor.
    """

Import

from snorkel.classification import MultitaskClassifier
from snorkel.classification.loss import cross_entropy_with_probs

I/O Contract

Inputs

Name Type Required Description
dataloaders List[DictDataLoader] Yes Data with gold labels
remap_labels Dict[str, Optional[str]] No Label-to-task remapping
as_dataframe bool No Return results as DataFrame (default False)

Outputs

Name Type Description
score() dict Dict[str, float] "task/dataset/split/metric" -> float
score() DataFrame pd.DataFrame Columns: label, dataset, split, metric, score
predict() Dict[str, Dict[str, Tensor]] Per-task 'golds', 'probs', optionally 'preds'

Usage Examples

Evaluate Model

# Score across all tasks and splits
results = model.score(
    dataloaders=[test_dl],
    as_dataframe=True,
)
print(results)

# Predict
predictions = model.predict(test_dl, return_preds=True)
for task_name, task_preds in predictions.items():
    print(f"{task_name}: probs shape = {task_preds['probs'].shape}")

Training with Soft Labels

from snorkel.classification.loss import cross_entropy_with_probs

# Use soft labels from label model as task loss
task_with_soft_labels = Task(
    name="sentiment",
    module_pool=module_pool,
    op_sequence=op_sequence,
    loss_func=cross_entropy_with_probs,  # Supports soft targets
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment