对Mahout中随机森林的决策树建立过程源码分析

这几天学习了Mahout中的随机森林算法,今天阅读了随机森林中决策树的建立过程,主要是DecisionTreeBuilder.java中的代码。具体的分析过程我写在了源码的注释中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.df.builder;

import com.google.common.collect.Sets;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.data.conditions.Condition;
import org.apache.mahout.classifier.df.node.CategoricalNode;
import org.apache.mahout.classifier.df.node.Leaf;
import org.apache.mahout.classifier.df.node.Node;
import org.apache.mahout.classifier.df.node.NumericalNode;
import org.apache.mahout.classifier.df.split.IgSplit;
import org.apache.mahout.classifier.df.split.OptIgSplit;
import org.apache.mahout.classifier.df.split.RegressionSplit;
import org.apache.mahout.classifier.df.split.Split;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;
import java.util.Random;

/**
* Builds a classification tree or regression tree<br>
* A classification tree is built when the criterion variable is the categorical attribute.<br>
* A regression tree is built when the criterion variable is the numerical attribute.
*/
public class DecisionTreeBuilder implements TreeBuilder {

private static final Logger log = LoggerFactory.getLogger(DecisionTreeBuilder.class);

private static final int[] NO_ATTRIBUTES = new int[0];
private static final double EPSILON = 1.0e-6;

/**
* indicates which CATEGORICAL attributes have already been selected in the parent nodes
*/
private boolean[] selected;//表明哪些分类属性已经被选中父节点
/**
* number of attributes to select randomly at each node
*/
private int m;//在每个节点随机选择的属性个数
/**
* IgSplit implementation
*/
private IgSplit igSplit;//使用信息增益度量计算的最佳分裂
/**
* tree is complemented
*/
private boolean complemented = true;//树的补充
/**
* minimum number for split
*/
private double minSplitNum = 2.0;//分割最小数量
/**
* minimum proportion of the total variance for split
*/
private double minVarianceProportion = 1.0e-3;//最低总方差的比例分割
/**
* full set data
*/
private Data fullSet;//数据集
/**
* minimum variance for split
*/
private double minVariance = Double.NaN;//最小方差分割

public void setM(int m) {//设置m属性
this.m = m;
}

public void setIgSplit(IgSplit igSplit) {//设置最佳分裂
this.igSplit = igSplit;
}

public void setComplemented(boolean complemented) {
this.complemented = complemented;
}

public void setMinSplitNum(int minSplitNum) {
this.minSplitNum = minSplitNum;
}

public void setMinVarianceProportion(double minVarianceProportion) {
this.minVarianceProportion = minVarianceProportion;
}

@Override
public Node build(Random rng, Data data) {//建树
if (selected == null) {//初始化selected
selected = new boolean[data.getDataset().nbAttributes()];//按照数据集中属性的个数初始化
selected[data.getDataset().getLabelId()] = true; // never select the label永远也不选择标签
}
if (m == 0) {//设置默认的m
// set default m
double e = data.getDataset().nbAttributes() - 1;//数据集中属性个数减一
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
// regression
m = (int) Math.ceil(e / 3.0);
} else {
// classification
m = (int) Math.ceil(Math.sqrt(e));//m=Math.ceil(Math.sqrt(e))
}
}

if (data.isEmpty()) {//如果数据是空,新建一个空的叶子节点
return new Leaf(Double.NaN);
}

double sum = 0.0;//和
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
// regression
// sum and sum squared of a label is computed
double sumSquared = 0.0;
for (int i = 0; i < data.size(); i++) {
double label = data.getDataset().getLabel(data.get(i));
sum += label;
sumSquared += label * label;
}

// computes the variance
double var = sumSquared - (sum * sum) / data.size();

// computes the minimum variance计算最小方差
if (Double.compare(minVariance, Double.NaN) == 0) {
minVariance = var / data.size() * minVarianceProportion;
log.debug("minVariance:{}", minVariance);
}

// variance is compared with minimum variance//计算最小方差
if ((var / data.size()) < minVariance) {
log.debug("variance({}) < minVariance({}) Leaf({})", var / data.size(), minVariance, sum / data.size());
return new Leaf(sum / data.size());
}
} else {
// classification
if (isIdentical(data)) {//检查所有的向量是否相同
return new Leaf(data.majorityLabel(rng));//如果相同,返回一个叶子节点
}
if (data.identicalLabel()) {//检查是否所有的向量有相同的标签值
return new Leaf(data.getDataset().getLabel(data.get(0)));//如果有,返回一个叶子节点
}
}

// store full set data
if (fullSet == null) {
fullSet = data;
}

int[] attributes = randomAttributes(rng, selected, m);//随机选择m个属性值
if (attributes == null || attributes.length == 0) {//如果没有选出来属性值,就代表我们尝试过所有的属性,并且不能再分了
// we tried all the attributes and could not split the data anymore
double label;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
// regression
label = sum / data.size();
} else {
// classification
label = data.majorityLabel(rng);
}
log.warn("attribute which can be selected is not found Leaf({})", label);
return new Leaf(label);//返回一个叶子节点
}

if (igSplit == null) {//初始化igsplit
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
// regression
igSplit = new RegressionSplit();
} else {
// classification
igSplit = new OptIgSplit();//优化的实现IgSplit这个类可以使用标准变量时分类属性。
}
}

// find the best split寻找最好的分割点
Split best = null;
for (int attr : attributes) {//根据之前随机选择出来的属性循环
Split split = igSplit.computeSplit(data, attr);//根据属性计算最好的分割
if (best == null || best.getIg() < split.getIg()) {//如果best为空或者best的信息增益小于split的信息增益
best = split;//赋值
}
}

