Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine TELlamaForCausalLM

From Leeroopedia


Overview

TE-accelerated LlamaForCausalLM model using monkey-patching for layer replacement.

Doc Type

Wrapper Doc -- This class wraps HuggingFace's LlamaForCausalLM using a __new__-based factory pattern to inject TE decoder layers.

Description

TELlamaForCausalLM uses Python's __new__ method to create a LlamaForCausalLM instance, then initializes it within a replace_decoder context so that all decoder layers become TELlamaDecoderLayer instances. The result is a standard HuggingFace LlamaForCausalLM object with full API compatibility but TE-powered internals.

The class provides two construction paths:

  • __new__(cls, config): Creates a new LlamaForCausalLM with randomly initialized TE decoder layers.
  • from_pretrained_local(cls, path, config=config): Creates the TE model and loads pretrained HuggingFace checkpoint weights, mapping them to TE's parameter layout.

The from_pretrained_local method handles sharded checkpoints by:

  1. Detecting whether the checkpoint uses model.safetensors.index.json or pytorch_model.bin.index.json
  2. Resolving shard file paths via HuggingFace's get_checkpoint_shard_files
  3. Iterating over each shard, loading weights and mapping TE-specific parameters via replace_params
  4. Using load_state_dict(strict=False) for non-TE parameters (embeddings, final layer norm, lm_head)
  5. Explicitly freeing memory after each shard with del state_dict and gc.collect()

Source

  • File: docs/examples/te_llama/te_llama.py
  • Class: TELlamaForCausalLM
  • Lines: L87-163

Signature

class TELlamaForCausalLM:
    """
    Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
    class is monkey-patched with `TELlamaDecoderLayer` class before
    initializing the causal LM with `LlamaForCausalLM`.

    Args:
        config: LlamaConfig
    """

    def __new__(cls, config: LlamaConfig):
        with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
            llama_for_causal_lm = LlamaForCausalLM(config)
        return llama_for_causal_lm

    @classmethod
    def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs):
        """
        Custom method adapted from `from_pretrained` method in HuggingFace
        Transformers repo.
        """
        torch.set_default_dtype(kwargs["torch_dtype"])

        vanilla_model = cls(config)
        subfolder = ""
        variant = None

        # Detect sharded checkpoint format (safetensors or PyTorch)
        if os.path.isfile(
            os.path.join(pretrained_model_name_or_path, subfolder,
                         _add_variant("model.safetensors.index.json", variant))
        ):
            archive_file = os.path.join(
                pretrained_model_name_or_path, subfolder,
                _add_variant("model.safetensors.index.json", variant))
            is_sharded = True
        elif os.path.isfile(
            os.path.join(pretrained_model_name_or_path, subfolder,
                         _add_variant(WEIGHTS_INDEX_NAME, variant))
        ):
            archive_file = os.path.join(
                pretrained_model_name_or_path, subfolder,
                _add_variant(WEIGHTS_INDEX_NAME, variant))
            is_sharded = True
        else:
            raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")

        resolved_archive_file, _ = get_checkpoint_shard_files(
            pretrained_model_name_or_path, archive_file)

        if not is_sharded:
            resolved_archive_file = [resolved_archive_file]

        for shard_file in resolved_archive_file:
            state_dict = load_state_dict(shard_file)
            replace_params(state_dict, vanilla_model.state_dict(), config)
            vanilla_model.load_state_dict(state_dict, strict=False)
            del state_dict
            gc.collect()

        return vanilla_model

I/O

Direction Name Type Description
Input config LlamaConfig HuggingFace LLaMA configuration object containing model hyperparameters.
Output return value LlamaForCausalLM A HuggingFace LlamaForCausalLM instance with TE decoder layers. Note: the returned type is LlamaForCausalLM, not TELlamaForCausalLM.

from_pretrained_local I/O

Direction Name Type Description
Input pretrained_model_name_or_path str Path to a local directory containing the sharded HuggingFace checkpoint files.
Input config LlamaConfig HuggingFace LLaMA configuration (keyword-only argument).
Input torch_dtype torch.dtype Desired PyTorch dtype for model parameters (passed via **kwargs).
Output return value LlamaForCausalLM A pretrained LlamaForCausalLM instance with TE decoder layers and loaded weights.

Construction Flow

The model construction proceeds through the following steps:

  1. TELlamaForCausalLM.__new__(cls, config) is called
  2. Inside __new__, the replace_decoder context manager patches LlamaDecoderLayer
  3. LlamaForCausalLM(config) is called, which invokes LlamaForCausalLM.__init__
  4. LlamaForCausalLM.__init__ creates a LlamaModel, which creates decoder layers
  5. Each decoder layer instantiation resolves to TELlamaDecoderLayer(config, layer_idx) due to the patch
  6. The context manager restores the original LlamaDecoderLayer class
  7. The LlamaForCausalLM instance (with TE layers) is returned

Related

Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment