본문 바로가기
étude/AI

Model Explanation : SHAP values

by mummoo 2023. 11. 22.

SHAP (SHapley Additive exPlanations)

: 모든 ML 모델의 결과를 출력하기 위한 게임 이론적인 방식

 

SHAP의 특징

  • 각 feature가 예측값에 어떠한 영향을 주는지를 알려줌
  • 어떠한 feature가 모델 성능에 있어서 가장 중요한지 알려주고, 이게 어떻게 결과에 영향을 주는지 알려줌

게임 이론적인 방식?

  • 각 player의 기여도와 최종 결과물 간의 관계를 측정하는 이론적 접근
  • ML에서는 각 열에 중요도 (importance value)가 할당되고, 이게 모델의 결과물에 기여하는 정도를 알려준다. 

SHAP 범위 

  • SHAP value > 0 : 예측값에 긍정적인 영향을 줌
  • SHAP value < 0 : 예측값의 성능에 긍정적이지 않은 영향을 줌 
  • SHAP value 절댓값 (magnitude = 규모)은 그 영향의 크기를 알려준다

 

Summary Plot : SHAP은 어떻게 생겼을까?

  • x축 : SHAP value (log)
  • y축 : 중요도에 따라 위에서 아래로 feature 이름을 나열

** log odds?

https://tyami.github.io/machine%20learning/machine-learning-1-odds-log-odds/#google_vignette

 

Machine learning: Odds와 Log(Odds)

Machine learning의 기본적인 개념 중 하나인 Odds와 Log(Odds)에 대해 정리해봅시다

tyami.github.io

 

SHAP CODE

import shap -> explainer -> summary_plot -> force_plot 으로 코드 생성 시

 

1. import shap & 모델 train 등등 

from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import shap

# 데이터 불러오는 코드는 알아서 작성

X, y = shap.datasets.california(n_points=1000)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
model = RandomForestRegressor().fit(X_train, y_train)

 

2. explainer

explainer = shap.Explainer(model)
shap_values = explainer.shap_values(X_test)

 

3. summary_plot

shap.summary_plot(shap_values, X_test)

  • 위의 그래프에서 빨간색은 high value, 파란색은 low value를 나타낸다.
  • 각 point는 특정 data의 한 row를 나타냄 

각 feature가 target value와 어떤 상관 관계가 있는지 쉽게 확인할 수 있다. 

Feature와 target 간 상관 관계 해석의 용이성이 summary plot의 장점이다!

 

4. force_plot

개별 data에서 각 feature가 target에 어떤 영향을 줬는지 해석 가능

전체 row를 사용하는 것이 아닌, 한 개의 row만을 사용해 개별 데이터와 모델의 관계를 예측

shap.force_plot(explainer.expected_value, shap_values[20,:], X_val.iloc[20,:])

21번째 row에 속하는 data만을 이용한 코드! 

  • 굵은 글씨 : target값, base value = X_val에 속한 target value의 평균값
  • 빨강이 글씨는 target 변수의 값을 높게 만드는 feature 
  • 파랑이 글씨는 target 변수의 값을 낮게 만드는 feature