Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Axolotl ai cloud Axolotl SwanLab Custom Trainer Profiling

From Leeroopedia
Revision as of 14:33, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Axolotl_ai_cloud_Axolotl_SwanLab_Custom_Trainer_Profiling.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Profiling, Training, Experiment_Tracking
Last Updated 2026-02-07 00:00 GMT

Overview

Example module demonstrating three patterns for adding SwanLab profiling instrumentation to custom Axolotl trainer subclasses.

Description

The custom_trainer_profiling.py example provides two complete trainer implementations (CustomTrainerWithProfiling and AdvancedProfilingTrainer) that demonstrate how to instrument custom trainers with SwanLab profiling. It covers four profiling patterns: (1) Decorator-based using @swanlab_profile for methods that should always be profiled (~2-5 microsecond overhead), (2) Context manager using swanlab_profiling_context for fine-grained profiling of specific code blocks (e.g., forward pass vs backward pass separately), (3) Advanced filtering using swanlab_profiling_context_advanced with ProfilingConfig to throttle high-frequency operations via min_duration_ms and log_interval parameters, and (4) Exception-safe profiling where duration is logged even if the profiled method raises an exception. All profiling metrics are logged to SwanLab under the profiling/ namespace.

Usage

Use this module as a reference when creating custom trainers that need performance instrumentation. Copy the relevant patterns into your own trainer subclass. The decorator pattern works for critical-path methods (training_step, compute_loss), while the advanced context manager is suited for high-frequency helper methods where throttling is needed.

Code Reference

Source Location

Signature

class CustomTrainerWithProfiling(AxolotlTrainer):
    """Custom trainer with SwanLab profiling enabled."""

    def __init__(self, *args, **kwargs): ...

    @swanlab_profile
    def training_step(self, model, inputs): ...

    @swanlab_profile
    def compute_loss(self, model, inputs, return_outputs=False): ...

    @swanlab_profile
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): ...

    def complex_training_step(self, model, inputs): ...
    def _prepare_inputs(self, inputs): ...

class AdvancedProfilingTrainer(AxolotlTrainer):
    """Trainer with method-specific profiling configurations."""

    def __init__(self, *args, **kwargs): ...
    def training_step(self, model, inputs): ...
    def _prepare_inputs(self, inputs): ...

Import

from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.swanlab.profiling import (
    ProfilingConfig,
    swanlab_profile,
    swanlab_profiling_context,
    swanlab_profiling_context_advanced,
)

I/O Contract

Inputs

Name Type Required Description
model nn.Module Yes The model being trained
inputs dict Yes Batch of input tensors
ProfilingConfig.min_duration_ms float No Minimum duration threshold to log (default 0.0)
ProfilingConfig.log_interval int No Log every Nth call (default 1)

Outputs

Name Type Description
SwanLab profiling metrics float Timing data logged under profiling/Time taken: ClassName.method_name
Training outputs varies Standard trainer outputs (loss, logits, etc.)

Usage Examples

Decorator Pattern

from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.swanlab.profiling import swanlab_profile

class MyTrainer(AxolotlTrainer):
    @swanlab_profile
    def training_step(self, model, inputs):
        # Profiled automatically; metric logged as:
        # profiling/Time taken: MyTrainer.training_step
        return super().training_step(model, inputs)

Context Manager with Throttling

from axolotl.integrations.swanlab.profiling import (
    ProfilingConfig,
    swanlab_profiling_context_advanced,
)

config = ProfilingConfig(enabled=True, min_duration_ms=1.0, log_interval=100)

class MyTrainer(AxolotlTrainer):
    def _prepare_inputs(self, inputs):
        with swanlab_profiling_context_advanced(self, "prepare_inputs", config=config):
            return super()._prepare_inputs(inputs)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment