Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime CPU GRU GradCompute

From Leeroopedia


Knowledge Sources
Domains Training, CPU_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for computing GRU (Gated Recurrent Unit) gradient backpropagation on CPU in the ONNX Runtime training framework.

Description

This file implements the GRUGradImpl class template which performs the backward pass for a GRU cell. The gradient computation iterates in reverse through the time steps, computing gradients for the update gate (z), reset gate (r), and hidden state candidate (h) using the chain rule. It supports both the standard GRU formulation and the linear_before_reset variant. The implementation computes gradients for all learnable parameters: input weights (W), recurrence weights (R), biases (Wb, Rb), the input sequence (X), and the initial hidden state (H0). Matrix operations are performed using GEMM calls, and gradients are accumulated across batch elements and time steps using elementwise summation helpers from the deepcpu namespace.

Usage

This kernel is used during the backward pass of GRU training. It is called by the GRUGrad operator to compute parameter gradients after the forward pass has produced the intermediate gate activations (z, r, h) and all hidden states.

Code Reference

Source Location

Signature

template <typename T>
GRUGradImpl<T>::GRUGradImpl(int sequence_length, int batch_size,
                            int hidden_size, int input_size,
                            bool linear_before_reset,
                            concurrency::ThreadPool* thread_pool,
                            AllocatorPtr allocator);

template <typename T>
void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs,
                                     GRUGradOutputs<T>& outputs);

Import

#include "orttraining/orttraining/training_ops/cpu/rnn/gru_grad_compute.h"

I/O Contract

Inputs (GRUGradInputs)

Name Type Required Description
weights span<const T> Yes Input weights W [3*H, I]
recurrence_weights span<const T> Yes Recurrence weights R [3*H, H]
bias span<const T> No Bias [6*H]
input span<const T> Yes Input sequence [seq_len * batch * input_size]
zrh span<const T> Yes Gate activations from forward pass [seq_len * batch * 3*H]
all_hidden_states span<const T> Yes All hidden states from forward [seq_len * batch * H]
initial_hidden_state span<const T> Yes Initial hidden state [batch * H]
grad_all_hidden_states span<const T> No Gradient w.r.t. all hidden states
grad_final_hidden_state span<const T> No Gradient w.r.t. final hidden state

Outputs (GRUGradOutputs)

Name Type Description
grad_input span<T> Gradient w.r.t. input X
grad_weights span<T> Gradient w.r.t. weights W
grad_recurrence_weights span<T> Gradient w.r.t. recurrence weights R
grad_bias span<T> Gradient w.r.t. bias
grad_initial_hidden_state span<T> Gradient w.r.t. initial hidden state H0

Usage Examples

// Creating and using GRUGradImpl
gru::GRUGradImpl<float> gru_cell(sequence_length, batch_size,
                                  hidden_size, input_size,
                                  linear_before_reset,
                                  thread_pool, allocator);
gru_cell.ComputeGradient(grugrad_inputs, grugrad_outputs);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment