Implementation:Microsoft Onnxruntime CPU LSTM Grad
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing LSTM gradients during the backward pass on CPU in the ONNX Runtime training framework.
Description
This file implements the LSTMGrad kernel, which orchestrates the backward pass for LSTM training. It parses gradient inputs using lstm::LSTMGradInputs, allocates gradient outputs via lstm::LSTMGradOutputs, and delegates the actual gradient computation to lstm::LSTMGradImpl. The kernel is registered under kMSDomain with opset version 1 and supports float type only.
Usage
This kernel is invoked during the backward pass of LSTM training, consuming the gate activations (iofc), hidden states, and cell states produced by the LSTMTraining forward kernel to compute parameter gradients.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/rnn/lstm_grad.cc
- Lines: 1-45
Signature
template <typename T>
Status LSTMGrad<T>::Compute(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/rnn/lstm_grad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| X | Tensor(float) | Yes | Input sequence [seq_length, batch_size, input_size] |
| W | Tensor(float) | Yes | Weights [directions, 4*H, input_size] |
| R | Tensor(float) | Yes | Recurrence weights [directions, 4*H, H] |
| SL | Tensor(int) | No | Sequence lengths |
| H0 | Tensor(float) | No | Initial hidden state |
| C0 | Tensor(float) | No | Initial cell state |
| HAll | Tensor(float) | Yes | All hidden states from forward |
| CAll | Tensor(float) | Yes | All cell states from forward |
| IOFC | Tensor(float) | Yes | Gate activations from forward |
| grad_HAll | Tensor(float) | No | Gradient w.r.t. all hidden states |
| grad_Ht | Tensor(float) | No | Gradient w.r.t. final hidden state |
| grad_Ct | Tensor(float) | No | Gradient w.r.t. final cell state |
Outputs
| Name | Type | Description |
|---|---|---|
| dX | Tensor(float) | Gradient w.r.t. input X |
| dW | Tensor(float) | Gradient w.r.t. weights W |
| dR | Tensor(float) | Gradient w.r.t. recurrence weights R |
| dB | Tensor(float) | Gradient w.r.t. bias |
| dH0 | Tensor(float) | Gradient w.r.t. initial hidden state |
| dC0 | Tensor(float) | Gradient w.r.t. initial cell state |
| dP | Tensor(float) | Gradient w.r.t. peephole weights |
Usage Examples
ONNX_OPERATOR_TYPED_KERNEL_EX(
LSTMGrad, kMSDomain, 1, float, kCpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
LSTMGrad<float>);