本文實(shí)例講述了Java實(shí)現(xiàn)的決策樹算法。分享給大家供大家參考,具體如下:
決策樹算法是一種逼近離散函數(shù)值的方法。它是一種典型的分類方法,首先對數(shù)據(jù)進(jìn)行處理,利用歸納算法生成可讀的規(guī)則和決策樹,然后使用決策對新數(shù)據(jù)進(jìn)行分析。本質(zhì)上決策樹是通過一系列規(guī)則對數(shù)據(jù)進(jìn)行分類的過程。
決策樹構(gòu)造可以分兩步進(jìn)行。第一步,決策樹的生成:由訓(xùn)練樣本集生成決策樹的過程。一般情況下,訓(xùn)練樣本數(shù)據(jù)集是根據(jù)實(shí)際需要有歷史的、有一定綜合程度的,用于數(shù)據(jù)分析處理的數(shù)據(jù)集。第二步,決策樹的剪枝:決策樹的剪枝是對上一階段生成的決策樹進(jìn)行檢驗(yàn)、校正和修下的過程,主要是用新的樣本數(shù)據(jù)集(稱為測試數(shù)據(jù)集)中的數(shù)據(jù)校驗(yàn)決策樹生成過程中產(chǎn)生的初步規(guī)則,將那些影響預(yù)衡準(zhǔn)確性的分枝剪除。
java實(shí)現(xiàn)代碼如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
|
package demo; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; public class DicisionTree { public static void main(String[] args) throws Exception { System.out.print( "服務(wù)器之家測試結(jié)果:" ); String[] attrNames = new String[] { "AGE" , "INCOME" , "STUDENT" , "CREDIT_RATING" }; // 讀取樣本集 Map<Object, List<Sample>> samples = readSamples(attrNames); // 生成決策樹 Object decisionTree = generateDecisionTree(samples, attrNames); // 輸出決策樹 outputDecisionTree(decisionTree, 0 , null ); } /** * 讀取已分類的樣本集,返回Map:分類 -> 屬于該分類的樣本的列表 */ static Map<Object, List<Sample>> readSamples(String[] attrNames) { // 樣本屬性及其所屬分類(數(shù)組中的最后一個(gè)元素為樣本所屬分類) Object[][] rawData = new Object[][] { { "<30 " , "High " , "No " , "Fair " , "0" }, { "<30 " , "High " , "No " , "Excellent" , "0" }, { "30-40" , "High " , "No " , "Fair " , "1" }, { ">40 " , "Medium" , "No " , "Fair " , "1" }, { ">40 " , "Low " , "Yes" , "Fair " , "1" }, { ">40 " , "Low " , "Yes" , "Excellent" , "0" }, { "30-40" , "Low " , "Yes" , "Excellent" , "1" }, { "<30 " , "Medium" , "No " , "Fair " , "0" }, { "<30 " , "Low " , "Yes" , "Fair " , "1" }, { ">40 " , "Medium" , "Yes" , "Fair " , "1" }, { "<30 " , "Medium" , "Yes" , "Excellent" , "1" }, { "30-40" , "Medium" , "No " , "Excellent" , "1" }, { "30-40" , "High " , "Yes" , "Fair " , "1" }, { ">40 " , "Medium" , "No " , "Excellent" , "0" } }; // 讀取樣本屬性及其所屬分類,構(gòu)造表示樣本的Sample對象,并按分類劃分樣本集 Map<Object, List<Sample>> ret = new HashMap<Object, List<Sample>>(); for (Object[] row : rawData) { Sample sample = new Sample(); int i = 0 ; for ( int n = row.length - 1 ; i < n; i++) sample.setAttribute(attrNames[i], row[i]); sample.setCategory(row[i]); List<Sample> samples = ret.get(row[i]); if (samples == null ) { samples = new LinkedList<Sample>(); ret.put(row[i], samples); } samples.add(sample); } return ret; } /** * 構(gòu)造決策樹 */ static Object generateDecisionTree( Map<Object, List<Sample>> categoryToSamples, String[] attrNames) { // 如果只有一個(gè)樣本,將該樣本所屬分類作為新樣本的分類 if (categoryToSamples.size() == 1 ) return categoryToSamples.keySet().iterator().next(); // 如果沒有供決策的屬性,則將樣本集中具有最多樣本的分類作為新樣本的分類,即投票選舉出分類 if (attrNames.length == 0 ) { int max = 0 ; Object maxCategory = null ; for (Entry<Object, List<Sample>> entry : categoryToSamples .entrySet()) { int cur = entry.getValue().size(); if (cur > max) { max = cur; maxCategory = entry.getKey(); } } return maxCategory; } // 選取測試屬性 Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames); // 決策樹根結(jié)點(diǎn),分支屬性為選取的測試屬性 Tree tree = new Tree(attrNames[(Integer) rst[ 0 ]]); // 已用過的測試屬性不應(yīng)再次被選為測試屬性 String[] subA = new String[attrNames.length - 1 ]; for ( int i = 0 , j = 0 ; i < attrNames.length; i++) if (i != (Integer) rst[ 0 ]) subA[j++] = attrNames[i]; // 根據(jù)分支屬性生成分支 @SuppressWarnings ( "unchecked" ) Map<Object, Map<Object, List<Sample>>> splits = /* NEW LINE */(Map<Object, Map<Object, List<Sample>>>) rst[2]; for (Entry<Object, Map<Object, List<Sample>>> entry : splits.entrySet()) { Object attrValue = entry.getKey(); Map<Object, List<Sample>> split = entry.getValue(); Object child = generateDecisionTree(split, subA); tree.setChild(attrValue, child); } return tree; } /** * 選取最優(yōu)測試屬性。最優(yōu)是指如果根據(jù)選取的測試屬性分支,則從各分支確定新樣本 * 的分類需要的信息量之和最小,這等價(jià)于確定新樣本的測試屬性獲得的信息增益最大 * 返回?cái)?shù)組:選取的屬性下標(biāo)、信息量之和、Map(屬性值->(分類->樣本列表)) */ static Object[] chooseBestTestAttribute( Map<Object, List<Sample>> categoryToSamples, String[] attrNames) { int minIndex = -1; // 最優(yōu)屬性下標(biāo) double minValue = Double.MAX_VALUE; // 最小信息量 Map<Object, Map<Object, List<Sample>>> minSplits = null; // 最優(yōu)分支方案 // 對每一個(gè)屬性,計(jì)算將其作為測試屬性的情況下在各分支確定新樣本的分類需要的信息量之和,選取最小為最優(yōu) for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) { int allCount = 0; // 統(tǒng)計(jì)樣本總數(shù)的計(jì)數(shù)器 // 按當(dāng)前屬性構(gòu)建Map:屬性值->(分類->樣本列表) Map<Object, Map<Object, List<Sample>>> curSplits = /* NEW LINE */new HashMap<Object, Map<Object, List<Sample>>>(); for (Entry<Object, List<Sample>> entry : categoryToSamples .entrySet()) { Object category = entry.getKey(); List<Sample> samples = entry.getValue(); for (Sample sample : samples) { Object attrValue = sample .getAttribute(attrNames[attrIndex]); Map<Object, List<Sample>> split = curSplits.get(attrValue); if (split == null) { split = new HashMap<Object, List<Sample>>(); curSplits.put(attrValue, split); } List<Sample> splitSamples = split.get(category); if (splitSamples == null) { splitSamples = new LinkedList<Sample>(); split.put(category, splitSamples); } splitSamples.add(sample); } allCount += samples.size(); } // 計(jì)算將當(dāng)前屬性作為測試屬性的情況下在各分支確定新樣本的分類需要的信息量之和 double curValue = 0.0; // 計(jì)數(shù)器:累加各分支 for (Map<Object, List<Sample>> splits : curSplits.values()) { double perSplitCount = 0; for (List<Sample> list : splits.values()) perSplitCount += list.size(); // 累計(jì)當(dāng)前分支樣本數(shù) double perSplitValue = 0.0; // 計(jì)數(shù)器:當(dāng)前分支 for (List<Sample> list : splits.values()) { double p = list.size() / perSplitCount; perSplitValue -= p * (Math.log(p) / Math.log(2)); } curValue += (perSplitCount / allCount) * perSplitValue; } // 選取最小為最優(yōu) if (minValue > curValue) { minIndex = attrIndex; minValue = curValue; minSplits = curSplits; } } return new Object[] { minIndex, minValue, minSplits }; } /** * 將決策樹輸出到標(biāo)準(zhǔn)輸出 */ static void outputDecisionTree(Object obj, int level, Object from) { for (int i = 0; i < level; i++) System.out.print("|-----"); if (from != null) System.out.printf("(%s):", from); if (obj instanceof Tree) { Tree tree = (Tree) obj; String attrName = tree.getAttribute(); System.out.printf("[%s = ?]\n", attrName); for (Object attrValue : tree.getAttributeValues()) { Object child = tree.getChild(attrValue); outputDecisionTree(child, level + 1, attrName + " = " + attrValue); } } else { System.out.printf("[CATEGORY = %s]\n", obj); } } /** * 樣本,包含多個(gè)屬性和一個(gè)指明樣本所屬分類的分類值 */ static class Sample { private Map<String, Object> attributes = new HashMap<String, Object>(); private Object category; public Object getAttribute(String name) { return attributes.get(name); } public void setAttribute(String name, Object value) { attributes.put(name, value); } public Object getCategory() { return category; } public void setCategory(Object category) { this.category = category; } public String toString() { return attributes.toString(); } } /** * 決策樹(非葉結(jié)點(diǎn)),決策樹中的每個(gè)非葉結(jié)點(diǎn)都引導(dǎo)了一棵決策樹 * 每個(gè)非葉結(jié)點(diǎn)包含一個(gè)分支屬性和多個(gè)分支,分支屬性的每個(gè)值對應(yīng)一個(gè)分支,該分支引導(dǎo)了一棵子決策樹 */ static class Tree { private String attribute; private Map<Object, Object> children = new HashMap<Object, Object>(); public Tree(String attribute) { this .attribute = attribute; } public String getAttribute() { return attribute; } public Object getChild(Object attrValue) { return children.get(attrValue); } public void setChild(Object attrValue, Object child) { children.put(attrValue, child); } public Set<Object> getAttributeValues() { return children.keySet(); } } } |
運(yùn)行結(jié)果:
希望本文所述對大家java程序設(shè)計(jì)有所幫助。
原文鏈接:http://blog.csdn.net/u013058160/article/details/50035693