学习地址
第 61 天: 决策树 (1. 准备工作)
决策树的生成主要分以下两步,这两步通常通过学习已经知道分类结果的样本来实现。
1.节点的分裂:一般当一个节点所代表的属性无法给出判断时,则选择将这一节点分成2个子节点(如果不是二叉树的情况会分成n个子节点)
2.阈值的确定:选择适当的阈值使得分类错误率最小 (Training Error)。
ID3: 由增熵(Entropy)原理来决定哪个做父节点,哪个节点需要分裂。对于一组数据,熵越小说明分类结果越好。
熵定义如下:
Entropy=- sum [p(xi) * log2(P(xi) ]
其中p(xi) 为xi出现的概率。
假如是2分类问题,当A类和B类各占50%的时候:
Entropy = - (0.5×log2(0.5)+0.5×log2(0.5))= 1
当只有A类,或只有B类的时候:
Entropy= - (1×log2(1)+0)=0
所以当Entropy最大为1的时候,是分类效果最差的状态,当它最小为0的时候,是完全分类的状态。因为熵等于零是理想状态,一般实际情况下,熵介于0和1之间。 熵的不断最小化,实际上就是提高分类正确率的过程。
决策树是最经典的机器学习算法,它有非常好的可解释性。 1.数据仅有一份. 分裂后的数据子集仅需要保存 availableInstances 和 availableAttributes 两个数组. 2.两个构造方法, 一个读入文件获得根节点, 另一个建立根据数据分裂的获得. 3.判断数据集是否纯, 即所有的类标签是否相同, 如果是就不用分裂了. 4.每个节点 (包括非叶节点) 都需要一个标签, 这样, 遇到未见过的属性就可以直接分类了. 为获得该标签, 可以通过投票的方式, 即 getMajorityClass(). 5.最大化信息增益, 与最小化条件信息熵, 两者是等价的. 6.分裂的数据块有可能是空的, 这时使用长度为 0 的数组而不是 null.
代码:
package xjx;
import java.io.FileReader;
import java.util.Arrays;
import weka.core.*;
public class ID3 {
Instances dataset;
boolean pure;
int numClasses;
int[] availableInstances;
int[] availableAttributes;
int splitAttribute;
ID3[] children;
int label;
int[] predicts;
static int smallBlockThreshold = 3;
public ID3(String paraFilename) {
dataset = null;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
System.exit(0);
}
dataset.setClassIndex(dataset.numAttributes() - 1);
numClasses = dataset.classAttribute().numValues();
availableInstances = new int[dataset.numInstances()];
for (int i = 0; i < availableInstances.length; i++) {
availableInstances[i] = i;
}
availableAttributes = new int[dataset.numAttributes() - 1];
for (int i = 0; i < availableAttributes.length; i++) {
availableAttributes[i] = i;
}
children = null;
label = getMajorityClass(availableInstances);
pure = pureJudge(availableInstances);
}
public ID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) {
dataset = paraDataset;
availableInstances = paraAvailableInstances;
availableAttributes = paraAvailableAttributes;
children = null;
label = getMajorityClass(availableInstances);
pure = pureJudge(availableInstances);
}
public boolean pureJudge(int[] paraBlock) {
pure = true;
for (int i = 1; i < paraBlock.length; i++) {
if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0])
.classValue()) {
pure = false;
break;
}
}
return pure;
}
public int getMajorityClass(int[] paraBlock) {
int[] tempClassCounts = new int[dataset.numClasses()];
for (int i = 0; i < paraBlock.length; i++) {
tempClassCounts[(int) dataset.instance(paraBlock[i]).classValue()]++;
}
int resultMajorityClass = -1;
int tempMaxCount = -1;
for (int i = 0; i < tempClassCounts.length; i++) {
if (tempMaxCount < tempClassCounts[i]) {
resultMajorityClass = i;
tempMaxCount = tempClassCounts[i];
}
}
return resultMajorityClass;
}
public int selectBestAttribute() {
splitAttribute = -1;
double tempMinimalEntropy = 10000;
double tempEntropy;
for (int i = 0; i < availableAttributes.length; i++) {
tempEntropy = conditionalEntropy(availableAttributes[i]);
if (tempMinimalEntropy > tempEntropy) {
tempMinimalEntropy = tempEntropy;
splitAttribute = availableAttributes[i];
}
}
return splitAttribute;
}
public double conditionalEntropy(int paraAttribute) {
int tempNumClasses = dataset.numClasses();
int tempNumValues = dataset.attribute(paraAttribute).numValues();
int tempNumInstances = availableInstances.length;
double[] tempValueCounts = new double[tempNumValues];
double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];
int tempClass, tempValue;
for (int i = 0; i < tempNumInstances; i++) {
tempClass = (int) dataset.instance(availableInstances[i]).classValue();
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempValueCounts[tempValue]++;
tempCountMatrix[tempValue][tempClass]++;
}
double resultEntropy = 0;
double tempEntropy, tempFraction;
for (int i = 0; i < tempNumValues; i++) {
if (tempValueCounts[i] == 0) {
continue;
}
tempEntropy = 0;
for (int j = 0; j < tempNumClasses; j++) {
tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
if (tempFraction == 0) {
continue;
}
tempEntropy += -tempFraction * Math.log(tempFraction);
}
resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
}
return resultEntropy;
}
public int[][] splitData(int paraAttribute) {
int tempNumValues = dataset.attribute(paraAttribute).numValues();
int[][] resultBlocks = new int[tempNumValues][];
int[] tempSizes = new int[tempNumValues];
int tempValue;
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempSizes[tempValue]++;
}
for (int i = 0; i < tempNumValues; i++) {
resultBlocks[i] = new int[tempSizes[i]];
}
Arrays.fill(tempSizes, 0);
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
tempSizes[tempValue]++;
}
return resultBlocks;
}
}
|