from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_cm(labels, predictions):
cm = confusion_matrix(labels, predictions)
plt.figure(figsize=(5,5))
sns.heatmap(cm, annot=True, fmt="d")
plt.title('Confusion matrix @p')
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
test_predictions = test_model.predict(test_features, batch_size=16)
#print(test_predictions)
plot_cm(test_labels,test_predictions)
在生产一个多分类的混淆矩阵时会出现报错:
ValueError: Classification metrics can't handle a mix of multilabel-indicator and continuous-multioutput targets
报错位置是最后一句调用混淆矩阵绘制函数时产生的,test_labels 划分类别时采用的是one-hot编码形式,如下:
[[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
....
[0. 0. 0. 1. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 0. 1.]
[0. 0. 0. 0. 1.]
[0. 0. 0. 0. 1.]
[0. 0. 0. 0. 1.]
test_predictions 为预测结果,是概率矩阵的形式。
[[0.869262 0.5665759 0.146371 0.24039719 0.7748723 ]
[0.9530567 0.5486128 0.31749046 0.01823309 0.59434575]
[0.9656232 0.4479042 0.2059783 0.04152995 0.5600178 ]
[0.9739768 0.31491834 0.35822213 0.05632868 0.65389943]
[0.99412966 0.08985636 0.49540597 0.15879542 0.11264369]
[0.9549756 0.7419628 0.16873392 0.05887893 0.3303986 ]
[0.96005356 0.86933553 0.02683157 0.2589711 0.4449162 ]
[0.97514784 0.59909743 0.10652998 0.10352147 0.5752037 ]
[0.9747724 0.61836743 0.13382941 0.03205872 0.7293294 ]
[0.71841097 0.9828553 0.0815956 0.08285439 0.33566174]
[0.84335256 0.940132 0.11253729 0.06208599 0.27973926]
[0.7956094 0.9140527 0.11103621 0.11854422 0.34825724]
[0.6131778 0.98456514 0.1272377 0.06850973 0.40009892]
[0.65834457 0.98713803 0.09057501 0.15498066 0.2677393 ]
[0.74587846 0.9386983 0.04109141 0.25518125 0.75702965]
[0.73277116 0.96753037 0.06807715 0.15403154 0.36670652]
如何更改能运行呢?
在test_labels和test_predictions后面添加.argmax(axis=1)就可以运行了
.argmax(axis=1)相当于转化成为一个十进制的数字,相当于从one-hot的逆编码
plot_cm(test_labels.argmax(axis=1),test_predictions.argmax(axis=1))
print(confusion_matrix(test_labels.argmax(axis=1),test_predictions.argmax(axis=1)))
输出结果混淆矩阵及绘制结果:
[[ 9 0 0 0 0]
[ 0 9 0 0 0]
[ 0 0 10 0 0]
[ 0 0 0 10 0]
[ 0 0 1 0 9]]
Original: https://blog.csdn.net/m0_64748541/article/details/124235857
Author: 青椒炒代码
Title: 混淆矩阵不支持multilabel-indicator
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/665930/
转载文章受原作者版权保护。转载请注明原作者出处!