决策树(Decision Tree)是什么东西呢?它是怎么用于分类的呢?它其实很简单,请看下图。 上图就是一颗决策树,椭圆是判断模块(特征属性),从判断模块引出的左右箭头称作分支,它可以到达另一个判断模块或终止模块(类别值)。上图构造的决策树,根据颜色、价格、大小来判断是否喜欢所选择的礼物。从上图可以看出决策树的数据形式及分类过程很好理解,不像其他分类算法,比如SVM、K最近邻,无法给出数据的内在形式。 决策树构造 决策树用样本的属性作为节点,用属性的取值作为分支的树结构。 决策树方法最早产生于上世纪60年代,到70年代末。由J RossQuinlan提出了ID3算法,此算法的目的在于减少树的深度。但是忽略了叶子数目的研究。C4.5算法在ID3算法的基础上进行了改进,对于预测变量的缺值处理、剪枝技术、派生规则等方面作了较大改进,既适合于分类问题,又适合于回归问题。 决策树算法用构造决策树来发现数据中蕴涵的分类规则。如何构造精度高、规模小的决策树是决策树算法的核心内容。决策树构造可以分两步进行:第一步,决策树的生成,由训练样本集生成决策树的过程;第二步,决策树的剪技,决策树的剪枝是对上一阶段生成的决策树进行检验、校正和修下的过程,主要是用测试数据集校验决策树生成过程中产生的初步规则,将那些影响预衡准确性的分枝剪除。 那么决策树生成过程哪些节点作为根节点,哪些节点作为中间节点呢?中间节点是信息量最大的属性,中间节点是子树所包含样本子集中信息量最大的属性,叶节点是类别值。 ID3算法:(1)计算每个属性的信息增益。将信息增益最大的点作为根节点。 C4.5算法:ID3算法的改进,用信息增益率来选择属性。 用信息增益来选择属性存在一个问题:假设某个属性存在大量的不同值,如ID编号(在上面例子中加一列为ID,编号为a ~ n),在划分时将每个值成为一个结点。那么用这个属性划分后的熵会很小,因为每个概率变小了。就导致信息增益很大。就倾向于选择该属性作为节点。就会导致过拟合。 确定递归建树的停止条件:否则会使节点过多,导致过拟合。 1. 每个子节点只有一种类型的记录时停止,这样会使得节点过多,导致过拟合。 2. 可行方法:当前节点中的记录数低于一个阈值时,那就停止分割。 过拟合原因: (1)噪音数据,某些节点用噪音数据作为了分割标准。 (2)缺少代表性的数据,训练数据没有包含所有具有代表性的数据,导致某类数据无法很好匹配。 (3)还就是上面的停止条件设置不好。 优化方案:剪枝,cross-alidation,randomforest #######下面是使用方法####### Python版: 使用的类: class sklearn.tree.DecisionTreeClassifier(criterion='gini',splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1,min_weight_fraction_leaf=0.0, max_features=None, random_state=None,max_leaf_nodes=None, class_weight=None, presort=False) 参数介绍: criterion: ”gini” or “entropy”(default=”gini”)是计算属性的gini(基尼不纯度)还是entropy(信息增益),来选择最合适的节点。 splitter: ”best” or “random”(default=”best”)随机选择属性还是选择不纯度最大的属性,建议用默认。 max_features: 选择最适属性时划分的特征不能超过此值。 当为整数时,即最大特征数;当为小数时,训练集特征数*小数; if “auto”, then max_features=sqrt(n_features). If “sqrt”, thenmax_features=sqrt(n_features). If “log2”, thenmax_features=log2(n_features). If None, then max_features=n_features. max_depth: (default=None)设置树的最大深度,默认为None,这样建树时,会使每一个叶节点只有一个类别,或是达到min_samples_split。 min_samples_split:根据属性划分节点时,每个划分最少的样本数。 min_samples_leaf:叶子节点最少的样本数。 max_leaf_nodes: (default=None)叶子树的最大样本数。 min_weight_fraction_leaf : (default=0) Theminimum weighted fraction of the input samples required to be at a leaf node. 类中方法: apply(X[,check_input]) Returnsthe index of the leaf that each sample is predicted as. fit(X,y[, sample_weight, check_input, ...]) 拟合训练数据,建立模型。 fit_transform(X[,y]) Fit to data, then transform it. predict(X[,check_input]) 做预测。 predict_proba(X[,check_input]) Predictclass probabilities of the input samples X. predict_log_proba(X) Predict class log-probabilities of the input samples X. score(X,y[, sample_weight]) 返回测试数据的准确率。 get_params([deep]) Get parameters for this estimator. set_params(**params) Set the parameters of this estimator. transform(*args, **kwargs) DEPRECATED: Support to use estimators asfeature selectors will be removed in version 0.19. 使用实例: ########################### R语言版: (1)使用rpart包 > library(rpart) #这里还是使用之前的数据集 #iris3是R自带的数据,是一个三维数组存储的三个地区的数据 #训练数据,我们取iris3前类数据的特征数据 > train =rbind(iris3[1:40,,1],iris3[1:40,,2],iris3[1:40,,3]) #将矩阵转化成数据框 > train = as.data.frame(train) #为训练数据添加上分类属性 > label = factor(c(rep(0,40),rep(1,40),rep(2,40))) > train$label = label #测试数据集 > test =rbind(iris3[41:50,,1],iris3[41:50,,2],iris3[41:50,,3]) > test = as.data.frame(test) > true_class = c(rep(0,10),rep(1,10),rep(2,10)) #拟合数据构建树 > fit = rpart(label ~.,method='class',data = train) #打印出树的信息 > printcp(fit) #查看此树信息 > summary(fit) #画出此树 > plot(fit,uniform=TRUE,main='DecisionTree') > text(fit,use.n=T,all=T,cex=0.8) #画出此树,一种更好看的方式 > post(fit, file = 'c:/tree.ps', title = 'Classification Tree foriris') #剪枝 > pfit = prune(fit,cp =fit$cptable[which.min(fit$cptable[,'xerror']),'CP']) #画出修剪后的树 > plot(pfit, uniform=TRUE, main='PrunedClassification Tree for iris') > text(pfit, use.n=TRUE, all=TRUE, cex=.8) #画到ps格式,更好看些 > post(pfit, file = 'c:/ptree.ps', title= 'Pruned Classification Tree for Kyphosis') #用测试数据集合测试 > p_class =predict(fit,test,type='class') > table(p_class, true_class) (2)使用party包 > library(party) > fit2 = ctree(label ~ .,data=train) #画出的图好看很多(如下图) > plot(fit2,main='Classification Tree foriris') #测试数据 > p_class2 = predict(fit2,test,type='response') > table(p_class2, true_class) #可以看到这个包和上面一个包的准确率是一样的。 参考文章: [1] http://baike.haosou.com/doc/6986193.html [2] http://www.cnblogs.com/bourneli/archive/2013/03/15/2961568.html [3] http://www./advstats/cart.html |
|
来自: web3佬总图书馆 > 《区块链技术、虚拟数字化、万物联网》