<-- home

Polynomial Regression In Scikit Learn And Pipeline

scikit-learn 中的多项式回归和Pipeline

import numpy as np
import matplotlib.pyplot as plt
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x ** 2 + x + 2 + np.random.normal(0, 1, 100)
from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(degree=2)
poly.fit(X)
X2 = poly.transform(X)
X2.shape

(100, 3)

X2[:5]

array([[ 1. , -2.42731039, 5.89183574], [ 1. , -1.1892334 , 1.41427607], [ 1. , 2.97919223, 8.87558632], [ 1. , -0.3104026 , 0.09634978], [ 1. , -2.28611994, 5.22634437]])

X[:5]

array([[-2.42731039], [-1.1892334 ], [ 2.97919223], [-0.3104026 ], [-2.28611994]])

from sklearn.linear_model import LinearRegression

lin_reg = LinearRegression()
lin_reg.fit(X2, y)
y_predict = lin_reg.predict(X2)
plt.scatter(x, y)
plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
plt.show()

png

lin_reg.coef_

array([0. , 1.01522347, 0.48684016])

lin_reg.intercept_

1.9568056125014162

关于 PolynomialFeatures

当 degree = 3

如果原特征为

\[x_1, x_2\]

则零次项:

\[1\]

一次项:

\[x_1, x_2\]

二次项:

\[x_1^2, x_2^2, x_1x_2\]

三次项:

\[x_1^3, x_2^3, x_1^2x_2, x_1x_2^2\]

共十项,可以看出随着 degree ,多项式项数指数级增加,功能强大,但会带来一些问题

X = np.arange(1, 11).reshape(-1, 2)
X

array([[ 1, 2], [ 3, 4], [ 5, 6], [ 7, 8], [ 9, 10]])

poly = PolynomialFeatures(degree=3)
poly.fit(X)
X3 = poly.transform(X)
X3

array([[ 1., 1., 2., 1., 2., 4., 1., 2., 4., 8.], [ 1., 3., 4., 9., 12., 16., 27., 36., 48., 64.], [ 1., 5., 6., 25., 30., 36., 125., 150., 180., 216.], [ 1., 7., 8., 49., 56., 64., 343., 392., 448., 512.], [ 1., 9., 10., 81., 90., 100., 729., 810., 900., 1000.]])

Pipeline

将多项式特征、数据归一化、线性回归三步合成一步

x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x ** 2 + x + 2 + np.random.normal(0, 1, 100)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

poly_reg = Pipeline([
    ("poly", PolynomialFeatures(degree=2)),
    ("std_scaler", StandardScaler()),
    ("lin_reg", LinearRegression())
])

poly_reg.fit(X, y)
y_predict = poly_reg.predict(X)
plt.scatter(x, y)
plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
plt.show()

png