top of page

Ordinary Least Square: Closed Form Solution & Gradient Descent

  • Writer: Gianluca Turcatel
    Gianluca Turcatel
  • Nov 9, 2024
  • 6 min read

ree

The Jupyter Notebook for this article can be found HERE.

Ordinary Least Squares (OLS) is a widely used method for estimating the parameters of a linear regression model. We are given some data D:

ree

Where each Xi is a vector of length k, and yi the value of the dependent variable for Xi. In the simplest scenario Xi has dimension 1 and, thus, a scalar. OLS methods seeks to minimize the following objective function (loss):

ree

OLS Visualized

OLS seeks to minimize the Mean Squared Error (MSE) of the model's predictions. The scatter plot below depicts the relationship between the length of the leg (X) of a species of animals and their running speed (y). The relationship between y and X is clearly linear.

ree

Figure 1. Observed data


How do we find the best fitting line, the one that minimizes the loss? Let's start by guessing a fitting line and calculating the MSE.

ree

Figure 2. Observed data, first guess line and MSE


Well, out first guess is pretty bad: the line is not following the data at all and MSE is high (16).


Next, we can propose a better fitting line based on the result of our first attempt:

ree

Figure 3. Observed data, second guess line and MSE


We improved from the first attempt and the second fitting line much better represents the data: the MSE is significantly lower (9) and the line is closer to the data points. Note that the size of the squared residuals is visibly smaller than before. We could continue with this approach and propose a third fitting line, a forth and so on, until the MSE is not changing or the change is lower than a predefined threshold.

As you may have guessed, we are not going to find the best fitting line by blinding testing all the possible fitting lines. This approach is called brute force, and seldomly is the best and smartest way to go. Instead, we are going to implement two methods: closed form solution and gradient descent optimization.


1- Closed Form Solution

We are trying to minimize the following loss function:

ree

Let's rewrite it in matrix notation (1/n was dropped since it is just a number):

ree

Next we expand the square operation:

ree

We can use the following identity:

ree

To modify the loss function into:

ree

Next we expand the expression:

ree

The figure below depicts how the value of the loss function changes with W. The loss function is convex. Let's call W* the value of W for which L(W) is the minimum (Figure 4).

ree

Figure 4. Shape of the loss function

At W* the first derivative of the loss function is zero:

ree

Note that the last term does not contain W, thus:

ree

The derivative of the first term is a little more complicated. In general:

ree

Where A is a symmetric matrix. Let's update the derivative of the loss function by including the derivative of the first term (9) and removing the fourth term (8).

ree

The final trick is to realize that:

ree

Why is that equality in (11) valid? The first two terms are equal to each other because the result of the first term is a scalar, and if you transpose it, you end up with the same scalar. To go from the second term to the third one we applied again (4).

And because:

ree

we end up with:

ree

We get rid of the number 2 and arrange the formula such that:

ree

Let's multiply both sides by:

ree

We get:

ree

We recognize that:

ree

Hence the closed form formula to calculate W is:

ree

Great! Let's use this formula to estimate the intercept and slope of the data in Figure 1. The only caveat is that we need to add a column of ones to the original X, in order to estimate the intercept B0.

ree

The true values of B0 and B1 are 1 and 1.5 respectively.

Finally, let's overlay the fitted line with the original data and calculate the final error:

ree

Figure 5. Closed form solution fitted line

2- Gradient descent

Gradient descent is an optimization method that estimates the parameters of a model by minimizing the loss function in an iteratively manner. At each iteration, gradient descent updates the parameter values by moving in the direction of the steepest descent, as determined by the negative gradient of the loss function with respect to the current parameter values.


ree

Figure 5. Gradient descent intuition


Let's walk through Figure 5. We start the search for the optimal parameter, W*, with an initial guess, W0. The loss at W0 is L(W0). Next, we propose a better value of W such that the loss is smaller. In the figure 5, that means increasing the value of W0 and moving toward the right (if W0 was higher than W*, we would propose a smaller W, thus moving toward the left). However, we don't have Figure 5, so how do we if we have to increase or decrease W? That's where gradient descent comes to the rescue.

The first step of gradient descent is calculating the derivative, a.k.a. gradient, of the loss function at W0, L'(W0). In figure 5 the red lines represent the gradient of the loss function at different values of W. The gradient at W0 is negative because the line points downward and to the right. That tells us that the optimum W* is located on the right, and increasing W would decrease the loss. Conversely, if the gradient is positive, W is greater than W* and we would need to decrease W, moving to the left.


Great! We developed the intuition on how gradient descent works and figured out when increasing or decreasing W. The next step is formalizing our intuition and deciding how much we should decrease (or increase) W of at each iteration. The rule to update W at each iteration is given by the formula:

ree

The value of W at the next iteration (t+1) is equal to the current value of W minus the gradient of the loss w.r.t. W multiplied by the learning rate, lr. The learning rate is an hyperparameter and control the learning process and specifically dictates how fast we are descending the curve. Higher lr means bigger updates of W and faster descents. Values of the learning rate are in the orders of 0.01 to 0.0001 and is often found through trial and error.


It is time to apply gradient descent to find the parameters of the line that best fits the data in Figure 1. However, we first need to derive the derivative of the loss function with respect to the model parameters. Our propose linear model is:

ree

Thus, we have two parameters to estimate, B0 (intercept) and B1 (slope).

The loss function is:

ree

Where yi_hat is the model prediction and yi is the observed data. The derivation of the loss function can be found HERE.

Let's substitute yi_hat with (19).

ree

The gradient of the loss w.r.t. B1 is:

ree

The gradient of the loss w.r.t. B0 is:

ree

Both (22) and (23) were calculated with the chain rule: f(g(x)) = f'(g(x))⋅g'(x).


It is time to code everything out. Let's start with defining two helper functions, one that returns the loss (given predictions and observed data), and one that returns the gradients, given predictions, observed data and X):

ree

Next we deploy gradient descent to estimate the B0 and B1 of the best fitting line:

ree

Note that at each iteration B0 and B1 were updated once, using the whole dataset. This variant of the gradient descent is called batch gradient descent. In the vanilla gradient descent, parameters are updated one data point at the time.

After 300 iteration the estimated B0 and B1 are:

ree

The real values of B0 and B1 are 1 and 1.5, respectively.


Finally, let's plot the loss during training and the final fitted line with the MSE:

ree

Conclusion

Both methods reached the same SME and proposed the same values of B0 and B1. While the closed-form solution can be achieved in a single step, performing matrix operations on large datasets can be computationally expensive. In such situations, we must deploy iterative optimization algorithms like gradient descent. Gradient descent requires tuning the learning rate hyperparameter, which can be time-consuming and challenging.



What's next

In future articles I will introduce other variations of gradient descent.


Follow me on Twitter and Facebook to stay updated.






































Comments


bottom of page