Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain Brain Init

From Leeroopedia


Field Value
Implementation Name Brain_Init
API Signature Brain.__init__(self, modules=None, opt_class=None, hparams=None, run_opts=None, checkpointer=None)
Source File speechbrain/core.py:L488 (class Brain), L612-637 (__init__)
Import from speechbrain.core import Brain
Type API Doc
Related Principle Principle:Speechbrain_Speechbrain_Brain_Subclass_Initialization

Description

Brain.__init__() initializes the SpeechBrain experiment manager by setting up module registration, optimizer configuration, hyperparameter access, runtime options (device, precision, distributed training), and checkpoint management. This is the foundation upon which all SpeechBrain training recipes are built. Users subclass Brain and pass their modules and configuration to this constructor.

Inputs

Parameter Type Default Description
modules dict None Dictionary mapping string names to torch.nn.Module instances. These are stored in a torch.nn.ModuleDict, moved to the specified device, and passed to the optimizer. Accessible via self.modules.name.
opt_class callable None Optimizer constructor that accepts a parameter list. Typically a lambda or functools.partial wrapping a PyTorch optimizer class. E.g., lambda params: torch.optim.Adam(params, lr=0.001). Can be None if optimizers are configured in init_optimizers().
hparams dict None Hyperparameters dictionary, typically the output of load_hyperpyyaml(). Stored as a SimpleNamespace accessible via self.hparams with dot notation (e.g., self.hparams.lr).
run_opts dict None Runtime options controlling device, precision, distributed training, debug mode, and other execution parameters. See the run options table below.
checkpointer Checkpointer None A speechbrain.utils.checkpoints.Checkpointer instance for saving/loading model states, optimizer states, and other recoverables.

Run Options

Run options are resolved with priority: command-line args > hparams > defaults.

Option Type Default Description
device str "cpu" Computation device (e.g., "cuda:0", "cpu")
precision str "fp32" Training precision: "fp32", "fp16", or "bf16"
debug bool False Enable debug mode (limited batches/epochs)
debug_batches int 2 Number of batches per epoch in debug mode
debug_epochs int 2 Number of epochs in debug mode
max_grad_norm float 5.0 Maximum gradient norm for gradient clipping
nonfinite_patience int 3 Number of non-finite losses tolerated before stopping
noprogressbar bool False Disable progress bar display
ckpt_interval_minutes float 15.0 Minutes between intra-epoch checkpoint saves
compile bool False Enable torch.compile() for modules
compile_module_keys list None Specific modules to compile

Outputs

Returns an initialized Brain instance with the following key attributes:

Attribute Type Description
self.modules ModuleDict All registered modules, moved to the specified device
self.hparams SimpleNamespace Hyperparameters accessible via dot notation
self.device str The computation device
self.checkpointer Checkpointer The checkpoint manager
self.opt_class callable The optimizer constructor
self.optimizers_dict dict Dictionary of active optimizers (populated during fit)
self.step int Current step counter within an epoch
self.optimizer_step int Global optimizer step counter across all epochs

Initialization Sequence

The __init__ method performs the following steps in order:

  1. Store optimizer class and checkpointer references
  2. Resolve run options -- iterate through all default run options, checking command-line args first, then hparams, then defaults
  3. Python version check -- verify Python version compatibility
  4. Set up hparams -- convert the hparams dict to a SimpleNamespace for dot-notation access
  5. Set up modules -- convert the modules dict to a torch.nn.ModuleDict
  6. Device placement -- move all modules to the specified device
  7. JIT/compile -- optionally apply torch.jit.script or torch.compile to specified modules
  8. Mixed precision setup -- configure automatic mixed precision based on the precision setting
  9. DDP wrapping -- if in distributed mode, wrap modules with DistributedDataParallel

Usage Example

Basic Brain Subclass

import torch
import speechbrain as sb
from speechbrain.core import Brain

class SimpleBrain(Brain):
    def compute_forward(self, batch, stage):
        return self.modules.model(batch[0])

    def compute_objectives(self, predictions, batch, stage):
        return torch.nn.functional.l1_loss(predictions, batch[0])

model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain(
    modules={"model": model},
    opt_class=lambda x: torch.optim.SGD(x, 0.1),
)
brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],))

CTC ASR Subclass (from train_with_wav2vec.py)

import speechbrain as sb
from speechbrain.utils.data_utils import undo_padding

class ASR(sb.core.Brain):
    def compute_forward(self, batch, stage):
        """Forward pass: wav2vec2 -> encoder DNN -> CTC linear -> log-softmax."""
        batch = batch.to(self.device)
        wavs, wav_lens = batch.sig

        # Optional augmentation (training only)
        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)

        # Model forward pass
        feats = self.modules.wav2vec2(wavs, wav_lens)
        x = self.modules.enc(feats)
        logits = self.modules.ctc_lin(x)
        p_ctc = self.hparams.log_softmax(logits)

        # Decoding (validation/test only)
        p_tokens = None
        if stage == sb.Stage.VALID:
            p_tokens = sb.decoders.ctc_greedy_decode(
                p_ctc, wav_lens, blank_id=self.hparams.blank_index
            )
        elif stage == sb.Stage.TEST:
            p_tokens = test_searcher(p_ctc, wav_lens)

        return p_ctc, wav_lens, p_tokens

    def compute_objectives(self, predictions, batch, stage):
        """Compute CTC loss and accumulate WER/CER metrics."""
        p_ctc, wav_lens, p_tokens = predictions
        tokens, tokens_lens = batch.tokens

        loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)

        if stage != sb.Stage.TRAIN:
            predicted_words = self.tokenizer(p_tokens, task="decode_from_list")
            target_words = undo_padding(tokens, tokens_lens)
            target_words = self.tokenizer(target_words, task="decode_from_list")
            self.wer_metric.append(batch.id, predicted_words, target_words)
            self.cer_metric.append(batch.id, predicted_words, target_words)

        return loss

    def init_optimizers(self):
        """Set up separate optimizers for wav2vec2 and model."""
        if not self.hparams.wav2vec2.freeze:
            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
                self.modules.wav2vec2.parameters()
            )
        self.model_optimizer = self.hparams.model_opt_class(
            self.hparams.model.parameters()
        )

Instantiation from YAML Configuration

# In the main training script:
asr_brain = ASR(
    modules=hparams["modules"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)
asr_brain.tokenizer = tokenizer

YAML Configuration for Modules

The modules dict is typically defined in the YAML configuration:

modules:
    wav2vec2: !ref <wav2vec2>
    enc: !ref <enc>
    ctc_lin: !ref <ctc_lin>

After load_hyperpyyaml(), hparams["modules"] contains a dict of instantiated nn.Module objects ready to be passed to Brain.__init__().

Dependencies

  • torch.nn.ModuleDict -- for module registration and parameter tracking
  • speechbrain.utils.checkpoints.Checkpointer -- for checkpoint management
  • speechbrain.utils.distributed -- for DDP setup in multi-GPU environments
  • torch.compile (optional) -- for module compilation optimization

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment