Implementation:Microsoft Onnxruntime CPU LSTM GradCompute
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing LSTM (Long Short-Term Memory) gradient backpropagation on CPU in the ONNX Runtime training framework.
Description
This file implements the LSTMGradImpl class template which performs the backward pass for an LSTM cell. The gradient computation iterates in reverse through the time steps, computing gradients for the four gates: input (i), output (o), forget (f), and cell candidate (c). It follows the standard LSTM backpropagation through time (BPTT) algorithm. The implementation computes gradients for input weights (Wi, Wo, Wf, Wc), recurrence weights (Ri, Ro, Rf, Rc), biases (8 total: Wb and Rb for each gate), peephole weights (Pi, Po, Pf), the input sequence (X), initial hidden state (H0), and initial cell state (C0). The iofc buffer layout stores gate values as interleaved blocks per batch per time step. GEMM operations are used for matrix multiplications and gradients are accumulated using elementwise summation.
Usage
This kernel is used during the backward pass of LSTM training. It is called by the LSTMGrad operator after the forward pass has produced the intermediate gate activations (iofc), all hidden states, and all cell states.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/rnn/lstm_grad_compute.cc
- Lines: 1-507
Signature
template <typename T>
LSTMGradImpl<T>::LSTMGradImpl(int sequence_length, int batch_size,
int hidden_size, int input_size,
concurrency::ThreadPool* thread_pool,
AllocatorPtr allocator);
template <typename T>
void LSTMGradImpl<T>::ComputeGradient(const LSTMGradInputs<T>& inputs,
LSTMGradOutputs<T>& outputs);
Import
#include "orttraining/orttraining/training_ops/cpu/rnn/lstm_grad_compute.h"
I/O Contract
Inputs (LSTMGradInputs)
| Name | Type | Required | Description |
|---|---|---|---|
| weights | span<const T> | Yes | Input weights W [4*H, I] |
| recurrence_weights | span<const T> | Yes | Recurrence weights R [4*H, H] |
| input | span<const T> | Yes | Input sequence |
| iofc | span<const T> | Yes | Gate activations from forward [seq * dir * batch * 4*H] |
| all_hidden_states | span<const T> | Yes | All hidden states from forward |
| all_cell_states | span<const T> | Yes | All cell states from forward |
| initial_hidden_state | span<const T> | Yes | Initial hidden state |
| initial_cell_state | span<const T> | Yes | Initial cell state |
| grad_all_hidden_states | span<const T> | No | Gradient w.r.t. all hidden states |
| grad_final_hidden_state | span<const T> | No | Gradient w.r.t. final hidden state |
| grad_final_cell_state | span<const T> | No | Gradient w.r.t. final cell state |
Outputs (LSTMGradOutputs)
| Name | Type | Description |
|---|---|---|
| grad_input | span<T> | Gradient w.r.t. input X |
| grad_weights | span<T> | Gradient w.r.t. weights W |
| grad_recurrence_weights | span<T> | Gradient w.r.t. recurrence weights R |
| grad_bias | span<T> | Gradient w.r.t. bias (8 * hidden_size) |
| grad_initial_hidden_state | span<T> | Gradient w.r.t. initial hidden state H0 |
| grad_initial_cell_state | span<T> | Gradient w.r.t. initial cell state C0 |
| grad_peephole_weights | span<T> | Gradient w.r.t. peephole weights (3 * hidden_size) |
Usage Examples
// Creating and using LSTMGradImpl
lstm::LSTMGradImpl<float> lstm_cell(sequence_length, batch_size,
hidden_size, input_size,
thread_pool, allocator);
lstm_cell.ComputeGradient(lstmgrad_inputs, lstmgrad_outputs);