CART 分类决策树

1. Cart树简介

Cart模型是一种决策树模型,它即可以用于分类,也可以用于回归,其学习算法分为下面两步:

(1)决策树生成:用训练数据生成决策树,生成树尽可能大

(2)决策树剪枝:基于损失函数最小化的剪枝,用验证数据对生成的数据进行剪枝。

分类和回归树模型采用不同的最优化策略。Cart回归树使用平方误差最小化策略,Cart分类生成树采用的基尼指数最小化策略。

Scikit-learn中有两类决策树,他们均采用优化的Cart决策树算法。一个是DecisionTreeClassifier一个是DecisionTreeRegressor回归。

2. 基尼指数计算公式

3. 基尼指数计算举例

计算过程如下:根据是否有房将目标值划分为两部分:

是否有房,是否有房,Gini(是否有房,yes )=1−(03)2−(33)2=0

是否有房,是否有房,Gini⁡(是否有房,no )=1−(37)2−(47)2=0.4898

是否有房是否有房Gini-⁡index⁡(D, 是否有房 )=710∗0.4898+310∗0=0.343

结婚的基尼值,有 2、4、6、9 共 4 个样本,并且对应目标值全部为 no:

Gini_index⁡(D,{married})=0

不结婚的基尼值,有 1、3、5、7、8、10 共 6 个样本,并且对应 3 个 no,3 个 yes:

Gini_index⁡(D, {single,divorced} )=1−(36)2−(36)2=0.5

以 married 作为分裂点的基尼指数:

Gini_index⁡(D, married )=410∗0+610∗[1−(36)2−(36)2]=0.3

婚姻状况婚姻状况Gini_index⁡(D,婚姻状况)=410∗0.5+610∗[1−(16)2−(56)2]=0.367

婚姻状况婚姻状况Gini_index⁡(D, 婚姻状况 )=210∗0.5+810∗[1−(28)2−(68)2]=0.4

先将数值型属性升序排列,以相邻中间值作为待确定分裂点:

以年收入 65 将样本分为两部分,计算基尼指数:

节点为时年收入节点为时年收入节点为65时:年收入=110∗0+910∗[1−(69)2−(39)2]=0.4

以此类推计算所有分割点的基尼指数,我们发现最小的基尼指数为 0.3。

此时,我们发现:

最小基尼指数有两个分裂点,我们随机选择一个即可,假设婚姻状况,则可确定决策树如下:

重复上面步骤,直到每个叶子结点纯度达到最高.

4. Cart分类树原理

如果目标变量是离散变量,则是classfication Tree分类树。

分类树是使用树结构算法将数据分成离散类的方法。

(1)分类树两个关键点:

将训练样本进行递归地划分自变量空间进行建树‚用验证数据进行剪枝。

(2)对于离散变量X(x1…xn)处理:

分别取X变量各值的不同组合,将其分到树的左枝或右枝,并对不同组合而产生的树,进行评判,找出最佳组合。如果只有两个取值,直接根据这两个值就可以划分树。取值多于两个的情况就复杂一些了,如变量年纪,其值有”少年”、”中年”、”老年”,则分别生产{少年,中年}和{老年},{少年、老年}和{中年},{中年,老年}和{少年},这三种组合,最后评判对目标区分最佳的组合。因为CART二分的特性,当训练数据具有两个以上的类别,CART需考虑将目标类别合并成两个超类别,这个过程称为双化。这里可以说一个公式,n个属性,可以分出(2^n-2)/2种情况。

CART树生成

输入:数据集 D ,特征 A ,样本个数阈值、基尼系数阈值

输出:CART决策树T

(1)对于当前节点的数据集为D,如果样本个数小于阈值或者没有特征,则返回决策子树,当前节点停止递归;

(2)计算样本集D的基尼系数,如果基尼系数小于阈值,则返回决策树子树,当前节点停止递归;

(3)计算当前节点现有的各个特征的各个特征值对数据集D的基尼系数;

(4)在计算出来的各个特征的各个特征值对数据集D的基尼系数中,选择基尼系数最小的特征A和对应的特征值α。根据这个最优特征和最优特征值,把数据集划分成两部分D1和,D2同时建立当前节点的左右节点,左节点的数据集D为D1,右节点的数据集D为D2;

(5)对左右的子节点递归的调用前面1-4步,生成决策树。

CART树剪枝

我们知道,决策树算法对训练集很容易过拟合,导致泛化能力很差,为解决此问题,需要对CART树进行剪枝。CART剪枝算法从”完全生长”的决策树的底端剪去一些子树,使决策树变小,从而能够对未知数据有更准确的预测,也就是说CART使用的是后剪枝法。一般分为两步:先生成决策树,产生所有可能的剪枝后的CART树,然后使用交叉验证来检验各种剪枝的效果,最后选择泛化能力好的剪枝策略。

import numpy as np
import matplotlib.pyplot as plt

from sklearn import datasets

iris = datasets.load_iris()
X = iris.data[:,2:]
y = iris.target

from sklearn.tree import DecisionTreeClassifier

#注意:此处传入的是"gini"而不是"entropy",默认criterion='gini'
tree = DecisionTreeClassifier(max_depth=2,criterion="gini")
tree.fit(X,y)

def plot_decision_boundary(model,axis):
    x0,x1 = np.meshgrid(
        np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1),
        np.linspace(axis[2],axis[3],int((axis[3]-axis[2])*100)).reshape(-1,1)
    )
    X_new = np.c_[x0.ravel(),x1.ravel()]
    y_predict = model.predict(X_new)
    zz = y_predict.reshape(x0.shape)

    from matplotlib.colors import ListedColormap
    custom_map = ListedColormap(["#EF9A9A","#FFF59D","#90CAF9"])

    plt.contourf(x0,x1,zz,linewidth=5,cmap=custom_map)

plot_decision_boundary(tree,axis=[0.5,7.5,0,3])
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
plt.scatter(X[y==2,0],X[y==2,1])
plt.show()

分析上图可知:

  • X[1]
  • X[1]>0.8的,依据 X[1]

Original: https://blog.csdn.net/weixin_46556352/article/details/123924545
Author: AI耽误的大厨
Title: CART 分类决策树

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/618755/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球