The LDA approach is quite similar to Principal Component Analysis (PCA). It's a SUPERVISED type classification algorithm commonly used as a dimensional reduction process. Common applications are to reduce the probability of over-fitting as well as improve computational performance.
We'll leave the math until later. At a high level LDA tries to maximize the ratio of dispersion between groups and within groups. (NB: LDA literature also uses the term class to refer to a group). $$ \text{ ie } max( \frac{ D_{within} }{D_{between}} ) $$
Let's get a working example up and going before diving into the algorithm
# The usual imports
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
%matplotlib inline
from sklearn.datasets import make_blobs
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# Some silly data
X, y = make_blobs(15, 2, centers=3, random_state=2, cluster_std=0.5)
# plot functions
def plot_decision_boundary(pred_func):
""" Helper function to plot a decision boundary.
It simply draws a graph where pred_func is used to seperate the areas
"""
# Set min and max values and give it some padding
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
h = 0.01
# Generate a grid of points with distance h between them
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
# Predict the function value for the whole gid
Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# Plot the contour and training examples
plt.contourf(xx, yy, Z, cmap=plt.cm.Greys)
#plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Greys)
# So far so good I hope
lda = LinearDiscriminantAnalysis(solver="svd", store_covariance=True)
model = lda.fit(X, y)
a = np.linspace(-5, 5, 100)
b = np.linspace(-5, 5, 100)
aa, bb = np.meshgrid(a, b, sparse=True)
#gmesh = np.concatenate((aa,bb),axis=1)
a2 = aa.reshape(100,1)
gmesh = np.concatenate((a2,bb),axis=1)
ypred = model.predict(gmesh)
plot_decision_boundary(model.predict)
#Original points
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='Pastel1' );
# Shading background
a = np.arange(-5, 5, 0.1)
b = np.arange(-5, 5, 0.1)
aa, bb = np.meshgrid(a, b, sparse=True)
c = np.sin(aa**2 + bb**2) / (aa**2 + bb**2)
h = plt.contourf(a,b,c)