Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Recommenders team Recommenders Conjugate Gradient MS

From Leeroopedia


Knowledge Sources
Domains Riemannian Optimization, Numerical Optimization, Manifold Optimization
Last Updated 2026-02-10 00:00 GMT

Overview

The ConjugateGradientMS class implements a modified Conjugate Gradient solver for Riemannian optimization on manifolds, extending Pymanopt's solver with a statistics callback mechanism.

Description

The ConjugateGradientMS class extends Pymanopt's Solver base class to perform nonlinear conjugate gradient optimization on Riemannian manifolds. The key modification (denoted by the "MS" suffix) is the addition of a compute_stats callback parameter in the solve method, enabling external monitoring of optimization progress such as tracking training and validation RMSE during RLRMC training.

The solver supports four beta rules for constructing conjugate search directions:

  • Fletcher-Reeves: Classic CG beta based on ratio of successive gradient norms.
  • Polak-Ribiere: Uses gradient difference for improved convergence on nonlinear problems.
  • Hestenes-Stiefel (default): Computes beta from gradient difference projected along the search direction.
  • Hager-Zhang: Advanced rule with robustness guarantees.

Additional algorithmic features:

  • Adaptive line search: Uses LineSearchBackTracking or LineSearchAdaptive to find appropriate step sizes along the manifold.
  • Powell's restart strategy: Resets to steepest descent when consecutive gradients become insufficiently orthogonal, controlled by the orth_value parameter.
  • Riemannian operations: Uses manifold-aware operations including parallel transport for moving vectors between tangent spaces and Riemannian inner products for gradient computations.
  • Convergence monitoring: Checks stopping criteria based on gradient norm, iteration count, maximum time, and step size.

Usage

Use this solver as the optimization backend for the RLRMC algorithm when performing matrix completion via Riemannian optimization. It is designed to work with Pymanopt problem definitions on product manifolds and provides the callback infrastructure needed to track per-iteration metrics during training.

Code Reference

Source Location

Signature

class ConjugateGradientMS(Solver):
    def __init__(
        self,
        beta_type=BetaTypes.HestenesStiefel,
        orth_value=np.inf,
        linesearch=None,
        *args,
        **kwargs
    )

    def solve(self, problem, x=None, reuselinesearch=False, compute_stats=None)

Import

from recommenders.models.rlrmc.conjugate_gradient_ms import ConjugateGradientMS

I/O Contract

Inputs

Name Type Required Description
beta_type BetaTypes enum No Conjugate gradient beta rule: FletcherReeves, PolakRibiere, HestenesStiefel, or HagerZhang; default HestenesStiefel
orth_value float No Threshold for Powell's restart strategy; infinite value disables restart; default np.inf
linesearch object No Line search method instance; default LineSearchAdaptive
problem Problem Yes (for solve) Pymanopt Problem object with manifold, cost function, and gradient
x numpy.ndarray No (for solve) Starting point on the manifold; if None, a random point is generated
reuselinesearch bool No (for solve) Whether to reuse the previous line search object across solve calls; default False
compute_stats callable No (for solve) Callback function receiving (weights, [iter, cost, gradnorm, time], stats_dict) per iteration

Outputs

Name Type Description
x numpy.ndarray Local minimum on the manifold, or the point at which the algorithm terminated
stats dict Dictionary of per-iteration statistics populated by the compute_stats callback
optlog dict Optional optimization log (returned only if logverbosity > 0)

Usage Examples

Basic Usage

from recommenders.models.rlrmc.conjugate_gradient_ms import ConjugateGradientMS
from pymanopt.solvers.linesearch import LineSearchBackTracking
from pymanopt.manifolds import Stiefel, SymmetricPositiveDefinite, Product
from pymanopt import Problem

# Define the manifold (product of Stiefel and SPD manifolds)
manifold = Product([
    Stiefel(num_rows, rank),
    Stiefel(num_cols, rank),
    SymmetricPositiveDefinite(rank),
])

# Define the optimization problem
problem = Problem(
    manifold=manifold,
    cost=cost_function,
    egrad=gradient_function,
    verbosity=1,
)

# Create the solver with line search
solver = ConjugateGradientMS(
    maxtime=600,
    maxiter=100,
    linesearch=LineSearchBackTracking(),
)

# Solve with statistics tracking
def my_stats_callback(weights, given_stats, stats):
    stats.setdefault("cost", []).append(given_stats[1])

Wopt, stats = solver.solve(problem, compute_stats=my_stats_callback)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment