1. Introduction

This document offers a practical introduction to implicit stochastic gradient descent (ISGD), which is a stable variant to the popular SGD method. In general, SGD is used to fit large-scale statistical and machine learning (ML) models thanks to its computational efficiency, but it can be numerically unstable. ISGD resolves some of these issues but its implementation can be challenging in practice.

2. Basic idea of SGD

Suppose we model the response \(Y\in\mathbb{R}\) through covariates (features) \(X\in\mathbb{R}^p\). We observe data \((x_i, y_i), i=1, \ldots, n\). Let’s use “ordinary least squares” (OLS), i.e., aim to minimize \(l(\theta) = (1/2) \sum_{j=1}^n (y_j - x_j'\theta)^2\). When \(n\) is large, the minimization can be tedious, especially when it involves calculating the Hessian. SGD simplifies the task considerably by instead suggesting the iteration \[ \theta_i = \theta_{i-1} + \gamma_i (y_{I} - x_{I}'\theta_{i-1}) x_{I},~~i >0. \] Here, \(I\sim \mathrm{Unif}(1, \ldots, n)\) is a random datapoint at each iteration \(i\), and, typically, \(\gamma_i=\gamma_1/i > 0\).

Here is one generic implementation of the above SGD procedure. First, we start with code that implements a general stochastic approximation procedure, and a routine for visualizations.

rm(list=ls())
# Generic stochastic approximation for model y = f(x, theta) + noise
#' @param y_arg n-length vector of outcomes.
#' @param X_arg nxp matric of features/covariates.
#' @param update_fn Function that calculates the new value of theta_old given (y_i,x_i,theta_old,learning_rate) 
#' @param theta0 p-length vector of initial estimates.
#' @param gamma1 Learning rate constant (scalar)
#' @param npass Number of passes over the data.
#' @param target (optional) p-length vector of target estimate (e.g., ground truth)
#' @param last Set TRUE if sa() should return the last theta estimate. Otherwise it returns the entire sequence.
sa = function(y_arg, X_arg, update_fn, theta0=rep(0, ncol(X_arg)), gamma1=1, npass=1, target=NULL, last=TRUE) {
  n = length(y_arg)
  theta = matrix(0, nrow=n*npass, ncol=length(theta0))  # path
  theta[1,] = theta0    # initialization
  colnames(theta) = colnames(X)
  for(i in seq(2, npass*n)) {
    theta_old = theta[i-1,]     # old estimate
    I = sample(1:n, size=1)       # random datapoint
    x_I = X_arg[I,]; y_I = y_arg[I]
    lr = gamma1/i   # learning rate
    
    theta[i,] = as.numeric(update_fn(y_I, x_I, theta_old, lr)) # new estimate
  }
  
  # visualize if we know target values 
  if(!is.null(target)) { viz(target, theta0, theta) }
  
  if(last) { return(theta[nrow(theta), ]) } else { return(theta) }
}

# visualization of estimation accuracy
viz = function(target, theta0, theta) {
  est = theta[nrow(theta),] # last estimate
  plot(est, target, pch=20, cex=3, main="blue=initial, black=final",
       xlab="sgd estimate") # estimate wrt target
  legend("bottomright", legend="y=x line", lty=3, col="red")
  points(theta0, target, pch=20, cex=2, col="blue") # initial point wrt target
  abline(0, 1, col="red", lty=3)
  arrows(theta0, target, est, target, lty=2, col="blue", code=2, length=0.15)
}

The SGD method in the OLS model is then implemented by using sa with the following update:

sgd_update = function(y_I, x_I, theta_old, lr) {
  y_hat = sum(x_I*theta_old)  # predicted value of y
  theta_old + lr*(y_I - y_hat)*x_I  # ols sgd
} 

Despite its simplicity, SGD can be unstable as it can be very sensitive to all its three hyperparameters:

Let’s investigate this through a data example.

3. Data example

set.seed(41100)
attach(read.csv("https://raw.githubusercontent.com/ptoulis/datasets/main/hubble.csv"))
X = cbind(1, distance) # covariates
y = velocity    # response 

Here, we model the speed of galaxies (y) moving away from earth and their distance (X). Hubble’s law suggests that these two quantities are proportional to each other. So, the OLS model above is a good choice. Using lm we calculate

fit = lm(y ~ X+0)
summary(fit)
## 
## Call:
## lm(formula = y ~ X + 0)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -735.14 -132.92  -22.57  171.74  558.61 
## 
## Coefficients:
##           Estimate Std. Error t value Pr(>|t|)    
## X            6.696    126.557   0.053    0.958    
## Xdistance   76.127      9.493   8.019 5.68e-08 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 264.7 on 22 degrees of freedom
## Multiple R-squared:  0.9419, Adjusted R-squared:  0.9367 
## F-statistic: 178.5 on 2 and 22 DF,  p-value: 2.528e-14
b_ols = coef(fit)

