Implementation:Speechbrain Speechbrain Brain Evaluate With ErrorRateStats
| Field | Value |
|---|---|
| Implementation Name | Brain_Evaluate_With_ErrorRateStats |
| API Signature | Brain.evaluate(self, test_set, max_key=None, min_key=None, progressbar=None, test_loader_kwargs={}) and ErrorRateStats.__init__(self, merge_tokens=False, split_tokens=False, space_token="_", keep_values=True, extract_concepts_values=False, tag_in="", tag_out="", equality_comparator=_str_equals)
|
| Source File | speechbrain/core.py:L1695-1754 (evaluate), speechbrain/utils/metric_stats.py:L206-378 (ErrorRateStats) |
| Import | from speechbrain.core import Brain and from speechbrain.utils.metric_stats import ErrorRateStats
|
| Type | API Doc |
| Related Principle | Principle:Speechbrain_Speechbrain_ASR_Evaluation_With_WER |
Description
Brain.evaluate() performs model evaluation on a test set by loading the best checkpoint (selected by a metric key), running inference with gradients disabled, and computing aggregate metrics. ErrorRateStats is the metric accumulator class that computes Word Error Rate (WER) and Character Error Rate (CER) by tracking per-utterance edit distances between hypotheses and references.
Brain.evaluate()
Inputs
| Parameter | Type | Default | Description |
|---|---|---|---|
test_set |
Dataset or DataLoader | (required) | Test data to evaluate on. If a DynamicItemDataset, a DataLoader is automatically created.
|
max_key |
str | None | Metric key to maximize when selecting the best checkpoint. Mutually exclusive with min_key.
|
min_key |
str | None | Metric key to minimize when selecting the best checkpoint. For ASR, typically "WER".
|
progressbar |
bool | None | Whether to display a progress bar. If None, determined by the noprogressbar run option.
|
test_loader_kwargs |
dict | {} | Keyword arguments for DataLoader creation. ckpt_prefix is automatically set to None so the test DataLoader is not added to the checkpointer.
|
Outputs
Returns the average test loss (float). Side effects include:
- WER/CER metrics are computed and logged
- Detailed alignment statistics are written to the test WER file
- Statistics are printed via the train logger
Execution Flow
evaluate(test_set, min_key="WER")
|
+-- Create DataLoader from test_set if needed
+-- on_evaluate_start(min_key="WER")
| +-- Load best checkpoint (lowest WER)
+-- on_stage_start(TEST, epoch=None)
| +-- Initialize ErrorRateStats for WER and CER
+-- modules.eval()
+-- torch.no_grad():
| +-- for each batch in test_set:
| +-- evaluate_batch(batch, TEST)
| +-- compute_forward(batch, TEST)
| | -> p_ctc, wav_lens, p_tokens (beam search)
| +-- compute_objectives(preds, batch, TEST)
| -> CTC loss + WER/CER accumulation
+-- on_stage_end(TEST, avg_test_loss, None)
+-- Summarize WER/CER statistics
+-- Log test statistics
+-- Write detailed alignment file
ErrorRateStats
Constructor Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
merge_tokens |
bool | False | Merge successive tokens into words (e.g., character-level to word-level) |
split_tokens |
bool | False | Split tokens into characters (e.g., word-level to character-level). Used for CER computation. |
space_token |
str | "_" | Token used as word boundary. Used with merge_tokens for splitting after merge, or with split_tokens for joining before split.
|
keep_values |
bool | True | Whether to keep concept values in structured output evaluation |
extract_concepts_values |
bool | False | Process predictions/targets to extract concepts and values |
tag_in |
str | "" | Start tag for concept extraction |
tag_out |
str | "" | End tag for concept extraction |
equality_comparator |
Callable | _str_equals | Function to compare two tokens for equality |
Key Methods
append(ids, predict, target, predict_len=None, target_len=None, ind2lab=None)
Adds per-utterance error statistics for a batch.
| Parameter | Type | Description |
|---|---|---|
ids |
list | List of utterance IDs for the batch |
predict |
list or torch.Tensor | Predicted word/token sequences |
target |
list or torch.Tensor | Reference word/token sequences |
predict_len |
torch.Tensor | Relative lengths for undoing prediction padding (optional) |
target_len |
torch.Tensor | Relative lengths for undoing target padding (optional) |
ind2lab |
callable | Maps from indices to labels for alignment writing (optional) |
The method:
- Undoes padding if length tensors are provided
- Applies index-to-label mapping if
ind2labis given - Optionally merges or splits tokens
- Computes per-utterance WER details including alignments
- Stores scores for later summarization
summarize(field=None)
Aggregates all per-utterance scores into corpus-level statistics.
Returns (when field=None): a dict with keys:
| Key | Type | Description |
|---|---|---|
"WER" |
float | Overall Word Error Rate as a percentage |
"error_rate" |
float | Same as WER (generic alias) |
"insertions" |
int | Total insertion errors across all utterances |
"deletions" |
int | Total deletion errors across all utterances |
"substitutions" |
int | Total substitution errors across all utterances |
When field is specified (e.g., "error_rate"), returns only that specific value.
write_stats(filestream)
Writes detailed statistics and per-utterance alignment information to a file stream.
with open(self.hparams.test_wer_file, "w", encoding="utf-8") as w:
self.wer_metric.write_stats(w)
Output includes a summary header followed by per-utterance alignments showing substitutions, insertions, and deletions.
YAML Configuration
The WER and CER metric computers are configured in YAML:
# WER computer (word-level)
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
# CER computer (character-level, using split_tokens)
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True
The !name: tag creates a callable (constructor) rather than an instance. The actual ErrorRateStats instances are created in on_stage_start() at the beginning of each validation/test stage:
def on_stage_start(self, stage, epoch):
if stage != sb.Stage.TRAIN:
self.cer_metric = self.hparams.cer_computer() # Creates new ErrorRateStats
self.wer_metric = self.hparams.error_rate_computer() # Creates new ErrorRateStats
Usage Example
Complete Evaluation Flow
import speechbrain as sb
# After training is complete, evaluate on test set
asr_brain = ASR(
modules=hparams["modules"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# Load best checkpoint and evaluate
asr_brain.evaluate(
test_data,
min_key="WER",
test_loader_kwargs=hparams["test_dataloader_options"],
)
# The on_stage_end callback writes detailed WER stats:
# on_stage_end(TEST, avg_loss, None):
# with open(self.hparams.test_wer_file, "w") as w:
# self.wer_metric.write_stats(w)
Standalone ErrorRateStats Usage
from speechbrain.utils.metric_stats import ErrorRateStats
# Word Error Rate
wer_stats = ErrorRateStats()
wer_stats.append(
ids=["utt1", "utt2"],
predict=[["the", "cat", "set"], ["hello", "world"]],
target=[["the", "cat", "sat"], ["hello", "world"]],
)
summary = wer_stats.summarize()
print(f"WER: {summary['WER']:.2f}%")
print(f"Substitutions: {summary['substitutions']}")
print(f"Deletions: {summary['deletions']}")
print(f"Insertions: {summary['insertions']}")
# Character Error Rate
cer_stats = ErrorRateStats(split_tokens=True)
cer_stats.append(
ids=["utt1"],
predict=[["the", "cat", "set"]],
target=[["the", "cat", "sat"]],
)
cer_summary = cer_stats.summarize()
print(f"CER: {cer_summary['WER']:.2f}%")
Checkpoint Selection
The min_key="WER" parameter in evaluate() instructs the checkpointer to load the checkpoint with the lowest WER value. This checkpoint was saved during training by:
self.checkpointer.save_and_keep_only(
meta={"WER": stage_stats["WER"]},
min_keys=["WER"],
)
The checkpointer stores WER values in checkpoint metadata and can select the optimal checkpoint at evaluation time. This ensures that the model evaluated on the test set is the one that performed best on the validation set, not simply the most recent.
Dependencies
speechbrain.utils.edit_distance.wer_details_for_batch-- computes per-utterance edit distance and alignment detailsspeechbrain.utils.edit_distance.wer_summary-- aggregates per-utterance scores into corpus-level WERspeechbrain.utils.edit_distance.print_wer_summary-- formats WER summary for outputspeechbrain.utils.edit_distance.print_alignments-- formats per-utterance alignments for outputspeechbrain.dataio.dataio.merge_char/split_word-- for character-level processing