How to check the assumptions of a linear model
Description
If you plan to use a linear model to describe some data, it’s important to check if it satisfies the assumptions for linear regression. How can we do that?
Using NumPy, SciPy, sklearn, Matplotlib and Seaborn, in Python
When performing a linear regression, the following assumptions should be checked.
1. We have two or more columns of numerical data of the same length.
The solution below uses an example dataset about car design and fuel consumption from a 1974 Motor Trend magazine. (See how to quickly load some sample data.) We can see that our columns all have the same length.
1
2
3
4
from rdatasets import data
df = data('mtcars')
df = df[['mpg','cyl','wt']] # Select the 3 variables we're interested in
df.info()
1
2
3
4
5
6
7
8
9
10
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 32 entries, 0 to 31
Data columns (total 3 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 mpg 32 non-null float64
1 cyl 32 non-null int64
2 wt 32 non-null float64
dtypes: float64(2), int64(1)
memory usage: 900.0 bytes
2. Scatter plots we’ve made suggest a linear relationship.
Scatterplots are covererd in how to create basic plots, but after making the model, we can also examine the residuals.
So let’s make the model. Our predictors will be the number of cylinders and the weight of the car and the response will be miles per gallon. (See also how to fit a linear model to two columns of data.)
1
2
3
4
5
6
7
8
from sklearn.linear_model import LinearRegression
model = LinearRegression()
predictors = df[['cyl','wt']]
response = df['mpg']
model.fit( X=predictors, y=response )
predictions = model.predict(predictors)
We test for linearity with residual plots. We show just one residual plot here; you should make one for each predictor. Seaborn has a function for just this purpose. (See also how to compute the residuals of a linear model.)
1
2
3
4
5
6
7
import seaborn as sns
import matplotlib.pyplot as plt
# The "lowess" parameter adds a smooth line through the data:
sns.residplot(x = df['wt'], y = response, data=df, lowess=True)
plt.xlabel("Weight")
plt.title('Miles per gallon')
plt.show()
3. After making the model, the residuals seem normally distributed.
We can check this by constructing a QQ-plot, which compares the distribution of the residuals to a normal distribution. Here we use SciPy, but there are other methods; see how to create a QQ-plot.
1
2
3
4
5
from scipy import stats
residuals = response - predictions # Compute the residuals
stats.probplot(residuals, dist="norm", plot=plt)
plt.title("Normal Q-Q Plot")
plt.show()
4. After making the model, the residuals seem homoscedastic.
This assumption is sometimes called “equal variance,” and can be checked by the regplot
function in Seaborn. We must first standardize the residuals, which we can do with NumPy. We want to see a plot with no clear pattern; a cone shape to the data would indicate heteroscedasticity, the opposite of homoscedasticity.
1
2
3
4
5
6
7
import numpy as np
standardized_residuals = np.sqrt(np.abs(residuals))
sns.regplot(x = predictions, y = standardized_residuals, scatter=True, lowess=True)
plt.ylabel("Standarized residuals")
plt.xlabel("Fitted value")
plt.title("Scale-Location")
plt.show()
Content last modified on 24 July 2023.
See a problem? Tell us or edit the source.
Solution, in R
When performing a linear regression, the following assumptions should be checked.
1. We have two or more columns of numerical data of the same length.
The solution below uses an example dataset about car design and fuel consumption from a 1974 Motor Trend magazine. (See how to quickly load some sample data.) We can see that our columns all have the same length.
1
2
df <- mtcars
head(df)
1
2
3
4
5
6
7
mpg cyl disp hp drat wt qsec vs am gear carb
Mazda RX4 21.0 6 160 110 3.90 2.620 16.46 0 1 4 4
Mazda RX4 Wag 21.0 6 160 110 3.90 2.875 17.02 0 1 4 4
Datsun 710 22.8 4 108 93 3.85 2.320 18.61 1 1 4 1
Hornet 4 Drive 21.4 6 258 110 3.08 3.215 19.44 1 0 3 1
Hornet Sportabout 18.7 8 360 175 3.15 3.440 17.02 0 0 3 2
Valiant 18.1 6 225 105 2.76 3.460 20.22 1 0 3 1
2. Scatter plots we’ve made suggest a linear relationship.
Scatterplots are covererd in how to create basic plots, but after making the model, we can also examine the residuals.
So let’s make the model. Our predictors will be the number of cylinders and the weight of the car and the response will be miles per gallon. (See also how to fit a linear model to two columns of data.)
1
model = lm(mpg~ cyl + wt, data=df)
We test for linearity with residual plots. We show just one residual plot here; you should make one for each predictor. R’s plot function knows how to create residual plots. (See also how to compute the residuals of a linear model.)
1
plot(model, which = 1)
3. After making the model, the residuals seem normally distributed.
We can check this by constructing a QQ-plot, which compares the distribution of the residuals to a normal distribution. Here we use SciPy, but there are other methods; see how to create a QQ-plot.
1
plot(model, which = 2)
4. After making the model, the residuals seem homoscedastic.
This assumption is sometimes called “equal variance,” and can be checked by the regplot
function in Seaborn. We must first standardize the residuals, which we can do with NumPy. We want to see a plot with no clear pattern; a cone shape to the data would indicate heteroscedasticity, the opposite of homoscedasticity.
1
plot(model, which = 3) # assumption of equal variance
Content last modified on 24 July 2023.
See a problem? Tell us or edit the source.
Topics that include this task
Opportunities
This website does not yet contain a solution for this task in any of the following software packages.
- Excel
- Julia
If you can contribute a solution using any of these pieces of software, see our Contributing page for how to help extend this website.