The “true” slope is 76.127 and is highly significant. Let’s see what we get with SGD if we let all parameters at their default value:

sa(y, X, sgd_update, target=b_ols)

##                  distance 
## 1.169263e+22 1.770883e+23

We see that SGD completely diverged! Let’s see whether we can address this issue by increasing the number of passes over the data:

sa(y, X, sgd_update, npass=100, target=b_ols)

##                    distance 
##  3.196091e+36 -1.890527e+35

That didn’t help. It is clear that the learning rate is too large. We can try to lower the learning rate:

sa(y, X, sgd_update, gamma1=0.01, target=b_ols)

##            distance 
##  4.775514 75.139581

Much better. We see that the SGD estimates align with the target estimates. Increasing the learning rate a bit more, however, breaks SGD again:

sa(y, X, sgd_update, gamma=0.1, npass=100, target=b_ols)

##              distance 
## -37976.597   2626.121

We see that SGD is extremely sensitive to its hyperparameters. This makes its application quite challenging, which unsettles many users.

ISGD can help!

4. Implicit SGD

ISGD is a variant of SGD that is numerically much more stable. The idea is similar to optimization methods that backtrack when the update is too large. The really cool aspect of ISGD is that it automatically calculates the amount of backtracking that is required.

Roughly speaking, you obtain ISGD by substituting \(\theta_{i-1}\) with \(\theta_i\) in the likelihood portion of the SGD update. In the context of OLS, this means \[ \theta_i = \theta_{i-1} + \gamma_i (y_{I} - x_{I}'\theta_{i}) x_{I},~~i >0. \] Note that \(\theta_i\) now appears on both sides, hence the update is implicit. Solving for \(\theta_i\) we get \[ \theta_i = \theta_{i-1} + \frac{\gamma_i}{1+\gamma_i||x_I||^2} (y_{I} - x_{I}'\theta_{i-1}) x_{I},~~i >0. \] This is also known as “normalized least mean squares” filter (NLMS) in signal processing. Here is a quick implementation of this idea in the OLS model:

isgd_update = function(y_I, x_I, theta_old, lr) {
  y_hat = sum(x_I*theta_old)  # predicted value of y
  im_ftr = lr/ (1+lr*sum(x_I^2))  # specific to ISGD
  theta_old + im_ftr*(y_I - y_hat)*x_I  # ols sgd
} 
sa(y, X, isgd_update, target=b_ols)

##            distance 
## -8.097782 83.287398

We see that with the default parameters ISGD does not diverge. Recall that with the same exact parameters standard SGD diverged. ISGD is much more stable and, of course, more accurate!

Let’s try to increase npass to improve our convergence.

sa(y, X, isgd_update, npass=200, target=b_ols)

##            distance 
## -7.034197 76.071580

Again, we notice that ISGD is remarkably stable around the target value, and we improved the slope. However, the intercept estimate still seems off, but there is an explanation. Notice in the original lm() output that the intercept has a huge standard error for its point value (t-value=0.053). In fact, the confidence interval for the intercept is

confint(fit)[1,]
##     2.5 %    97.5 % 
## -255.7670  269.1595

For such parameters with high standard errors, it is not surprising for SGD estimates to substantially differ from lm().

Let’s see specifically how the ISGD estimate for the slope changes over iterations:

out = sa(y, X, isgd_update, npass=200, last=F)
plot(out[,2], type="b",pch=20, cex=0.2)
abline(h=b_ols[2], col="blue")

We see that ISGD is bouncing around the target value, suggesting that the learning rate is too high and does not allow ISGD to settle. This is actually a feature of ISGD: It is very stable but we still need to tune the learning rate to get optimal performance. What we mostly gain from ISGD is that we don’t have to worry about numerical instabilities or divergence!

Let’s try again with a smaller learning rate:

out = sa(y, X, isgd_update, gamma1=0.1, npass=200, last=F)
plot(out[,2], type="b",pch=20, cex=0.2)
abline(h=b_ols[2], col="blue")

Much better! Recall that this configuration was actually quite bad for SGD.

5. Summary

The following questions remain:

  1. How to tune the hyperparameters of SGD in practice? What changes for ISGD?
  2. How to implement ISGD in other models, such as logistic regression?
  3. Can we get statistical significance from SGD/ISGD procedures?

These question are quite subtle and require additional discussion. I hope to get back to these soon.

6. Some more reading