This document offers a practical introduction to implicit stochastic gradient descent (ISGD), which is a stable variant to the popular SGD method. In general, SGD is used to fit large-scale statistical and machine learning (ML) models thanks to its computational efficiency, but it can be numerically unstable. ISGD resolves some of these issues but its implementation can be challenging in practice.
Suppose we model the response \(Y\in\mathbb{R}\) through covariates (features) \(X\in\mathbb{R}^p\). We observe data \((x_i, y_i), i=1, \ldots, n\). Let’s use “ordinary least squares” (OLS), i.e., aim to minimize \(l(\theta) = (1/2) \sum_{j=1}^n (y_j - x_j'\theta)^2\). When \(n\) is large, the minimization can be tedious, especially when it involves calculating the Hessian. SGD simplifies the task considerably by instead suggesting the iteration \[ \theta_i = \theta_{i-1} + \gamma_i (y_{I} - x_{I}'\theta_{i-1}) x_{I},~~i >0. \] Here, \(I\sim \mathrm{Unif}(1, \ldots, n)\) is a random datapoint at each iteration \(i\), and, typically, \(\gamma_i=\gamma_1/i > 0\).
Here is one generic implementation of the above SGD procedure. First, we start with code that implements a general stochastic approximation procedure, and a routine for visualizations.
rm(list=ls())
# Generic stochastic approximation for model y = f(x, theta) + noise
#' @param y_arg n-length vector of outcomes.
#' @param X_arg nxp matric of features/covariates.
#' @param update_fn Function that calculates the new value of theta_old given (y_i,x_i,theta_old,learning_rate)
#' @param theta0 p-length vector of initial estimates.
#' @param gamma1 Learning rate constant (scalar)
#' @param npass Number of passes over the data.
#' @param target (optional) p-length vector of target estimate (e.g., ground truth)
#' @param last Set TRUE if sa() should return the last theta estimate. Otherwise it returns the entire sequence.
sa = function(y_arg, X_arg, update_fn, theta0=rep(0, ncol(X_arg)), gamma1=1, npass=1, target=NULL, last=TRUE) {
n = length(y_arg)
theta = matrix(0, nrow=n*npass, ncol=length(theta0)) # path
theta[1,] = theta0 # initialization
colnames(theta) = colnames(X)
for(i in seq(2, npass*n)) {
theta_old = theta[i-1,] # old estimate
I = sample(1:n, size=1) # random datapoint
x_I = X_arg[I,]; y_I = y_arg[I]
lr = gamma1/i # learning rate
theta[i,] = as.numeric(update_fn(y_I, x_I, theta_old, lr)) # new estimate
}
# visualize if we know target values
if(!is.null(target)) { viz(target, theta0, theta) }
if(last) { return(theta[nrow(theta), ]) } else { return(theta) }
}
# visualization of estimation accuracy
viz = function(target, theta0, theta) {
est = theta[nrow(theta),] # last estimate
plot(est, target, pch=20, cex=3, main="blue=initial, black=final",
xlab="sgd estimate") # estimate wrt target
legend("bottomright", legend="y=x line", lty=3, col="red")
points(theta0, target, pch=20, cex=2, col="blue") # initial point wrt target
abline(0, 1, col="red", lty=3)
arrows(theta0, target, est, target, lty=2, col="blue", code=2, length=0.15)
}
The SGD method in the OLS model is then implemented by using sa
with the following update:
sgd_update = function(y_I, x_I, theta_old, lr) {
y_hat = sum(x_I*theta_old) # predicted value of y
theta_old + lr*(y_I - y_hat)*x_I # ols sgd
}
Despite its simplicity, SGD can be unstable as it can be very sensitive to all its three hyperparameters:
theta0
, the starting point;gamma1
, the learning rate andnpass
, the number of passes over the data.Let’s investigate this through a data example.
set.seed(41100)
attach(read.csv("https://raw.githubusercontent.com/ptoulis/datasets/main/hubble.csv"))
X = cbind(1, distance) # covariates
y = velocity # response
Here, we model the speed of galaxies (y
) moving away from earth and their distance (X
). Hubble’s law suggests that these two quantities are proportional to each other. So, the OLS model above is a good choice. Using lm
we calculate
fit = lm(y ~ X+0)
summary(fit)
##
## Call:
## lm(formula = y ~ X + 0)
##
## Residuals:
## Min 1Q Median 3Q Max
## -735.14 -132.92 -22.57 171.74 558.61
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## X 6.696 126.557 0.053 0.958
## Xdistance 76.127 9.493 8.019 5.68e-08 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 264.7 on 22 degrees of freedom
## Multiple R-squared: 0.9419, Adjusted R-squared: 0.9367
## F-statistic: 178.5 on 2 and 22 DF, p-value: 2.528e-14
b_ols = coef(fit)
The “true” slope is 76.127 and is highly significant. Let’s see what we get with SGD if we let all parameters at their default value:
sa(y, X, sgd_update, target=b_ols)
## distance
## 1.169263e+22 1.770883e+23
We see that SGD completely diverged! Let’s see whether we can address this issue by increasing the number of passes over the data:
sa(y, X, sgd_update, npass=100, target=b_ols)
## distance
## 3.196091e+36 -1.890527e+35
That didn’t help. It is clear that the learning rate is too large. We can try to lower the learning rate:
sa(y, X, sgd_update, gamma1=0.01, target=b_ols)
## distance
## 4.775514 75.139581
Much better. We see that the SGD estimates align with the target estimates. Increasing the learning rate a bit more, however, breaks SGD again:
sa(y, X, sgd_update, gamma=0.1, npass=100, target=b_ols)
## distance
## -37976.597 2626.121
We see that SGD is extremely sensitive to its hyperparameters. This makes its application quite challenging, which unsettles many users.
ISGD can help!
ISGD is a variant of SGD that is numerically much more stable. The idea is similar to optimization methods that backtrack when the update is too large. The really cool aspect of ISGD is that it automatically calculates the amount of backtracking that is required.
Roughly speaking, you obtain ISGD by substituting \(\theta_{i-1}\) with \(\theta_i\) in the likelihood portion of the SGD update. In the context of OLS, this means \[ \theta_i = \theta_{i-1} + \gamma_i (y_{I} - x_{I}'\theta_{i}) x_{I},~~i >0. \] Note that \(\theta_i\) now appears on both sides, hence the update is implicit. Solving for \(\theta_i\) we get \[ \theta_i = \theta_{i-1} + \frac{\gamma_i}{1+\gamma_i||x_I||^2} (y_{I} - x_{I}'\theta_{i-1}) x_{I},~~i >0. \] This is also known as “normalized least mean squares” filter (NLMS) in signal processing. Here is a quick implementation of this idea in the OLS model:
isgd_update = function(y_I, x_I, theta_old, lr) {
y_hat = sum(x_I*theta_old) # predicted value of y
im_ftr = lr/ (1+lr*sum(x_I^2)) # specific to ISGD
theta_old + im_ftr*(y_I - y_hat)*x_I # ols sgd
}
sa(y, X, isgd_update, target=b_ols)
## distance
## -8.097782 83.287398
We see that with the default parameters ISGD does not diverge. Recall that with the same exact parameters standard SGD diverged. ISGD is much more stable and, of course, more accurate!
Let’s try to increase npass
to improve our convergence.
sa(y, X, isgd_update, npass=200, target=b_ols)
## distance
## -7.034197 76.071580
Again, we notice that ISGD is remarkably stable around the target value, and we improved the slope. However, the intercept estimate still seems off, but there is an explanation. Notice in the original lm()
output that the intercept has a huge standard error for its point value (t-value=0.053). In fact, the confidence interval for the intercept is
confint(fit)[1,]
## 2.5 % 97.5 %
## -255.7670 269.1595
For such parameters with high standard errors, it is not surprising for SGD estimates to substantially differ from lm()
.
Let’s see specifically how the ISGD estimate for the slope changes over iterations:
out = sa(y, X, isgd_update, npass=200, last=F)
plot(out[,2], type="b",pch=20, cex=0.2)
abline(h=b_ols[2], col="blue")
We see that ISGD is bouncing around the target value, suggesting that the learning rate is too high and does not allow ISGD to settle. This is actually a feature of ISGD: It is very stable but we still need to tune the learning rate to get optimal performance. What we mostly gain from ISGD is that we don’t have to worry about numerical instabilities or divergence!
Let’s try again with a smaller learning rate:
out = sa(y, X, isgd_update, gamma1=0.1, npass=200, last=F)
plot(out[,2], type="b",pch=20, cex=0.2)
abline(h=b_ols[2], col="blue")
Much better! Recall that this configuration was actually quite bad for SGD.
The following questions remain:
These question are quite subtle and require additional discussion. I hope to get back to these soon.