第三节识别毒蘑菇
Python的scikit-learn(简称sklearn)模块可以非常好地支持决策树分类,尽管它使用的是更加复杂的决策树算法,但依然是在上一节讲述的原理基础上的优化算法。本节将使用scikit-learn的决策树分类来识别“蘑菇数据”中的毒蘑菇。
首先介绍本案例使用的数据。原始数据来源于加州大学欧文分校用于机器学习的数据库(UCI数据库),本书使用的是经过必要处理的数据,此数据可从教材资源平台下载。
接下来在命令行界面安装scikit-learn。
pipinstall-uscikit-learn
安装scikit-learn之前需要确保模块numpy和scipy已经安装,否则可以使用pipinstall安装。
pipinstall-unumpy
pipinstall-uscipy
进入Python之后,使用如下命令调用scikit-learn的决策树方法。
fromsklearnimporttree
模块scikit-learreeClassifier对象,它能够解决二分类问题(如蘑菇是否可食用),也可以解决多分类问题。在使用中,对输入训练数据的维度要求如下。
输入X:样本数量×特征属性数量;
输入Y:样本类别标签,与X要一一对应。
下面先用一个简单的例子来熟悉使用方法。对四个点X=[[0,0],[0,1],[1,0],[1,1]]使用决策树进行分类,四个点分别属于两个类别0和1,它们对应的类别标签是Y=[0,0,1,1]。这可以通过如下简单代码实现。
In[1]:fromsklearnimporttree
X=[[0,0],[0,1],[1,0],[1,1]]
Y=[0,0,1,1]
clf=tree。DeTreeClassifier()
clf=clf。fit(X,Y)
clf命令用来构造决策树,通过clf。fit(X,Y)实现了决策树的构建,此时clf已经是能够进行分类的决策树了,可以用它来进行新数据的分类。例如,输出[[0。3,0],[0。8,1],[1。2,1]]的类别,代码如下。
I=[[0。3,0],[0。8,1],[1。2,1]]
In[3]:clf。predict(test)
Out[3]:array([0,1,1])
通过上述输出可以看到决策树给出的分类结果是[0。3,0]类别为0;[0。8,1]类别为1;[1。2,1]类别为1。读者可以在平面直角坐标系上画出这些点,看看分类是否合理。
读者还可以尝试下列与决策树相关的函数,看看它们具有什么功能。
In[4]:dir(clf)
Out[4]:
[……
'apply',
'class_weight',
'classes_',
'',
'de_path',
'feature_importances_',
'fit',
'fit_transform',
'get_params',
'max_depth',
'max_features',