Implementation:Recommenders team Recommenders Conjugate Gradient MS
| Knowledge Sources | |
|---|---|
| Domains | Riemannian Optimization, Numerical Optimization, Manifold Optimization |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
The ConjugateGradientMS class implements a modified Conjugate Gradient solver for Riemannian optimization on manifolds, extending Pymanopt's solver with a statistics callback mechanism.
Description
The ConjugateGradientMS class extends Pymanopt's Solver base class to perform nonlinear conjugate gradient optimization on Riemannian manifolds. The key modification (denoted by the "MS" suffix) is the addition of a compute_stats callback parameter in the solve method, enabling external monitoring of optimization progress such as tracking training and validation RMSE during RLRMC training.
The solver supports four beta rules for constructing conjugate search directions:
- Fletcher-Reeves: Classic CG beta based on ratio of successive gradient norms.
- Polak-Ribiere: Uses gradient difference for improved convergence on nonlinear problems.
- Hestenes-Stiefel (default): Computes beta from gradient difference projected along the search direction.
- Hager-Zhang: Advanced rule with robustness guarantees.
Additional algorithmic features:
- Adaptive line search: Uses LineSearchBackTracking or LineSearchAdaptive to find appropriate step sizes along the manifold.
- Powell's restart strategy: Resets to steepest descent when consecutive gradients become insufficiently orthogonal, controlled by the orth_value parameter.
- Riemannian operations: Uses manifold-aware operations including parallel transport for moving vectors between tangent spaces and Riemannian inner products for gradient computations.
- Convergence monitoring: Checks stopping criteria based on gradient norm, iteration count, maximum time, and step size.
Usage
Use this solver as the optimization backend for the RLRMC algorithm when performing matrix completion via Riemannian optimization. It is designed to work with Pymanopt problem definitions on product manifolds and provides the callback infrastructure needed to track per-iteration metrics during training.
Code Reference
Source Location
- Repository: Recommenders
- File: recommenders/models/rlrmc/conjugate_gradient_ms.py
- Lines: 1-255
Signature
class ConjugateGradientMS(Solver):
def __init__(
self,
beta_type=BetaTypes.HestenesStiefel,
orth_value=np.inf,
linesearch=None,
*args,
**kwargs
)
def solve(self, problem, x=None, reuselinesearch=False, compute_stats=None)
Import
from recommenders.models.rlrmc.conjugate_gradient_ms import ConjugateGradientMS
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| beta_type | BetaTypes enum | No | Conjugate gradient beta rule: FletcherReeves, PolakRibiere, HestenesStiefel, or HagerZhang; default HestenesStiefel |
| orth_value | float | No | Threshold for Powell's restart strategy; infinite value disables restart; default np.inf |
| linesearch | object | No | Line search method instance; default LineSearchAdaptive |
| problem | Problem | Yes (for solve) | Pymanopt Problem object with manifold, cost function, and gradient |
| x | numpy.ndarray | No (for solve) | Starting point on the manifold; if None, a random point is generated |
| reuselinesearch | bool | No (for solve) | Whether to reuse the previous line search object across solve calls; default False |
| compute_stats | callable | No (for solve) | Callback function receiving (weights, [iter, cost, gradnorm, time], stats_dict) per iteration |
Outputs
| Name | Type | Description |
|---|---|---|
| x | numpy.ndarray | Local minimum on the manifold, or the point at which the algorithm terminated |
| stats | dict | Dictionary of per-iteration statistics populated by the compute_stats callback |
| optlog | dict | Optional optimization log (returned only if logverbosity > 0) |
Usage Examples
Basic Usage
from recommenders.models.rlrmc.conjugate_gradient_ms import ConjugateGradientMS
from pymanopt.solvers.linesearch import LineSearchBackTracking
from pymanopt.manifolds import Stiefel, SymmetricPositiveDefinite, Product
from pymanopt import Problem
# Define the manifold (product of Stiefel and SPD manifolds)
manifold = Product([
Stiefel(num_rows, rank),
Stiefel(num_cols, rank),
SymmetricPositiveDefinite(rank),
])
# Define the optimization problem
problem = Problem(
manifold=manifold,
cost=cost_function,
egrad=gradient_function,
verbosity=1,
)
# Create the solver with line search
solver = ConjugateGradientMS(
maxtime=600,
maxiter=100,
linesearch=LineSearchBackTracking(),
)
# Solve with statistics tracking
def my_stats_callback(weights, given_stats, stats):
stats.setdefault("cost", []).append(given_stats[1])
Wopt, stats = solver.solve(problem, compute_stats=my_stats_callback)