更新:2024/12/05
【scikit-learn】ラッソ回帰の意味と使い方について


はるか
ラッソ回帰…線形回帰の一種で特徴量を選択できる。自動で不要なものを削除。

ふゅか
特徴量をゼロにして、モデルをスリムにするのがポイントね!
目次
1. ラッソ回帰
1.1. ラッソ回帰とは?
ラッソ回帰(Lasso Regression)は、線形回帰モデルの1つで、データの特徴量選択を自動的に行うことができる点が特徴です。具体的には、モデルの精度を保ちながら、重要でない特徴量の重みをゼロにして削除します。
1.2. ラッソ回帰の数式
ラッソ回帰は、線形回帰に正則化項(ペナルティ)を追加したものです。損失関数は次のようになります。
\[ L(\beta) = \frac{1}{2n} \sum_{i=1}^{n} (y_i – \hat{y}_i)^2 + \alpha \sum_{j=1}^{p} |\beta_j| \]
- \( y_i \): 実際の値
- \( \hat{y}_i \): 線形回帰による予測値
- \( \beta_j \): 回帰係数
- \( \alpha \): 正則化強度を調整するハイパーパラメータ
- $p$:特徴量の総数
正則化項 \( \alpha \sum_{j=1}^{p} |\beta_j| \) が追加されることで、特徴量の一部がゼロになる(不要な特徴量が削除される)効果を持ちます。
2. Scikit-learnでラッソ回帰を実装
Scikit-learnのLasso
クラスを使うと、ラッソ回帰を簡単に実装できます。

はるか
Scikit-learnを使えば簡単に試せる。例えば
Lasso
クラス。
ふゅか
今回は糖尿病のデータセットを使ってみよう!目的変数は1年後の疾患進行度らしいね!
2.1. 必要なライブラリのインポート
まずはライブラリをインポートします。
from sklearn.linear_model import Lasso
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.datasets import load_diabetes
2.2. データの準備
load_diabetes
データセットを使用します。このデータセットは、糖尿病に関する10個の特徴量を持つ回帰問題のデータです。
# データの読み込み
data = load_diabetes()
X = data.data # 特徴量
y = data.target # 目的変数
# データ分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
2.3. ラッソ回帰モデルの構築と学習
alpha
を調整することで正則化の強度を変えることができます。
# ラッソ回帰モデルの作成
lasso = Lasso(alpha=0.1) # alphaは正則化強度
lasso.fit(X_train, y_train)
# 学習済みモデルで予測
y_pred = lasso.predict(X_test)
2.4. モデルの評価
モデルの精度を確認します。
# 平均二乗誤差(MSE)の計算
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse:.3f}")
2.5. 回帰係数を確認
ラッソ回帰では、一部の特徴量の係数がゼロになっていることがわかります。これがラッソ回帰の特徴であり、不要な特徴量をスパース化(ゼロ化)して特徴量を減らしてくれます。
# 回帰係数を出力
print("回帰係数:", lasso.coef_)
3. ハイパーパラメータ調整
過剰に強い正則化を適用すると、重要な特徴量まで削除される可能性があります。GridSearchCV
などを使って適切な値を見つけるのがおすすめです。
from sklearn.model_selection import GridSearchCV
# alphaの候補値
param_grid = {'alpha': [0.01, 0.1, 1, 10]}
# グリッドサーチ
lasso_cv = GridSearchCV(Lasso(), param_grid, scoring='neg_mean_squared_error', cv=5)
lasso_cv.fit(X_train, y_train)
# 最適なalpha
print(f"Best alpha: {lasso_cv.best_params_['alpha']}")
今回の場合は、best_alphaは0.1になりました。
PR