Implementation:Axolotl ai cloud Axolotl SwanLab Custom Trainer Profiling
| Knowledge Sources | |
|---|---|
| Domains | Profiling, Training, Experiment_Tracking |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Example module demonstrating three patterns for adding SwanLab profiling instrumentation to custom Axolotl trainer subclasses.
Description
The custom_trainer_profiling.py example provides two complete trainer implementations (CustomTrainerWithProfiling and AdvancedProfilingTrainer) that demonstrate how to instrument custom trainers with SwanLab profiling. It covers four profiling patterns: (1) Decorator-based using @swanlab_profile for methods that should always be profiled (~2-5 microsecond overhead), (2) Context manager using swanlab_profiling_context for fine-grained profiling of specific code blocks (e.g., forward pass vs backward pass separately), (3) Advanced filtering using swanlab_profiling_context_advanced with ProfilingConfig to throttle high-frequency operations via min_duration_ms and log_interval parameters, and (4) Exception-safe profiling where duration is logged even if the profiled method raises an exception. All profiling metrics are logged to SwanLab under the profiling/ namespace.
Usage
Use this module as a reference when creating custom trainers that need performance instrumentation. Copy the relevant patterns into your own trainer subclass. The decorator pattern works for critical-path methods (training_step, compute_loss), while the advanced context manager is suited for high-frequency helper methods where throttling is needed.
Code Reference
Source Location
- Repository: Axolotl
- File: examples/swanlab/custom_trainer_profiling.py
- Lines: 1-299
Signature
class CustomTrainerWithProfiling(AxolotlTrainer):
"""Custom trainer with SwanLab profiling enabled."""
def __init__(self, *args, **kwargs): ...
@swanlab_profile
def training_step(self, model, inputs): ...
@swanlab_profile
def compute_loss(self, model, inputs, return_outputs=False): ...
@swanlab_profile
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): ...
def complex_training_step(self, model, inputs): ...
def _prepare_inputs(self, inputs): ...
class AdvancedProfilingTrainer(AxolotlTrainer):
"""Trainer with method-specific profiling configurations."""
def __init__(self, *args, **kwargs): ...
def training_step(self, model, inputs): ...
def _prepare_inputs(self, inputs): ...
Import
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.swanlab.profiling import (
ProfilingConfig,
swanlab_profile,
swanlab_profiling_context,
swanlab_profiling_context_advanced,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | The model being trained |
| inputs | dict | Yes | Batch of input tensors |
| ProfilingConfig.min_duration_ms | float | No | Minimum duration threshold to log (default 0.0) |
| ProfilingConfig.log_interval | int | No | Log every Nth call (default 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| SwanLab profiling metrics | float | Timing data logged under profiling/Time taken: ClassName.method_name |
| Training outputs | varies | Standard trainer outputs (loss, logits, etc.) |
Usage Examples
Decorator Pattern
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.swanlab.profiling import swanlab_profile
class MyTrainer(AxolotlTrainer):
@swanlab_profile
def training_step(self, model, inputs):
# Profiled automatically; metric logged as:
# profiling/Time taken: MyTrainer.training_step
return super().training_step(model, inputs)
Context Manager with Throttling
from axolotl.integrations.swanlab.profiling import (
ProfilingConfig,
swanlab_profiling_context_advanced,
)
config = ProfilingConfig(enabled=True, min_duration_ms=1.0, log_interval=100)
class MyTrainer(AxolotlTrainer):
def _prepare_inputs(self, inputs):
with swanlab_profiling_context_advanced(self, "prepare_inputs", config=config):
return super()._prepare_inputs(inputs)