Implementation:NVIDIA TransformerEngine TELlamaForCausalLM
Overview
TE-accelerated LlamaForCausalLM model using monkey-patching for layer replacement.
Doc Type
Wrapper Doc -- This class wraps HuggingFace's LlamaForCausalLM using a __new__-based factory pattern to inject TE decoder layers.
Description
TELlamaForCausalLM uses Python's __new__ method to create a LlamaForCausalLM instance, then initializes it within a replace_decoder context so that all decoder layers become TELlamaDecoderLayer instances. The result is a standard HuggingFace LlamaForCausalLM object with full API compatibility but TE-powered internals.
The class provides two construction paths:
__new__(cls, config): Creates a newLlamaForCausalLMwith randomly initialized TE decoder layers.from_pretrained_local(cls, path, config=config): Creates the TE model and loads pretrained HuggingFace checkpoint weights, mapping them to TE's parameter layout.
The from_pretrained_local method handles sharded checkpoints by:
- Detecting whether the checkpoint uses
model.safetensors.index.jsonorpytorch_model.bin.index.json - Resolving shard file paths via HuggingFace's
get_checkpoint_shard_files - Iterating over each shard, loading weights and mapping TE-specific parameters via
replace_params - Using
load_state_dict(strict=False)for non-TE parameters (embeddings, final layer norm, lm_head) - Explicitly freeing memory after each shard with
del state_dictandgc.collect()
Source
- File:
docs/examples/te_llama/te_llama.py - Class:
TELlamaForCausalLM - Lines: L87-163
Signature
class TELlamaForCausalLM:
"""
Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
class is monkey-patched with `TELlamaDecoderLayer` class before
initializing the causal LM with `LlamaForCausalLM`.
Args:
config: LlamaConfig
"""
def __new__(cls, config: LlamaConfig):
with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
llama_for_causal_lm = LlamaForCausalLM(config)
return llama_for_causal_lm
@classmethod
def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs):
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo.
"""
torch.set_default_dtype(kwargs["torch_dtype"])
vanilla_model = cls(config)
subfolder = ""
variant = None
# Detect sharded checkpoint format (safetensors or PyTorch)
if os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder,
_add_variant("model.safetensors.index.json", variant))
):
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder,
_add_variant("model.safetensors.index.json", variant))
is_sharded = True
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder,
_add_variant(WEIGHTS_INDEX_NAME, variant))
):
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder,
_add_variant(WEIGHTS_INDEX_NAME, variant))
is_sharded = True
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path, archive_file)
if not is_sharded:
resolved_archive_file = [resolved_archive_file]
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
replace_params(state_dict, vanilla_model.state_dict(), config)
vanilla_model.load_state_dict(state_dict, strict=False)
del state_dict
gc.collect()
return vanilla_model
I/O
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | config |
LlamaConfig |
HuggingFace LLaMA configuration object containing model hyperparameters. |
| Output | return value | LlamaForCausalLM |
A HuggingFace LlamaForCausalLM instance with TE decoder layers. Note: the returned type is LlamaForCausalLM, not TELlamaForCausalLM.
|
from_pretrained_local I/O
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | pretrained_model_name_or_path |
str |
Path to a local directory containing the sharded HuggingFace checkpoint files. |
| Input | config |
LlamaConfig |
HuggingFace LLaMA configuration (keyword-only argument). |
| Input | torch_dtype |
torch.dtype |
Desired PyTorch dtype for model parameters (passed via **kwargs).
|
| Output | return value | LlamaForCausalLM |
A pretrained LlamaForCausalLM instance with TE decoder layers and loaded weights.
|
Construction Flow
The model construction proceeds through the following steps:
TELlamaForCausalLM.__new__(cls, config)is called- Inside
__new__, thereplace_decodercontext manager patchesLlamaDecoderLayer LlamaForCausalLM(config)is called, which invokesLlamaForCausalLM.__init__LlamaForCausalLM.__init__creates aLlamaModel, which creates decoder layers- Each decoder layer instantiation resolves to
TELlamaDecoderLayer(config, layer_idx)due to the patch - The context manager restores the original
LlamaDecoderLayerclass - The
LlamaForCausalLMinstance (with TE layers) is returned
Related
Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements