Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Rapidsai Cuml SGD CD Solver

From Leeroopedia


Knowledge Sources
Domains Machine_Learning, Linear_Models
Last Updated 2026-02-08 12:00 GMT

Overview

Provides GPU-accelerated Stochastic Gradient Descent (SGD) and Coordinate Descent (CD) solvers for fitting and predicting with linear models including linear regression, lasso, ridge, and elastic-net.

Description

The solver.hpp header declares two families of GPU-accelerated optimization solvers within the ML::Solver namespace:

Stochastic Gradient Descent (SGD):

  • sgdFit: Fits a linear model using mini-batch SGD. Supports configurable loss functions, penalty types (L1, L2, elastic-net), learning rate schedules, and early stopping. Available for both float and double precision.
  • sgdPredict: Predicts continuous values using the fitted SGD model.
  • sgdPredictBinaryClass: Predicts binary class labels using the fitted SGD model.

Coordinate Descent (CD):

  • cdFit: Fits a linear, lasso, or elastic-net regression model using coordinate descent. It minimizes the elastic-net objective: 1/2 * ||y - Xw||^2 + alpha/2 * (1 - l1_ratio) * ||w||^2 + alpha * l1_ratio * ||w||_1. Supports optional sample weights and returns the number of iterations. Available for both float and double.
  • cdPredict: Predicts using the fitted CD model.

All functions operate on device memory in column-major format and accept a RAFT handle for GPU resource management.

Usage

Use SGD when you need an online or mini-batch learning algorithm that can handle large datasets with configurable loss functions and learning rate schedules. Use Coordinate Descent when fitting regularized linear models (ridge, lasso, elastic-net) where the closed-form coordinate update is efficient. Both are single-GPU solvers; for multi-GPU coordinate descent, see the cd_mg.hpp API.

Code Reference

Source Location

  • Repository: Rapidsai_Cuml
  • File: cpp/include/cuml/solvers/solver.hpp

Signature

namespace ML {
namespace Solver {

// SGD functions
void sgdFit(raft::handle_t& handle, float* input, int n_rows, int n_cols,
            float* labels, float* coef, float* intercept, bool fit_intercept,
            int batch_size, int epochs, int lr_type, float eta0, float power_t,
            int loss, int penalty, float alpha, float l1_ratio,
            bool shuffle, float tol, int n_iter_no_change);

void sgdFit(raft::handle_t& handle, double* input, int n_rows, int n_cols,
            double* labels, double* coef, double* intercept, bool fit_intercept,
            int batch_size, int epochs, int lr_type, double eta0, double power_t,
            int loss, int penalty, double alpha, double l1_ratio,
            bool shuffle, double tol, int n_iter_no_change);

void sgdPredict(raft::handle_t& handle, const float* input, int n_rows, int n_cols,
                const float* coef, float intercept, float* preds, int loss);

void sgdPredict(raft::handle_t& handle, const double* input, int n_rows, int n_cols,
                const double* coef, double intercept, double* preds, int loss);

void sgdPredictBinaryClass(raft::handle_t& handle, const float* input, int n_rows, int n_cols,
                           const float* coef, float intercept, float* preds, int loss);

void sgdPredictBinaryClass(raft::handle_t& handle, const double* input, int n_rows, int n_cols,
                           const double* coef, double intercept, double* preds, int loss);

// Coordinate Descent functions
int cdFit(raft::handle_t& handle, float* input, int n_rows, int n_cols,
          float* labels, float* coef, float* intercept, bool fit_intercept,
          int epochs, int loss, float alpha, float l1_ratio,
          bool shuffle, float tol, float* sample_weight = nullptr);

int cdFit(raft::handle_t& handle, double* input, int n_rows, int n_cols,
          double* labels, double* coef, double* intercept, bool fit_intercept,
          int epochs, int loss, double alpha, double l1_ratio,
          bool shuffle, double tol, double* sample_weight = nullptr);

void cdPredict(raft::handle_t& handle, const float* input, int n_rows, int n_cols,
               const float* coef, float intercept, float* preds, int loss);

void cdPredict(raft::handle_t& handle, const double* input, int n_rows, int n_cols,
               const double* coef, double intercept, double* preds, int loss);

} // namespace Solver
} // namespace ML

Import

#include <cuml/solvers/solver.hpp>

I/O Contract

Inputs

sgdFit

Name Type Required Description
handle raft::handle_t& Yes RAFT handle for GPU resources
input T* Yes Device pointer to feature matrix in column-major format [n_rows x n_cols]
n_rows int Yes Number of training samples
n_cols int Yes Number of features
labels T* Yes Device pointer to labels [n_rows]
fit_intercept bool Yes Whether to fit an intercept term
batch_size int Yes Mini-batch size
epochs int Yes Number of training epochs
lr_type int Yes Learning rate schedule type
eta0 T Yes Initial learning rate
power_t T Yes Exponent for inverse scaling learning rate
loss int Yes Loss function type
penalty int Yes Penalty type (L1, L2, elastic-net)
alpha T Yes Regularization strength
l1_ratio T Yes L1 ratio in elastic-net (0 to 1)
shuffle bool Yes Whether to shuffle data each epoch
tol T Yes Convergence tolerance
n_iter_no_change int Yes Number of iterations with no improvement for early stopping

cdFit

Name Type Required Description
handle raft::handle_t& Yes RAFT handle
input T* Yes Device pointer to feature matrix in column-major [n_rows x n_cols]
n_rows int Yes Number of samples
n_cols int Yes Number of features
labels T* Yes Device pointer to labels [n_rows]
fit_intercept bool Yes Whether to fit an intercept
epochs int Yes Maximum number of iterations
loss int Yes Loss function (only linear regression supported)
alpha T Yes Regularization parameter
l1_ratio T Yes Ratio for L1 vs L2 regularization
shuffle bool Yes Whether to shuffle coordinates
tol T Yes Convergence tolerance
sample_weight T* No Optional sample weights [n_rows] (default: nullptr)

Outputs

Name Type Description
coef (fit) T* Device array of learned coefficients [n_cols]
intercept (fit) T* Pointer to learned intercept scalar
return value (cdFit) int Number of iterations the solver ran
preds (predict) T* Device array of predictions [n_rows]

Usage Examples

#include <cuml/solvers/solver.hpp>

raft::handle_t handle;

int n_rows = 10000;
int n_cols = 50;

float* d_X;           // device [n_rows x n_cols], column-major
float* d_y;           // device [n_rows]
float* d_coef;        // device [n_cols]
float intercept;

// Fit using Coordinate Descent (elastic-net)
int n_iter = ML::Solver::cdFit(handle, d_X, n_rows, n_cols, d_y,
                                d_coef, &intercept,
                                true,    // fit_intercept
                                1000,    // epochs
                                0,       // loss (linear regression)
                                1.0f,    // alpha
                                0.5f,    // l1_ratio
                                true,    // shuffle
                                1e-4f);  // tol

// Predict
float* d_preds;  // device [n_rows]
ML::Solver::cdPredict(handle, d_X, n_rows, n_cols,
                      d_coef, intercept, d_preds, 0);

// Or fit using SGD
ML::Solver::sgdFit(handle, d_X, n_rows, n_cols, d_y,
                   d_coef, &intercept,
                   true,    // fit_intercept
                   256,     // batch_size
                   100,     // epochs
                   0,       // lr_type
                   0.01f,   // eta0
                   0.25f,   // power_t
                   0,       // loss
                   2,       // penalty (L2)
                   0.001f,  // alpha
                   0.5f,    // l1_ratio
                   true,    // shuffle
                   1e-4f,   // tol
                   5);      // n_iter_no_change

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment