Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis.DecisionTrees/3.4/DecisionTreeRegression.cs @ 17460

Last change on this file since 17460 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 16.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22 using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.PermutationEncoding;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.PluginInfrastructure;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Random;
35using HEAL.Attic;
36
37namespace HeuristicLab.Algorithms.DataAnalysis {
38  [StorableType("FC8D8E5A-D16D-41BB-91CF-B2B35D17ADD7")]
39  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)]
40  [Item("Decision Tree Regression (DT)", "A regression tree / rule set learner")]
41  public sealed class DecisionTreeRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
42    public override bool SupportsPause {
43      get { return true; }
44    }
45
46    public const string RegressionTreeParameterVariableName = "RegressionTreeParameters";
47    public const string ModelVariableName = "Model";
48    public const string PruningSetVariableName = "PruningSet";
49    public const string TrainingSetVariableName = "TrainingSet";
50
51    #region Parameter names
52    private const string GenerateRulesParameterName = "GenerateRules";
53    private const string HoldoutSizeParameterName = "HoldoutSize";
54    private const string SplitterParameterName = "Splitter";
55    private const string MinimalNodeSizeParameterName = "MinimalNodeSize";
56    private const string LeafModelParameterName = "LeafModel";
57    private const string PruningTypeParameterName = "PruningType";
58    private const string SeedParameterName = "Seed";
59    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
60    private const string UseHoldoutParameterName = "UseHoldout";
61    #endregion
62
63    #region Parameter properties
64    public IFixedValueParameter<BoolValue> GenerateRulesParameter {
65      get { return (IFixedValueParameter<BoolValue>)Parameters[GenerateRulesParameterName]; }
66    }
67    public IFixedValueParameter<PercentValue> HoldoutSizeParameter {
68      get { return (IFixedValueParameter<PercentValue>)Parameters[HoldoutSizeParameterName]; }
69    }
70    public IConstrainedValueParameter<ISplitter> SplitterParameter {
71      get { return (IConstrainedValueParameter<ISplitter>)Parameters[SplitterParameterName]; }
72    }
73    public IFixedValueParameter<IntValue> MinimalNodeSizeParameter {
74      get { return (IFixedValueParameter<IntValue>)Parameters[MinimalNodeSizeParameterName]; }
75    }
76    public IConstrainedValueParameter<ILeafModel> LeafModelParameter {
77      get { return (IConstrainedValueParameter<ILeafModel>)Parameters[LeafModelParameterName]; }
78    }
79    public IConstrainedValueParameter<IPruning> PruningTypeParameter {
80      get { return (IConstrainedValueParameter<IPruning>)Parameters[PruningTypeParameterName]; }
81    }
82    public IFixedValueParameter<IntValue> SeedParameter {
83      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
84    }
85    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
86      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
87    }
88    public IFixedValueParameter<BoolValue> UseHoldoutParameter {
89      get { return (IFixedValueParameter<BoolValue>)Parameters[UseHoldoutParameterName]; }
90    }
91    #endregion
92
93    #region Properties
94    public bool GenerateRules {
95      get { return GenerateRulesParameter.Value.Value; }
96      set { GenerateRulesParameter.Value.Value = value; }
97    }
98    public double HoldoutSize {
99      get { return HoldoutSizeParameter.Value.Value; }
100      set { HoldoutSizeParameter.Value.Value = value; }
101    }
102    public ISplitter Splitter {
103      get { return SplitterParameter.Value; }
104      // no setter because this is a constrained parameter
105    }
106    public int MinimalNodeSize {
107      get { return MinimalNodeSizeParameter.Value.Value; }
108      set { MinimalNodeSizeParameter.Value.Value = value; }
109    }
110    public ILeafModel LeafModel {
111      get { return LeafModelParameter.Value; }
112    }
113    public IPruning Pruning {
114      get { return PruningTypeParameter.Value; }
115    }
116    public int Seed {
117      get { return SeedParameter.Value.Value; }
118      set { SeedParameter.Value.Value = value; }
119    }
120    public bool SetSeedRandomly {
121      get { return SetSeedRandomlyParameter.Value.Value; }
122      set { SetSeedRandomlyParameter.Value.Value = value; }
123    }
124    public bool UseHoldout {
125      get { return UseHoldoutParameter.Value.Value; }
126      set { UseHoldoutParameter.Value.Value = value; }
127    }
128    #endregion
129
130    #region State
131    [Storable]
132    private IScope stateScope;
133    #endregion
134
135    #region Constructors and Cloning
136    [StorableConstructor]
137    private DecisionTreeRegression(StorableConstructorFlag _) : base(_) { }
138    private DecisionTreeRegression(DecisionTreeRegression original, Cloner cloner) : base(original, cloner) {
139      stateScope = cloner.Clone(stateScope);
140    }
141    public DecisionTreeRegression() {
142      var modelSet = new ItemSet<ILeafModel>(ApplicationManager.Manager.GetInstances<ILeafModel>());
143      var pruningSet = new ItemSet<IPruning>(ApplicationManager.Manager.GetInstances<IPruning>());
144      var splitterSet = new ItemSet<ISplitter>(ApplicationManager.Manager.GetInstances<ISplitter>());
145      Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created (default=false)", new BoolValue(false)));
146      Parameters.Add(new FixedValueParameter<PercentValue>(HoldoutSizeParameterName, "How much of the training set shall be reserved for pruning (default=20%).", new PercentValue(0.2)));
147      Parameters.Add(new ConstrainedValueParameter<ISplitter>(SplitterParameterName, "The type of split function used to create node splits (default='Splitter').", splitterSet, splitterSet.OfType<Splitter>().First()));
148      Parameters.Add(new FixedValueParameter<IntValue>(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node (default=1).", new IntValue(1)));
149      Parameters.Add(new ConstrainedValueParameter<ILeafModel>(LeafModelParameterName, "The type of model used for the nodes (default='LinearLeaf').", modelSet, modelSet.OfType<LinearLeaf>().First()));
150      Parameters.Add(new ConstrainedValueParameter<IPruning>(PruningTypeParameterName, "The type of pruning used (default='ComplexityPruning').", pruningSet, pruningSet.OfType<ComplexityPruning>().First()));
151      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
152      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
153      Parameters.Add(new FixedValueParameter<BoolValue>(UseHoldoutParameterName, "True if a holdout set should be generated, false if splitting and pruning shall be performed on the same data (default=false).", new BoolValue(false)));
154      Problem = new RegressionProblem();
155    }
156    public override IDeepCloneable Clone(Cloner cloner) {
157      return new DecisionTreeRegression(this, cloner);
158    }
159    #endregion
160
161    protected override void Initialize(CancellationToken cancellationToken) {
162      base.Initialize(cancellationToken);
163      var random = new MersenneTwister();
164      if (SetSeedRandomly) Seed = RandomSeedGenerator.GetSeed();
165      random.Reset(Seed);
166      stateScope = InitializeScope(random, Problem.ProblemData, Pruning, MinimalNodeSize, LeafModel, Splitter, GenerateRules, UseHoldout, HoldoutSize);
167      stateScope.Variables.Add(new Variable("Algorithm", this));
168      Results.AddOrUpdateResult("StateScope", stateScope);
169    }
170
171    protected override void Run(CancellationToken cancellationToken) {
172      var model = Build(stateScope, Results, cancellationToken);
173      AnalyzeSolution(model.CreateRegressionSolution(Problem.ProblemData), Results, Problem.ProblemData);
174    }
175
176    #region Static Interface
177    public static IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, IRandom random, ILeafModel leafModel = null, ISplitter splitter = null, IPruning pruning = null,
178      bool useHoldout = false, double holdoutSize = 0.2, int minimumLeafSize = 1, bool generateRules = false, ResultCollection results = null, CancellationToken? cancellationToken = null) {
179      if (leafModel == null) leafModel = new LinearLeaf();
180      if (splitter == null) splitter = new Splitter();
181      if (cancellationToken == null) cancellationToken = CancellationToken.None;
182      if (pruning == null) pruning = new ComplexityPruning();
183
184      var stateScope = InitializeScope(random, problemData, pruning, minimumLeafSize, leafModel, splitter, generateRules, useHoldout, holdoutSize);
185      var model = Build(stateScope, results, cancellationToken.Value);
186      return model.CreateRegressionSolution(problemData);
187    }
188
189    public static void UpdateModel(IDecisionTreeModel model, IRegressionProblemData problemData, IRandom random, ILeafModel leafModel, CancellationToken? cancellationToken = null) {
190      if (cancellationToken == null) cancellationToken = CancellationToken.None;
191      var regressionTreeParameters = new RegressionTreeParameters(leafModel, problemData, random);
192      var scope = new Scope();
193      scope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParameters));
194      leafModel.Initialize(scope);
195      model.Update(problemData.TrainingIndices.ToList(), scope, cancellationToken.Value);
196    }
197    #endregion
198
199    #region Helpers
200    private static IScope InitializeScope(IRandom random, IRegressionProblemData problemData, IPruning pruning, int minLeafSize, ILeafModel leafModel, ISplitter splitter, bool generateRules, bool useHoldout, double holdoutSize) {
201      var stateScope = new Scope("RegressionTreeStateScope");
202
203      //reduce RegressionProblemData to AllowedInput & Target column wise and to TrainingSet row wise
204      var doubleVars = new HashSet<string>(problemData.Dataset.DoubleVariables);
205      var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray();
206      if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("Decision tree regression supports only double valued input or output features.");
207      var doubles = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();
208      if (doubles.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
209        throw new NotSupportedException("Decision tree regression does not support NaN or infinity values in the input dataset.");
210      var trainingData = new Dataset(vars, doubles);
211      var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);
212      pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows;
213      pd.TrainingPartition.Start = 0;
214
215      //store regression tree parameters
216      var regressionTreeParams = new RegressionTreeParameters(pruning, minLeafSize, leafModel, pd, random, splitter);
217      stateScope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParams));
218
219      //initialize tree operators
220      pruning.Initialize(stateScope);
221      splitter.Initialize(stateScope);
222      leafModel.Initialize(stateScope);
223
224      //store unbuilt model
225      IItem model;
226      if (generateRules) {
227        model = RegressionRuleSetModel.CreateRuleModel(problemData.TargetVariable, regressionTreeParams);
228        RegressionRuleSetModel.Initialize(stateScope);
229      }
230      else {
231        model = RegressionNodeTreeModel.CreateTreeModel(problemData.TargetVariable, regressionTreeParams);
232      }
233      stateScope.Variables.Add(new Variable(ModelVariableName, model));
234
235      //store training & pruning indices
236      IReadOnlyList<int> trainingSet, pruningSet;
237      GeneratePruningSet(pd.TrainingIndices.ToArray(), random, useHoldout, holdoutSize, out trainingSet, out pruningSet);
238      stateScope.Variables.Add(new Variable(TrainingSetVariableName, new IntArray(trainingSet.ToArray())));
239      stateScope.Variables.Add(new Variable(PruningSetVariableName, new IntArray(pruningSet.ToArray())));
240
241      return stateScope;
242    }
243
244    private static IRegressionModel Build(IScope stateScope, ResultCollection results, CancellationToken cancellationToken) {
245      var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[RegressionTreeParameterVariableName].Value;
246      var model = (IDecisionTreeModel)stateScope.Variables[ModelVariableName].Value;
247      var trainingRows = (IntArray)stateScope.Variables[TrainingSetVariableName].Value;
248      var pruningRows = (IntArray)stateScope.Variables[PruningSetVariableName].Value;
249      if (1 > trainingRows.Length)
250        return new PreconstructedLinearModel(new Dictionary<string, double>(), 0, regressionTreeParams.TargetVariable);
251      if (regressionTreeParams.MinLeafSize > trainingRows.Length) {
252        var targets = regressionTreeParams.Data.GetDoubleValues(regressionTreeParams.TargetVariable).ToArray();
253        return new PreconstructedLinearModel(new Dictionary<string, double>(), targets.Average(), regressionTreeParams.TargetVariable);
254      }
255      model.Build(trainingRows.ToArray(), pruningRows.ToArray(), stateScope, results, cancellationToken);
256      return model;
257    }
258
259    private static void GeneratePruningSet(IReadOnlyList<int> allrows, IRandom random, bool useHoldout, double holdoutSize, out IReadOnlyList<int> training, out IReadOnlyList<int> pruning) {
260      if (!useHoldout) {
261        training = allrows;
262        pruning = allrows;
263        return;
264      }
265      var perm = new Permutation(PermutationTypes.Absolute, allrows.Count, random);
266      var cut = (int)(holdoutSize * allrows.Count);
267      pruning = perm.Take(cut).Select(i => allrows[i]).ToArray();
268      training = perm.Take(cut).Select(i => allrows[i]).ToArray();
269    }
270
271    private void AnalyzeSolution(IRegressionSolution solution, ResultCollection results, IRegressionProblemData problemData) {
272      results.Add(new Result("RegressionSolution", (IItem)solution.Clone()));
273
274      Dictionary<string, int> frequencies = null;
275
276      var tree = solution.Model as RegressionNodeTreeModel;
277      if (tree != null) {
278        results.Add(RegressionTreeAnalyzer.CreateLeafDepthHistogram(tree));
279        frequencies = RegressionTreeAnalyzer.GetTreeVariableFrequences(tree);
280        RegressionTreeAnalyzer.AnalyzeNodes(tree, results, problemData);
281      }
282
283      var ruleSet = solution.Model as RegressionRuleSetModel;
284      if (ruleSet != null) {
285        results.Add(RegressionTreeAnalyzer.CreateRulesResult(ruleSet, problemData, "Rules", true));
286        frequencies = RegressionTreeAnalyzer.GetRuleVariableFrequences(ruleSet);
287        results.Add(RegressionTreeAnalyzer.CreateCoverageDiagram(ruleSet, problemData));
288      }
289
290      //Variable frequencies
291      if (frequencies != null) {
292        var sum = frequencies.Values.Sum();
293        sum = sum == 0 ? 1 : sum;
294        var impactArray = new DoubleArray(frequencies.Select(i => (double)i.Value / sum).ToArray()) {
295          ElementNames = frequencies.Select(i => i.Key)
296        };
297        results.Add(new Result("Variable Frequences", "relative frequencies of variables in rules and tree nodes", impactArray));
298      }
299
300      var pruning = Pruning as ComplexityPruning;
301      if (pruning != null && tree != null)
302        RegressionTreeAnalyzer.PruningChart(tree, pruning, results);
303    }
304    #endregion
305  }
306}
Note: See TracBrowser for help on using the repository browser.