Implementation:NVIDIA TransformerEngine Replace Decoder Context Manager
Overview
Context manager that monkey-patches the HuggingFace decoder layer class for TransformerEngine substitution.
Doc Type
Pattern Doc -- This function implements the monkey-patching pattern for transparent layer class substitution at model construction time.
Description
replace_decoder(te_decoder_cls) is a context manager (decorated with @contextmanager) that temporarily replaces the LlamaDecoderLayer class reference in the transformers.models.llama.modeling_llama module namespace with the provided TE decoder class.
The function follows this sequence:
- Save the original
LlamaDecoderLayerclass reference. - Replace the module-level reference with the provided
te_decoder_cls. - Yield control to the caller (within the
withblock). - Restore the original class reference in the
finallyblock, ensuring restoration even if an exception occurs.
This makes it safe to use in production code -- the module state is always cleaned up, and the substitution is scoped precisely to the with block.
Source
- File:
docs/examples/te_llama/te_llama.py - Function:
replace_decoder - Lines: L27-37
Signature
@contextmanager
def replace_decoder(te_decoder_cls):
"""
Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
"""
original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
try:
yield
finally:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls
I/O
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | te_decoder_cls |
type (class) |
The TE decoder layer class to substitute for LlamaDecoderLayer. Must accept the same constructor signature as LlamaDecoderLayer (i.e., (config, *args, **kwargs)).
|
| Output | context manager | generator-based context manager | Yields nothing. The substitution is active for the duration of the with block.
|
Behavior
The context manager modifies and restores a single module-level attribute:
| Phase | Action | Module State |
|---|---|---|
| Entry | Saves original class, assigns te_decoder_cls |
transformers.models.llama.modeling_llama.LlamaDecoderLayer points to te_decoder_cls
|
| During | Caller constructs model | Any LlamaDecoderLayer() calls instantiate te_decoder_cls
|
| Exit (normal) | Restores original class | transformers.models.llama.modeling_llama.LlamaDecoderLayer points to original
|
| Exit (exception) | Restores original class via finally |
transformers.models.llama.modeling_llama.LlamaDecoderLayer points to original
|
Example Usage
from te_llama import TELlamaDecoderLayer, replace_decoder
from transformers import LlamaForCausalLM, LlamaConfig
config = LlamaConfig(...)
# Within the context, LlamaForCausalLM will use TELlamaDecoderLayer
with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
model = LlamaForCausalLM(config)
# After the context, the original LlamaDecoderLayer is restored
# but the model retains its TELlamaDecoderLayer instances