Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime CPU LSTM GradCompute

From Leeroopedia


Knowledge Sources
Domains Training, CPU_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for computing LSTM (Long Short-Term Memory) gradient backpropagation on CPU in the ONNX Runtime training framework.

Description

This file implements the LSTMGradImpl class template which performs the backward pass for an LSTM cell. The gradient computation iterates in reverse through the time steps, computing gradients for the four gates: input (i), output (o), forget (f), and cell candidate (c). It follows the standard LSTM backpropagation through time (BPTT) algorithm. The implementation computes gradients for input weights (Wi, Wo, Wf, Wc), recurrence weights (Ri, Ro, Rf, Rc), biases (8 total: Wb and Rb for each gate), peephole weights (Pi, Po, Pf), the input sequence (X), initial hidden state (H0), and initial cell state (C0). The iofc buffer layout stores gate values as interleaved blocks per batch per time step. GEMM operations are used for matrix multiplications and gradients are accumulated using elementwise summation.

Usage

This kernel is used during the backward pass of LSTM training. It is called by the LSTMGrad operator after the forward pass has produced the intermediate gate activations (iofc), all hidden states, and all cell states.

Code Reference

Source Location

Signature

template <typename T>
LSTMGradImpl<T>::LSTMGradImpl(int sequence_length, int batch_size,
                              int hidden_size, int input_size,
                              concurrency::ThreadPool* thread_pool,
                              AllocatorPtr allocator);

template <typename T>
void LSTMGradImpl<T>::ComputeGradient(const LSTMGradInputs<T>& inputs,
                                      LSTMGradOutputs<T>& outputs);

Import

#include "orttraining/orttraining/training_ops/cpu/rnn/lstm_grad_compute.h"

I/O Contract

Inputs (LSTMGradInputs)

Name Type Required Description
weights span<const T> Yes Input weights W [4*H, I]
recurrence_weights span<const T> Yes Recurrence weights R [4*H, H]
input span<const T> Yes Input sequence
iofc span<const T> Yes Gate activations from forward [seq * dir * batch * 4*H]
all_hidden_states span<const T> Yes All hidden states from forward
all_cell_states span<const T> Yes All cell states from forward
initial_hidden_state span<const T> Yes Initial hidden state
initial_cell_state span<const T> Yes Initial cell state
grad_all_hidden_states span<const T> No Gradient w.r.t. all hidden states
grad_final_hidden_state span<const T> No Gradient w.r.t. final hidden state
grad_final_cell_state span<const T> No Gradient w.r.t. final cell state

Outputs (LSTMGradOutputs)

Name Type Description
grad_input span<T> Gradient w.r.t. input X
grad_weights span<T> Gradient w.r.t. weights W
grad_recurrence_weights span<T> Gradient w.r.t. recurrence weights R
grad_bias span<T> Gradient w.r.t. bias (8 * hidden_size)
grad_initial_hidden_state span<T> Gradient w.r.t. initial hidden state H0
grad_initial_cell_state span<T> Gradient w.r.t. initial cell state C0
grad_peephole_weights span<T> Gradient w.r.t. peephole weights (3 * hidden_size)

Usage Examples

// Creating and using LSTMGradImpl
lstm::LSTMGradImpl<float> lstm_cell(sequence_length, batch_size,
                                     hidden_size, input_size,
                                     thread_pool, allocator);
lstm_cell.ComputeGradient(lstmgrad_inputs, lstmgrad_outputs);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment