混淆矩阵不支持multilabel-indicator

from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_cm(labels, predictions):

    cm = confusion_matrix(labels, predictions)
    plt.figure(figsize=(5,5))
    sns.heatmap(cm, annot=True, fmt="d")
    plt.title('Confusion matrix @p')
    plt.ylabel('Actual label')
    plt.xlabel('Predicted label')

test_predictions = test_model.predict(test_features, batch_size=16)
#print(test_predictions)

plot_cm(test_labels,test_predictions)

在生产一个多分类的混淆矩阵时会出现报错:

ValueError: Classification metrics can't handle a mix of multilabel-indicator and continuous-multioutput targets

报错位置是最后一句调用混淆矩阵绘制函数时产生的,test_labels 划分类别时采用的是one-hot编码形式,如下:

[[1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 ....
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 1.]

test_predictions 为预测结果,是概率矩阵的形式。

[[0.869262   0.5665759  0.146371   0.24039719 0.7748723 ]
 [0.9530567  0.5486128  0.31749046 0.01823309 0.59434575]
 [0.9656232  0.4479042  0.2059783  0.04152995 0.5600178 ]
 [0.9739768  0.31491834 0.35822213 0.05632868 0.65389943]
 [0.99412966 0.08985636 0.49540597 0.15879542 0.11264369]
 [0.9549756  0.7419628  0.16873392 0.05887893 0.3303986 ]
 [0.96005356 0.86933553 0.02683157 0.2589711  0.4449162 ]
 [0.97514784 0.59909743 0.10652998 0.10352147 0.5752037 ]
 [0.9747724  0.61836743 0.13382941 0.03205872 0.7293294 ]
 [0.71841097 0.9828553  0.0815956  0.08285439 0.33566174]
 [0.84335256 0.940132   0.11253729 0.06208599 0.27973926]
 [0.7956094  0.9140527  0.11103621 0.11854422 0.34825724]
 [0.6131778  0.98456514 0.1272377  0.06850973 0.40009892]
 [0.65834457 0.98713803 0.09057501 0.15498066 0.2677393 ]
 [0.74587846 0.9386983  0.04109141 0.25518125 0.75702965]
 [0.73277116 0.96753037 0.06807715 0.15403154 0.36670652]

如何更改能运行呢?

在test_labels和test_predictions后面添加.argmax(axis=1)就可以运行了
.argmax(axis=1)相当于转化成为一个十进制的数字,相当于从one-hot的逆编码

plot_cm(test_labels.argmax(axis=1),test_predictions.argmax(axis=1))

print(confusion_matrix(test_labels.argmax(axis=1),test_predictions.argmax(axis=1)))

输出结果混淆矩阵及绘制结果:

[[ 9  0  0  0  0]
 [ 0  9  0  0  0]
 [ 0  0 10  0  0]
 [ 0  0  0 10  0]
 [ 0  0  1  0  9]]

混淆矩阵不支持multilabel-indicator

Original: https://blog.csdn.net/m0_64748541/article/details/124235857
Author: 青椒炒代码
Title: 混淆矩阵不支持multilabel-indicator

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/665930/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球