// information gain is near to zero.
if (best.getIg() < EPSILON) {//best的信息增益接近于零
double label;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
label = sum / data.size();
} else {
label = data.majorityLabel(rng);
}
log.debug("ig is near to zero Leaf({})", label);
return new Leaf(label);//返回叶子节点
}

log.debug("best split attr:{}, split:{}, ig:{}", best.getAttr(), best.getSplit(), best.getIg());

boolean alreadySelected = selected[best.getAttr()];
if (alreadySelected) {
// attribute already selected
log.warn("attribute {} already selected in a parent node", best.getAttr());
}

Node childNode;//新建child节点
if (data.getDataset().isNumerical(best.getAttr())) {//选出的属性为数值属性
boolean[] temp = null;

Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));//小于选出的属性值的实例放到左子树
Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));//大于选出属性值的实例放到右子树

if (loSubset.isEmpty() || hiSubset.isEmpty()) {
// the selected attribute did not change the data, avoid using it in the child notes
selected[best.getAttr()] = true;
} else {
// the data changed, so we can unselect all previousely selected NUMERICAL attributes
temp = selected;
selected = cloneCategoricalAttributes(data.getDataset(), selected);
}

// size of the subset is less than the minSpitNum
if (loSubset.size() < minSplitNum || hiSubset.size() < minSplitNum) {//如果左右子树都小于最小分裂数,这里是2
// branch is not split不在分裂
double label;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
label = sum / data.size();
} else {
label = data.majorityLabel(rng);
}
log.debug("branch is not split Leaf({})", label);
return new Leaf(label);
}

Node loChild = build(rng, loSubset);//建立左子树
Node hiChild = build(rng, hiSubset);//建立右子树

// restore the selection state of the attributes
if (temp != null) {
selected = temp;
} else {
selected[best.getAttr()] = alreadySelected;
}

childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);//建立一个使用数值属性的节点
} else { // CATEGORICAL attribute
double[] values = data.values(best.getAttr());

// tree is complemented
Collection<Double> subsetValues = null;
if (complemented) {
subsetValues = Sets.newHashSet();
for (double value : values) {
subsetValues.add(value);
}
values = fullSet.values(best.getAttr());
}

int cnt = 0;
Data[] subsets = new Data[values.length];
for (int index = 0; index < values.length; index++) {
if (complemented && !subsetValues.contains(values[index])) {
continue;
}
subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index]));
if (subsets[index].size() >= minSplitNum) {
cnt++;
}
}

// size of the subset is less than the minSpitNum
if (cnt < 2) {
// branch is not split
double label;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
label = sum / data.size();
} else {
label = data.majorityLabel(rng);
}
log.debug("branch is not split Leaf({})", label);
return new Leaf(label);
}

selected[best.getAttr()] = true;

Node[] children = new Node[values.length];
for (int index = 0; index < values.length; index++) {
if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) {
// tree is complemented
double label;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
label = sum / data.size();
} else {
label = data.majorityLabel(rng);
}
log.debug("complemented Leaf({})", label);
children[index] = new Leaf(label);
continue;
}
children[index] = build(rng, subsets[index]);
}

selected[best.getAttr()] = alreadySelected;

childNode = new CategoricalNode(best.getAttr(), values, children);
}

return childNode;//返回子节点
}

/**
* checks if all the vectors have identical attribute values. Ignore selected attributes.
*
* @return true is all the vectors are identical or the data is empty<br>
* false otherwise
*/
private boolean isIdentical(Data data) {
if (data.isEmpty()) {
return true;
}

Instance instance = data.get(0);
for (int attr = 0; attr < selected.length; attr++) {
if (selected[attr]) {
continue;
}

for (int index = 1; index < data.size(); index++) {
if (data.get(index).get(attr) != instance.get(attr)) {
return false;
}
}
}

return true;
}

/**
* Make a copy of the selection state of the attributes, unselect all numerical attributes
*
* @param selected selection state to clone
* @return cloned selection state
*/
private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) {
boolean[] cloned = new boolean[selected.length];

for (int i = 0; i < selected.length; i++) {
cloned[i] = !dataset.isNumerical(i) && selected[i];
}
cloned[dataset.getLabelId()] = true;

return cloned;
}

/**
* Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes
*
* @param rng random-numbers generator
* @param selected attributes' state (selected or not)
* @param m number of attributes to choose
* @return list of selected attributes' indices, or null if all attributes have already been selected
*/
private static int[] randomAttributes(Random rng, boolean[] selected, int m) {
int nbNonSelected = 0; // number of non selected attributes
for (boolean sel : selected) {
if (!sel) {
nbNonSelected++;
}
}

if (nbNonSelected == 0) {
log.warn("All attributes are selected !");
return NO_ATTRIBUTES;
}

int[] result;
if (nbNonSelected <= m) {
// return all non selected attributes
result = new int[nbNonSelected];
int index = 0;
for (int attr = 0; attr < selected.length; attr++) {
if (!selected[attr]) {
result[index++] = attr;
}
}
} else {
result = new int[m];
for (int index = 0; index < m; index++) {
// randomly choose a "non selected" attribute
int rind;
do {
rind = rng.nextInt(selected.length);
} while (selected[rind]);

result[index] = rind;
selected[rind] = true; // temporarily set the chosen attribute to be selected
}

// the chosen attributes are not yet selected
for (int attr : result) {
selected[attr] = false;
}
}

return result;
}
}

本文作者:Qiu Qingyu
版权声明:本博客所有文章除特别声明外,均采用CC BY-NC-SA 3.0 CN许可协议。转载请注明出处!
本文永久链接:http://qiuqingyu.cn/2016/01/27/对Mahout中随机森林的决策树建立过程源码分析/