Heuristic:Deepseek ai Janus Bfloat16 Operation Workarounds
| Knowledge Sources | |
|---|---|
| Domains | Debugging, Deep_Learning, Optimization |
| Last Updated | 2026-02-10 09:30 GMT |
Overview
Cast tensors to float32 before calling PyTorch operations that do not support bfloat16, such as upsample_nearest2d, F.interpolate, and trunc_normal_ initialization.
Description
Several PyTorch operations do not support bfloat16 tensors and will either raise errors or produce incorrect results. The Janus codebase contains multiple workarounds where tensors are temporarily cast to float32 before these operations, then cast back to bfloat16 afterward. These workarounds are documented with TODO comments indicating they should be removed once PyTorch fixes the underlying issues. Additionally, large batch sizes (>= 64) require contiguous memory layout for NHWC upsampling.
Usage
Apply this heuristic when you encounter errors like `RuntimeError` in upsample or interpolate operations while using bfloat16 tensors, or when custom model layers produce incorrect initialization values. This is relevant for the UViT decoder in JanusFlow and the VQ-VAE upsampling in Janus.
The Insight (Rule of Thumb)
- Action 1: Before calling `F.interpolate()` with `mode="nearest"`, check if the tensor dtype is bfloat16. If so, cast to float32, perform the interpolation, then cast back.
- Action 2: Before calling `trunc_normal_()` for weight initialization, cast the tensor to float32, apply the initialization, then cast back to the original dtype.
- Action 3: For batch sizes >= 64, call `.contiguous()` on the tensor before upsampling to avoid NHWC layout failures.
- Trade-off: Temporary float32 casting uses 2x memory for the affected tensor during the operation, but the tensors are typically small feature maps, not the full model.
Reasoning
These are known PyTorch bugs and limitations:
upsample_nearest2d: PyTorch issue #86679 documents that the `upsample_nearest2d_out_frame` kernel does not support bfloat16. The Janus codebase includes this workaround with a TODO to remove it once PyTorch fixes the issue.
Large batch NHWC upsample: HuggingFace Diffusers issue #984 documents that `upsample_nearest_nhwc` fails with large batch sizes. The workaround is to make the tensor contiguous before interpolation.
trunc_normal_: The timm library's weight initialization function does not handle bfloat16 tensors. The Janus codebase reimplements this function with a float32 workaround.
VQ-VAE interpolation: The VQ model's Upsample layer in `janus/models/vq_model.py` performs a similar cast to float32 for `F.interpolate`, but hardcodes the cast-back dtype as bfloat16 (which could be an issue if the model runs in float32).
Code Evidence
UViT upsample workaround from `janus/janusflow/models/uvit.py:336-361`:
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# ... interpolation ...
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
trunc_normal_ workaround from `janus/janusflow/models/siglip_vit.py:92-95`:
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
VQ-VAE upsample workaround from `janus/models/vq_model.py:417-423`:
def forward(self, x):
if x.dtype != torch.float32:
x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
torch.bfloat16
)
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")