Implementation:Speechbrain Speechbrain Brain Init
| Field | Value |
|---|---|
| Implementation Name | Brain_Init |
| API Signature | Brain.__init__(self, modules=None, opt_class=None, hparams=None, run_opts=None, checkpointer=None)
|
| Source File | speechbrain/core.py:L488 (class Brain), L612-637 (__init__) |
| Import | from speechbrain.core import Brain
|
| Type | API Doc |
| Related Principle | Principle:Speechbrain_Speechbrain_Brain_Subclass_Initialization |
Description
Brain.__init__() initializes the SpeechBrain experiment manager by setting up module registration, optimizer configuration, hyperparameter access, runtime options (device, precision, distributed training), and checkpoint management. This is the foundation upon which all SpeechBrain training recipes are built. Users subclass Brain and pass their modules and configuration to this constructor.
Inputs
| Parameter | Type | Default | Description |
|---|---|---|---|
modules |
dict | None | Dictionary mapping string names to torch.nn.Module instances. These are stored in a torch.nn.ModuleDict, moved to the specified device, and passed to the optimizer. Accessible via self.modules.name.
|
opt_class |
callable | None | Optimizer constructor that accepts a parameter list. Typically a lambda or functools.partial wrapping a PyTorch optimizer class. E.g., lambda params: torch.optim.Adam(params, lr=0.001). Can be None if optimizers are configured in init_optimizers().
|
hparams |
dict | None | Hyperparameters dictionary, typically the output of load_hyperpyyaml(). Stored as a SimpleNamespace accessible via self.hparams with dot notation (e.g., self.hparams.lr).
|
run_opts |
dict | None | Runtime options controlling device, precision, distributed training, debug mode, and other execution parameters. See the run options table below. |
checkpointer |
Checkpointer | None | A speechbrain.utils.checkpoints.Checkpointer instance for saving/loading model states, optimizer states, and other recoverables.
|
Run Options
Run options are resolved with priority: command-line args > hparams > defaults.
| Option | Type | Default | Description |
|---|---|---|---|
device |
str | "cpu" | Computation device (e.g., "cuda:0", "cpu") |
precision |
str | "fp32" | Training precision: "fp32", "fp16", or "bf16" |
debug |
bool | False | Enable debug mode (limited batches/epochs) |
debug_batches |
int | 2 | Number of batches per epoch in debug mode |
debug_epochs |
int | 2 | Number of epochs in debug mode |
max_grad_norm |
float | 5.0 | Maximum gradient norm for gradient clipping |
nonfinite_patience |
int | 3 | Number of non-finite losses tolerated before stopping |
noprogressbar |
bool | False | Disable progress bar display |
ckpt_interval_minutes |
float | 15.0 | Minutes between intra-epoch checkpoint saves |
compile |
bool | False | Enable torch.compile() for modules
|
compile_module_keys |
list | None | Specific modules to compile |
Outputs
Returns an initialized Brain instance with the following key attributes:
| Attribute | Type | Description |
|---|---|---|
self.modules |
ModuleDict | All registered modules, moved to the specified device |
self.hparams |
SimpleNamespace | Hyperparameters accessible via dot notation |
self.device |
str | The computation device |
self.checkpointer |
Checkpointer | The checkpoint manager |
self.opt_class |
callable | The optimizer constructor |
self.optimizers_dict |
dict | Dictionary of active optimizers (populated during fit) |
self.step |
int | Current step counter within an epoch |
self.optimizer_step |
int | Global optimizer step counter across all epochs |
Initialization Sequence
The __init__ method performs the following steps in order:
- Store optimizer class and checkpointer references
- Resolve run options -- iterate through all default run options, checking command-line args first, then hparams, then defaults
- Python version check -- verify Python version compatibility
- Set up hparams -- convert the hparams dict to a
SimpleNamespacefor dot-notation access - Set up modules -- convert the modules dict to a
torch.nn.ModuleDict - Device placement -- move all modules to the specified device
- JIT/compile -- optionally apply
torch.jit.scriptortorch.compileto specified modules - Mixed precision setup -- configure automatic mixed precision based on the precision setting
- DDP wrapping -- if in distributed mode, wrap modules with
DistributedDataParallel
Usage Example
Basic Brain Subclass
import torch
import speechbrain as sb
from speechbrain.core import Brain
class SimpleBrain(Brain):
def compute_forward(self, batch, stage):
return self.modules.model(batch[0])
def compute_objectives(self, predictions, batch, stage):
return torch.nn.functional.l1_loss(predictions, batch[0])
model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain(
modules={"model": model},
opt_class=lambda x: torch.optim.SGD(x, 0.1),
)
brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],))
CTC ASR Subclass (from train_with_wav2vec.py)
import speechbrain as sb
from speechbrain.utils.data_utils import undo_padding
class ASR(sb.core.Brain):
def compute_forward(self, batch, stage):
"""Forward pass: wav2vec2 -> encoder DNN -> CTC linear -> log-softmax."""
batch = batch.to(self.device)
wavs, wav_lens = batch.sig
# Optional augmentation (training only)
if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
# Model forward pass
feats = self.modules.wav2vec2(wavs, wav_lens)
x = self.modules.enc(feats)
logits = self.modules.ctc_lin(x)
p_ctc = self.hparams.log_softmax(logits)
# Decoding (validation/test only)
p_tokens = None
if stage == sb.Stage.VALID:
p_tokens = sb.decoders.ctc_greedy_decode(
p_ctc, wav_lens, blank_id=self.hparams.blank_index
)
elif stage == sb.Stage.TEST:
p_tokens = test_searcher(p_ctc, wav_lens)
return p_ctc, wav_lens, p_tokens
def compute_objectives(self, predictions, batch, stage):
"""Compute CTC loss and accumulate WER/CER metrics."""
p_ctc, wav_lens, p_tokens = predictions
tokens, tokens_lens = batch.tokens
loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
if stage != sb.Stage.TRAIN:
predicted_words = self.tokenizer(p_tokens, task="decode_from_list")
target_words = undo_padding(tokens, tokens_lens)
target_words = self.tokenizer(target_words, task="decode_from_list")
self.wer_metric.append(batch.id, predicted_words, target_words)
self.cer_metric.append(batch.id, predicted_words, target_words)
return loss
def init_optimizers(self):
"""Set up separate optimizers for wav2vec2 and model."""
if not self.hparams.wav2vec2.freeze:
self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
self.modules.wav2vec2.parameters()
)
self.model_optimizer = self.hparams.model_opt_class(
self.hparams.model.parameters()
)
Instantiation from YAML Configuration
# In the main training script:
asr_brain = ASR(
modules=hparams["modules"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
asr_brain.tokenizer = tokenizer
YAML Configuration for Modules
The modules dict is typically defined in the YAML configuration:
modules:
wav2vec2: !ref <wav2vec2>
enc: !ref <enc>
ctc_lin: !ref <ctc_lin>
After load_hyperpyyaml(), hparams["modules"] contains a dict of instantiated nn.Module objects ready to be passed to Brain.__init__().
Dependencies
torch.nn.ModuleDict-- for module registration and parameter trackingspeechbrain.utils.checkpoints.Checkpointer-- for checkpoint managementspeechbrain.utils.distributed-- for DDP setup in multi-GPU environmentstorch.compile(optional) -- for module compilation optimization