1 | #region License Information
|
---|
2 | /* HeuristicLab
|
---|
3 | * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
|
---|
4 | *
|
---|
5 | * This file is part of HeuristicLab.
|
---|
6 | *
|
---|
7 | * HeuristicLab is free software: you can redistribute it and/or modify
|
---|
8 | * it under the terms of the GNU General Public License as published by
|
---|
9 | * the Free Software Foundation, either version 3 of the License, or
|
---|
10 | * (at your option) any later version.
|
---|
11 | *
|
---|
12 | * HeuristicLab is distributed in the hope that it will be useful,
|
---|
13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
15 | * GNU General Public License for more details.
|
---|
16 | *
|
---|
17 | * You should have received a copy of the GNU General Public License
|
---|
18 | * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
|
---|
19 | */
|
---|
20 | #endregion
|
---|
21 |
|
---|
22 | using System;
|
---|
23 | using System.Collections.Generic;
|
---|
24 | using System.Linq;
|
---|
25 | using System.Threading;
|
---|
26 | using HeuristicLab.Common;
|
---|
27 | using HeuristicLab.Core;
|
---|
28 | using HeuristicLab.Data;
|
---|
29 | using HeuristicLab.Encodings.PermutationEncoding;
|
---|
30 | using HeuristicLab.Optimization;
|
---|
31 | using HeuristicLab.Parameters;
|
---|
32 | using HeuristicLab.PluginInfrastructure;
|
---|
33 | using HeuristicLab.Problems.DataAnalysis;
|
---|
34 | using HeuristicLab.Random;
|
---|
35 | using HEAL.Attic;
|
---|
36 |
|
---|
37 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
38 | [StorableType("FC8D8E5A-D16D-41BB-91CF-B2B35D17ADD7")]
|
---|
39 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)]
|
---|
40 | [Item("Decision Tree Regression (DT)", "A regression tree / rule set learner")]
|
---|
41 | public sealed class DecisionTreeRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
|
---|
42 | public override bool SupportsPause {
|
---|
43 | get { return true; }
|
---|
44 | }
|
---|
45 |
|
---|
46 | public const string RegressionTreeParameterVariableName = "RegressionTreeParameters";
|
---|
47 | public const string ModelVariableName = "Model";
|
---|
48 | public const string PruningSetVariableName = "PruningSet";
|
---|
49 | public const string TrainingSetVariableName = "TrainingSet";
|
---|
50 |
|
---|
51 | #region Parameter names
|
---|
52 | private const string GenerateRulesParameterName = "GenerateRules";
|
---|
53 | private const string HoldoutSizeParameterName = "HoldoutSize";
|
---|
54 | private const string SplitterParameterName = "Splitter";
|
---|
55 | private const string MinimalNodeSizeParameterName = "MinimalNodeSize";
|
---|
56 | private const string LeafModelParameterName = "LeafModel";
|
---|
57 | private const string PruningTypeParameterName = "PruningType";
|
---|
58 | private const string SeedParameterName = "Seed";
|
---|
59 | private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
|
---|
60 | private const string UseHoldoutParameterName = "UseHoldout";
|
---|
61 | #endregion
|
---|
62 |
|
---|
63 | #region Parameter properties
|
---|
64 | public IFixedValueParameter<BoolValue> GenerateRulesParameter {
|
---|
65 | get { return (IFixedValueParameter<BoolValue>)Parameters[GenerateRulesParameterName]; }
|
---|
66 | }
|
---|
67 | public IFixedValueParameter<PercentValue> HoldoutSizeParameter {
|
---|
68 | get { return (IFixedValueParameter<PercentValue>)Parameters[HoldoutSizeParameterName]; }
|
---|
69 | }
|
---|
70 | public IConstrainedValueParameter<ISplitter> SplitterParameter {
|
---|
71 | get { return (IConstrainedValueParameter<ISplitter>)Parameters[SplitterParameterName]; }
|
---|
72 | }
|
---|
73 | public IFixedValueParameter<IntValue> MinimalNodeSizeParameter {
|
---|
74 | get { return (IFixedValueParameter<IntValue>)Parameters[MinimalNodeSizeParameterName]; }
|
---|
75 | }
|
---|
76 | public IConstrainedValueParameter<ILeafModel> LeafModelParameter {
|
---|
77 | get { return (IConstrainedValueParameter<ILeafModel>)Parameters[LeafModelParameterName]; }
|
---|
78 | }
|
---|
79 | public IConstrainedValueParameter<IPruning> PruningTypeParameter {
|
---|
80 | get { return (IConstrainedValueParameter<IPruning>)Parameters[PruningTypeParameterName]; }
|
---|
81 | }
|
---|
82 | public IFixedValueParameter<IntValue> SeedParameter {
|
---|
83 | get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
|
---|
84 | }
|
---|
85 | public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
|
---|
86 | get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
|
---|
87 | }
|
---|
88 | public IFixedValueParameter<BoolValue> UseHoldoutParameter {
|
---|
89 | get { return (IFixedValueParameter<BoolValue>)Parameters[UseHoldoutParameterName]; }
|
---|
90 | }
|
---|
91 | #endregion
|
---|
92 |
|
---|
93 | #region Properties
|
---|
94 | public bool GenerateRules {
|
---|
95 | get { return GenerateRulesParameter.Value.Value; }
|
---|
96 | set { GenerateRulesParameter.Value.Value = value; }
|
---|
97 | }
|
---|
98 | public double HoldoutSize {
|
---|
99 | get { return HoldoutSizeParameter.Value.Value; }
|
---|
100 | set { HoldoutSizeParameter.Value.Value = value; }
|
---|
101 | }
|
---|
102 | public ISplitter Splitter {
|
---|
103 | get { return SplitterParameter.Value; }
|
---|
104 | // no setter because this is a constrained parameter
|
---|
105 | }
|
---|
106 | public int MinimalNodeSize {
|
---|
107 | get { return MinimalNodeSizeParameter.Value.Value; }
|
---|
108 | set { MinimalNodeSizeParameter.Value.Value = value; }
|
---|
109 | }
|
---|
110 | public ILeafModel LeafModel {
|
---|
111 | get { return LeafModelParameter.Value; }
|
---|
112 | }
|
---|
113 | public IPruning Pruning {
|
---|
114 | get { return PruningTypeParameter.Value; }
|
---|
115 | }
|
---|
116 | public int Seed {
|
---|
117 | get { return SeedParameter.Value.Value; }
|
---|
118 | set { SeedParameter.Value.Value = value; }
|
---|
119 | }
|
---|
120 | public bool SetSeedRandomly {
|
---|
121 | get { return SetSeedRandomlyParameter.Value.Value; }
|
---|
122 | set { SetSeedRandomlyParameter.Value.Value = value; }
|
---|
123 | }
|
---|
124 | public bool UseHoldout {
|
---|
125 | get { return UseHoldoutParameter.Value.Value; }
|
---|
126 | set { UseHoldoutParameter.Value.Value = value; }
|
---|
127 | }
|
---|
128 | #endregion
|
---|
129 |
|
---|
130 | #region State
|
---|
131 | [Storable]
|
---|
132 | private IScope stateScope;
|
---|
133 | #endregion
|
---|
134 |
|
---|
135 | #region Constructors and Cloning
|
---|
136 | [StorableConstructor]
|
---|
137 | private DecisionTreeRegression(StorableConstructorFlag _) : base(_) { }
|
---|
138 | private DecisionTreeRegression(DecisionTreeRegression original, Cloner cloner) : base(original, cloner) {
|
---|
139 | stateScope = cloner.Clone(stateScope);
|
---|
140 | }
|
---|
141 | public DecisionTreeRegression() {
|
---|
142 | var modelSet = new ItemSet<ILeafModel>(ApplicationManager.Manager.GetInstances<ILeafModel>());
|
---|
143 | var pruningSet = new ItemSet<IPruning>(ApplicationManager.Manager.GetInstances<IPruning>());
|
---|
144 | var splitterSet = new ItemSet<ISplitter>(ApplicationManager.Manager.GetInstances<ISplitter>());
|
---|
145 | Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created (default=false)", new BoolValue(false)));
|
---|
146 | Parameters.Add(new FixedValueParameter<PercentValue>(HoldoutSizeParameterName, "How much of the training set shall be reserved for pruning (default=20%).", new PercentValue(0.2)));
|
---|
147 | Parameters.Add(new ConstrainedValueParameter<ISplitter>(SplitterParameterName, "The type of split function used to create node splits (default='Splitter').", splitterSet, splitterSet.OfType<Splitter>().First()));
|
---|
148 | Parameters.Add(new FixedValueParameter<IntValue>(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node (default=1).", new IntValue(1)));
|
---|
149 | Parameters.Add(new ConstrainedValueParameter<ILeafModel>(LeafModelParameterName, "The type of model used for the nodes (default='LinearLeaf').", modelSet, modelSet.OfType<LinearLeaf>().First()));
|
---|
150 | Parameters.Add(new ConstrainedValueParameter<IPruning>(PruningTypeParameterName, "The type of pruning used (default='ComplexityPruning').", pruningSet, pruningSet.OfType<ComplexityPruning>().First()));
|
---|
151 | Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
|
---|
152 | Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
|
---|
153 | Parameters.Add(new FixedValueParameter<BoolValue>(UseHoldoutParameterName, "True if a holdout set should be generated, false if splitting and pruning shall be performed on the same data (default=false).", new BoolValue(false)));
|
---|
154 | Problem = new RegressionProblem();
|
---|
155 | }
|
---|
156 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
157 | return new DecisionTreeRegression(this, cloner);
|
---|
158 | }
|
---|
159 | #endregion
|
---|
160 |
|
---|
161 | protected override void Initialize(CancellationToken cancellationToken) {
|
---|
162 | base.Initialize(cancellationToken);
|
---|
163 | var random = new MersenneTwister();
|
---|
164 | if (SetSeedRandomly) Seed = RandomSeedGenerator.GetSeed();
|
---|
165 | random.Reset(Seed);
|
---|
166 | stateScope = InitializeScope(random, Problem.ProblemData, Pruning, MinimalNodeSize, LeafModel, Splitter, GenerateRules, UseHoldout, HoldoutSize);
|
---|
167 | stateScope.Variables.Add(new Variable("Algorithm", this));
|
---|
168 | Results.AddOrUpdateResult("StateScope", stateScope);
|
---|
169 | }
|
---|
170 |
|
---|
171 | protected override void Run(CancellationToken cancellationToken) {
|
---|
172 | var model = Build(stateScope, Results, cancellationToken);
|
---|
173 | AnalyzeSolution(model.CreateRegressionSolution(Problem.ProblemData), Results, Problem.ProblemData);
|
---|
174 | }
|
---|
175 |
|
---|
176 | #region Static Interface
|
---|
177 | public static IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, IRandom random, ILeafModel leafModel = null, ISplitter splitter = null, IPruning pruning = null,
|
---|
178 | bool useHoldout = false, double holdoutSize = 0.2, int minimumLeafSize = 1, bool generateRules = false, ResultCollection results = null, CancellationToken? cancellationToken = null) {
|
---|
179 | if (leafModel == null) leafModel = new LinearLeaf();
|
---|
180 | if (splitter == null) splitter = new Splitter();
|
---|
181 | if (cancellationToken == null) cancellationToken = CancellationToken.None;
|
---|
182 | if (pruning == null) pruning = new ComplexityPruning();
|
---|
183 |
|
---|
184 | var stateScope = InitializeScope(random, problemData, pruning, minimumLeafSize, leafModel, splitter, generateRules, useHoldout, holdoutSize);
|
---|
185 | var model = Build(stateScope, results, cancellationToken.Value);
|
---|
186 | return model.CreateRegressionSolution(problemData);
|
---|
187 | }
|
---|
188 |
|
---|
189 | public static void UpdateModel(IDecisionTreeModel model, IRegressionProblemData problemData, IRandom random, ILeafModel leafModel, CancellationToken? cancellationToken = null) {
|
---|
190 | if (cancellationToken == null) cancellationToken = CancellationToken.None;
|
---|
191 | var regressionTreeParameters = new RegressionTreeParameters(leafModel, problemData, random);
|
---|
192 | var scope = new Scope();
|
---|
193 | scope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParameters));
|
---|
194 | leafModel.Initialize(scope);
|
---|
195 | model.Update(problemData.TrainingIndices.ToList(), scope, cancellationToken.Value);
|
---|
196 | }
|
---|
197 | #endregion
|
---|
198 |
|
---|
199 | #region Helpers
|
---|
200 | private static IScope InitializeScope(IRandom random, IRegressionProblemData problemData, IPruning pruning, int minLeafSize, ILeafModel leafModel, ISplitter splitter, bool generateRules, bool useHoldout, double holdoutSize) {
|
---|
201 | var stateScope = new Scope("RegressionTreeStateScope");
|
---|
202 |
|
---|
203 | //reduce RegressionProblemData to AllowedInput & Target column wise and to TrainingSet row wise
|
---|
204 | var doubleVars = new HashSet<string>(problemData.Dataset.DoubleVariables);
|
---|
205 | var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray();
|
---|
206 | if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("Decision tree regression supports only double valued input or output features.");
|
---|
207 | var doubles = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();
|
---|
208 | if (doubles.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
|
---|
209 | throw new NotSupportedException("Decision tree regression does not support NaN or infinity values in the input dataset.");
|
---|
210 | var trainingData = new Dataset(vars, doubles);
|
---|
211 | var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);
|
---|
212 | pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows;
|
---|
213 | pd.TrainingPartition.Start = 0;
|
---|
214 |
|
---|
215 | //store regression tree parameters
|
---|
216 | var regressionTreeParams = new RegressionTreeParameters(pruning, minLeafSize, leafModel, pd, random, splitter);
|
---|
217 | stateScope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParams));
|
---|
218 |
|
---|
219 | //initialize tree operators
|
---|
220 | pruning.Initialize(stateScope);
|
---|
221 | splitter.Initialize(stateScope);
|
---|
222 | leafModel.Initialize(stateScope);
|
---|
223 |
|
---|
224 | //store unbuilt model
|
---|
225 | IItem model;
|
---|
226 | if (generateRules) {
|
---|
227 | model = RegressionRuleSetModel.CreateRuleModel(problemData.TargetVariable, regressionTreeParams);
|
---|
228 | RegressionRuleSetModel.Initialize(stateScope);
|
---|
229 | }
|
---|
230 | else {
|
---|
231 | model = RegressionNodeTreeModel.CreateTreeModel(problemData.TargetVariable, regressionTreeParams);
|
---|
232 | }
|
---|
233 | stateScope.Variables.Add(new Variable(ModelVariableName, model));
|
---|
234 |
|
---|
235 | //store training & pruning indices
|
---|
236 | IReadOnlyList<int> trainingSet, pruningSet;
|
---|
237 | GeneratePruningSet(pd.TrainingIndices.ToArray(), random, useHoldout, holdoutSize, out trainingSet, out pruningSet);
|
---|
238 | stateScope.Variables.Add(new Variable(TrainingSetVariableName, new IntArray(trainingSet.ToArray())));
|
---|
239 | stateScope.Variables.Add(new Variable(PruningSetVariableName, new IntArray(pruningSet.ToArray())));
|
---|
240 |
|
---|
241 | return stateScope;
|
---|
242 | }
|
---|
243 |
|
---|
244 | private static IRegressionModel Build(IScope stateScope, ResultCollection results, CancellationToken cancellationToken) {
|
---|
245 | var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[RegressionTreeParameterVariableName].Value;
|
---|
246 | var model = (IDecisionTreeModel)stateScope.Variables[ModelVariableName].Value;
|
---|
247 | var trainingRows = (IntArray)stateScope.Variables[TrainingSetVariableName].Value;
|
---|
248 | var pruningRows = (IntArray)stateScope.Variables[PruningSetVariableName].Value;
|
---|
249 | if (1 > trainingRows.Length)
|
---|
250 | return new PreconstructedLinearModel(new Dictionary<string, double>(), 0, regressionTreeParams.TargetVariable);
|
---|
251 | if (regressionTreeParams.MinLeafSize > trainingRows.Length) {
|
---|
252 | var targets = regressionTreeParams.Data.GetDoubleValues(regressionTreeParams.TargetVariable).ToArray();
|
---|
253 | return new PreconstructedLinearModel(new Dictionary<string, double>(), targets.Average(), regressionTreeParams.TargetVariable);
|
---|
254 | }
|
---|
255 | model.Build(trainingRows.ToArray(), pruningRows.ToArray(), stateScope, results, cancellationToken);
|
---|
256 | return model;
|
---|
257 | }
|
---|
258 |
|
---|
259 | private static void GeneratePruningSet(IReadOnlyList<int> allrows, IRandom random, bool useHoldout, double holdoutSize, out IReadOnlyList<int> training, out IReadOnlyList<int> pruning) {
|
---|
260 | if (!useHoldout) {
|
---|
261 | training = allrows;
|
---|
262 | pruning = allrows;
|
---|
263 | return;
|
---|
264 | }
|
---|
265 | var perm = new Permutation(PermutationTypes.Absolute, allrows.Count, random);
|
---|
266 | var cut = (int)(holdoutSize * allrows.Count);
|
---|
267 | pruning = perm.Take(cut).Select(i => allrows[i]).ToArray();
|
---|
268 | training = perm.Take(cut).Select(i => allrows[i]).ToArray();
|
---|
269 | }
|
---|
270 |
|
---|
271 | private void AnalyzeSolution(IRegressionSolution solution, ResultCollection results, IRegressionProblemData problemData) {
|
---|
272 | results.Add(new Result("RegressionSolution", (IItem)solution.Clone()));
|
---|
273 |
|
---|
274 | Dictionary<string, int> frequencies = null;
|
---|
275 |
|
---|
276 | var tree = solution.Model as RegressionNodeTreeModel;
|
---|
277 | if (tree != null) {
|
---|
278 | results.Add(RegressionTreeAnalyzer.CreateLeafDepthHistogram(tree));
|
---|
279 | frequencies = RegressionTreeAnalyzer.GetTreeVariableFrequences(tree);
|
---|
280 | RegressionTreeAnalyzer.AnalyzeNodes(tree, results, problemData);
|
---|
281 | }
|
---|
282 |
|
---|
283 | var ruleSet = solution.Model as RegressionRuleSetModel;
|
---|
284 | if (ruleSet != null) {
|
---|
285 | results.Add(RegressionTreeAnalyzer.CreateRulesResult(ruleSet, problemData, "Rules", true));
|
---|
286 | frequencies = RegressionTreeAnalyzer.GetRuleVariableFrequences(ruleSet);
|
---|
287 | results.Add(RegressionTreeAnalyzer.CreateCoverageDiagram(ruleSet, problemData));
|
---|
288 | }
|
---|
289 |
|
---|
290 | //Variable frequencies
|
---|
291 | if (frequencies != null) {
|
---|
292 | var sum = frequencies.Values.Sum();
|
---|
293 | sum = sum == 0 ? 1 : sum;
|
---|
294 | var impactArray = new DoubleArray(frequencies.Select(i => (double)i.Value / sum).ToArray()) {
|
---|
295 | ElementNames = frequencies.Select(i => i.Key)
|
---|
296 | };
|
---|
297 | results.Add(new Result("Variable Frequences", "relative frequencies of variables in rules and tree nodes", impactArray));
|
---|
298 | }
|
---|
299 |
|
---|
300 | var pruning = Pruning as ComplexityPruning;
|
---|
301 | if (pruning != null && tree != null)
|
---|
302 | RegressionTreeAnalyzer.PruningChart(tree, pruning, results);
|
---|
303 | }
|
---|
304 | #endregion
|
---|
305 | }
|
---|
306 | } |
---|