这里所使用的方法为Python中基于博弈论的SHAP分析方法。SHAP方法是解释性AI(XAI)领域的一个重大进步。最近发表在nature communications上的一篇重量级文章《Explainable artificial intelligence model to predict acute critical illness from electronic health records》正是使用了这种方法。简而言之,它都说明了每个输入变量对每个最终估算结果值的促进或者抑制作用。关于SHAP方法的具体信息您可以在https://shap.readthedocs.io/en/latest/找到,本文中也会对使用到的一些基础方法进行解释。
首先我们简单的构建一个训练模型,这个模型可以是您已经构建完成的模型:
#生成随机数据 import numpy as np from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, n_classes=2, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
#训练神经网络模型 import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense model = Sequential([ Dense(128, activation='relu', input_shape=(X_train.shape[1],)), Dense(64, activation='relu'), Dense(1, activation='sigmoid') ])
# Evaluate the model loss, accuracy = model.evaluate(X_test, y_test, verbose=0) print(f"Test Loss: {loss}\nTest Accuracy: {accuracy}")
# Plot summary of SHAP values # Ensure this matches the structure of your shap_values correct_shap_values = shap_values[1]
print("Shape of SHAP values:", np.array(correct_shap_values).shape) print("Shape of features:", X_test[:10].shape)
shap_values_output = explainer.shap_values(X_test[:10]) # For binary classification or single-output models, shap_values_output should be a list with one or two elements print(type(shap_values_output)) print([np.array(values).shape for values in shap_values_output])
# Assuming shap_values_output correctly contains your SHAP values for plotting correct_shap_values = shap_values_output[0] # Adjust based on your model's specifics
# Convert list of arrays to a single array reshaped_shap_values = np.concatenate([np.array(vals).reshape(1, -1) for vals in shap_values_output], axis=0)
# Now reshaped_shap_values should have the shape (10, 20), matching X_test[:10] print("Reshaped SHAP values shape:", reshaped_shap_values.shape)
# Attempt to plot with the reshaped SHAP values shap.summary_plot(reshaped_shap_values, X_test[:10], feature_names=[f'Feature {i}'for i in range(X.shape[1])])