Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis.DecisionTrees/3.4/DecisionTreeRegression.cs @ 17460

Last change on this file since 17460 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 16.4 KB
RevLine 
[16847]1#region License Information
2/* HeuristicLab
[17181]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[16847]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22 using System;
[15430]23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
[15614]29using HeuristicLab.Encodings.PermutationEncoding;
[15430]30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.PluginInfrastructure;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Random;
[16847]35using HEAL.Attic;
[15430]36
37namespace HeuristicLab.Algorithms.DataAnalysis {
[16847]38  [StorableType("FC8D8E5A-D16D-41BB-91CF-B2B35D17ADD7")]
[15430]39  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)]
[17082]40  [Item("Decision Tree Regression (DT)", "A regression tree / rule set learner")]
[17080]41  public sealed class DecisionTreeRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
[15830]42    public override bool SupportsPause {
43      get { return true; }
44    }
45
46    public const string RegressionTreeParameterVariableName = "RegressionTreeParameters";
47    public const string ModelVariableName = "Model";
48    public const string PruningSetVariableName = "PruningSet";
49    public const string TrainingSetVariableName = "TrainingSet";
50
[16847]51    #region Parameter names
[15430]52    private const string GenerateRulesParameterName = "GenerateRules";
[15614]53    private const string HoldoutSizeParameterName = "HoldoutSize";
[16847]54    private const string SplitterParameterName = "Splitter";
[15430]55    private const string MinimalNodeSizeParameterName = "MinimalNodeSize";
[15614]56    private const string LeafModelParameterName = "LeafModel";
[15430]57    private const string PruningTypeParameterName = "PruningType";
58    private const string SeedParameterName = "Seed";
59    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
[15614]60    private const string UseHoldoutParameterName = "UseHoldout";
[15430]61    #endregion
62
63    #region Parameter properties
64    public IFixedValueParameter<BoolValue> GenerateRulesParameter {
[15614]65      get { return (IFixedValueParameter<BoolValue>)Parameters[GenerateRulesParameterName]; }
[15430]66    }
[15614]67    public IFixedValueParameter<PercentValue> HoldoutSizeParameter {
68      get { return (IFixedValueParameter<PercentValue>)Parameters[HoldoutSizeParameterName]; }
[15430]69    }
[16847]70    public IConstrainedValueParameter<ISplitter> SplitterParameter {
71      get { return (IConstrainedValueParameter<ISplitter>)Parameters[SplitterParameterName]; }
[15614]72    }
[15430]73    public IFixedValueParameter<IntValue> MinimalNodeSizeParameter {
[15614]74      get { return (IFixedValueParameter<IntValue>)Parameters[MinimalNodeSizeParameterName]; }
[15430]75    }
[15614]76    public IConstrainedValueParameter<ILeafModel> LeafModelParameter {
77      get { return (IConstrainedValueParameter<ILeafModel>)Parameters[LeafModelParameterName]; }
[15430]78    }
[15614]79    public IConstrainedValueParameter<IPruning> PruningTypeParameter {
80      get { return (IConstrainedValueParameter<IPruning>)Parameters[PruningTypeParameterName]; }
[15430]81    }
82    public IFixedValueParameter<IntValue> SeedParameter {
[15614]83      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
[15430]84    }
85    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
[15614]86      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
[15430]87    }
[15614]88    public IFixedValueParameter<BoolValue> UseHoldoutParameter {
89      get { return (IFixedValueParameter<BoolValue>)Parameters[UseHoldoutParameterName]; }
90    }
[15430]91    #endregion
92
93    #region Properties
94    public bool GenerateRules {
95      get { return GenerateRulesParameter.Value.Value; }
[16847]96      set { GenerateRulesParameter.Value.Value = value; }
[15430]97    }
[15614]98    public double HoldoutSize {
99      get { return HoldoutSizeParameter.Value.Value; }
[16847]100      set { HoldoutSizeParameter.Value.Value = value; }
[15614]101    }
[15830]102    public ISplitter Splitter {
[16847]103      get { return SplitterParameter.Value; }
104      // no setter because this is a constrained parameter
[15430]105    }
106    public int MinimalNodeSize {
107      get { return MinimalNodeSizeParameter.Value.Value; }
[16847]108      set { MinimalNodeSizeParameter.Value.Value = value; }
[15430]109    }
[15614]110    public ILeafModel LeafModel {
111      get { return LeafModelParameter.Value; }
[15430]112    }
[15614]113    public IPruning Pruning {
[15430]114      get { return PruningTypeParameter.Value; }
115    }
116    public int Seed {
117      get { return SeedParameter.Value.Value; }
[16847]118      set { SeedParameter.Value.Value = value; }
[15430]119    }
120    public bool SetSeedRandomly {
121      get { return SetSeedRandomlyParameter.Value.Value; }
[16847]122      set { SetSeedRandomlyParameter.Value.Value = value; }
[15430]123    }
[15614]124    public bool UseHoldout {
125      get { return UseHoldoutParameter.Value.Value; }
[16847]126      set { UseHoldoutParameter.Value.Value = value; }
[15614]127    }
[15430]128    #endregion
129
[15830]130    #region State
131    [Storable]
132    private IScope stateScope;
133    #endregion
134
[15430]135    #region Constructors and Cloning
136    [StorableConstructor]
[17080]137    private DecisionTreeRegression(StorableConstructorFlag _) : base(_) { }
138    private DecisionTreeRegression(DecisionTreeRegression original, Cloner cloner) : base(original, cloner) {
[15830]139      stateScope = cloner.Clone(stateScope);
140    }
[17080]141    public DecisionTreeRegression() {
[15614]142      var modelSet = new ItemSet<ILeafModel>(ApplicationManager.Manager.GetInstances<ILeafModel>());
143      var pruningSet = new ItemSet<IPruning>(ApplicationManager.Manager.GetInstances<IPruning>());
[16847]144      var splitterSet = new ItemSet<ISplitter>(ApplicationManager.Manager.GetInstances<ISplitter>());
145      Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created (default=false)", new BoolValue(false)));
146      Parameters.Add(new FixedValueParameter<PercentValue>(HoldoutSizeParameterName, "How much of the training set shall be reserved for pruning (default=20%).", new PercentValue(0.2)));
[17081]147      Parameters.Add(new ConstrainedValueParameter<ISplitter>(SplitterParameterName, "The type of split function used to create node splits (default='Splitter').", splitterSet, splitterSet.OfType<Splitter>().First()));
[16847]148      Parameters.Add(new FixedValueParameter<IntValue>(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node (default=1).", new IntValue(1)));
149      Parameters.Add(new ConstrainedValueParameter<ILeafModel>(LeafModelParameterName, "The type of model used for the nodes (default='LinearLeaf').", modelSet, modelSet.OfType<LinearLeaf>().First()));
150      Parameters.Add(new ConstrainedValueParameter<IPruning>(PruningTypeParameterName, "The type of pruning used (default='ComplexityPruning').", pruningSet, pruningSet.OfType<ComplexityPruning>().First()));
[15430]151      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
152      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
[16847]153      Parameters.Add(new FixedValueParameter<BoolValue>(UseHoldoutParameterName, "True if a holdout set should be generated, false if splitting and pruning shall be performed on the same data (default=false).", new BoolValue(false)));
[15430]154      Problem = new RegressionProblem();
155    }
156    public override IDeepCloneable Clone(Cloner cloner) {
[17080]157      return new DecisionTreeRegression(this, cloner);
[15430]158    }
159    #endregion
160
[15830]161    protected override void Initialize(CancellationToken cancellationToken) {
162      base.Initialize(cancellationToken);
[15430]163      var random = new MersenneTwister();
[16847]164      if (SetSeedRandomly) Seed = RandomSeedGenerator.GetSeed();
[15430]165      random.Reset(Seed);
[15830]166      stateScope = InitializeScope(random, Problem.ProblemData, Pruning, MinimalNodeSize, LeafModel, Splitter, GenerateRules, UseHoldout, HoldoutSize);
167      stateScope.Variables.Add(new Variable("Algorithm", this));
168      Results.AddOrUpdateResult("StateScope", stateScope);
[15430]169    }
170
[15830]171    protected override void Run(CancellationToken cancellationToken) {
172      var model = Build(stateScope, Results, cancellationToken);
173      AnalyzeSolution(model.CreateRegressionSolution(Problem.ProblemData), Results, Problem.ProblemData);
174    }
175
[15430]176    #region Static Interface
[15830]177    public static IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, IRandom random, ILeafModel leafModel = null, ISplitter splitter = null, IPruning pruning = null,
[15833]178      bool useHoldout = false, double holdoutSize = 0.2, int minimumLeafSize = 1, bool generateRules = false, ResultCollection results = null, CancellationToken? cancellationToken = null) {
[15614]179      if (leafModel == null) leafModel = new LinearLeaf();
[17081]180      if (splitter == null) splitter = new Splitter();
[15430]181      if (cancellationToken == null) cancellationToken = CancellationToken.None;
[15830]182      if (pruning == null) pruning = new ComplexityPruning();
[15430]183
[15830]184      var stateScope = InitializeScope(random, problemData, pruning, minimumLeafSize, leafModel, splitter, generateRules, useHoldout, holdoutSize);
185      var model = Build(stateScope, results, cancellationToken.Value);
186      return model.CreateRegressionSolution(problemData);
187    }
188
[17081]189    public static void UpdateModel(IDecisionTreeModel model, IRegressionProblemData problemData, IRandom random, ILeafModel leafModel, CancellationToken? cancellationToken = null) {
[15830]190      if (cancellationToken == null) cancellationToken = CancellationToken.None;
191      var regressionTreeParameters = new RegressionTreeParameters(leafModel, problemData, random);
192      var scope = new Scope();
193      scope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParameters));
194      leafModel.Initialize(scope);
195      model.Update(problemData.TrainingIndices.ToList(), scope, cancellationToken.Value);
196    }
197    #endregion
198
199    #region Helpers
[15833]200    private static IScope InitializeScope(IRandom random, IRegressionProblemData problemData, IPruning pruning, int minLeafSize, ILeafModel leafModel, ISplitter splitter, bool generateRules, bool useHoldout, double holdoutSize) {
[15830]201      var stateScope = new Scope("RegressionTreeStateScope");
202
203      //reduce RegressionProblemData to AllowedInput & Target column wise and to TrainingSet row wise
[15430]204      var doubleVars = new HashSet<string>(problemData.Dataset.DoubleVariables);
205      var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray();
[17080]206      if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("Decision tree regression supports only double valued input or output features.");
[15830]207      var doubles = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();
208      if (doubles.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
[17080]209        throw new NotSupportedException("Decision tree regression does not support NaN or infinity values in the input dataset.");
[15830]210      var trainingData = new Dataset(vars, doubles);
[15430]211      var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);
212      pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows;
213      pd.TrainingPartition.Start = 0;
214
[15830]215      //store regression tree parameters
216      var regressionTreeParams = new RegressionTreeParameters(pruning, minLeafSize, leafModel, pd, random, splitter);
217      stateScope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParams));
[15430]218
[15830]219      //initialize tree operators
220      pruning.Initialize(stateScope);
221      splitter.Initialize(stateScope);
222      leafModel.Initialize(stateScope);
[15430]223
[15830]224      //store unbuilt model
225      IItem model;
[15833]226      if (generateRules) {
[15830]227        model = RegressionRuleSetModel.CreateRuleModel(problemData.TargetVariable, regressionTreeParams);
228        RegressionRuleSetModel.Initialize(stateScope);
229      }
230      else {
231        model = RegressionNodeTreeModel.CreateTreeModel(problemData.TargetVariable, regressionTreeParams);
232      }
233      stateScope.Variables.Add(new Variable(ModelVariableName, model));
[15430]234
[15830]235      //store training & pruning indices
236      IReadOnlyList<int> trainingSet, pruningSet;
237      GeneratePruningSet(pd.TrainingIndices.ToArray(), random, useHoldout, holdoutSize, out trainingSet, out pruningSet);
238      stateScope.Variables.Add(new Variable(TrainingSetVariableName, new IntArray(trainingSet.ToArray())));
239      stateScope.Variables.Add(new Variable(PruningSetVariableName, new IntArray(pruningSet.ToArray())));
[15430]240
[15830]241      return stateScope;
[15430]242    }
243
[15830]244    private static IRegressionModel Build(IScope stateScope, ResultCollection results, CancellationToken cancellationToken) {
[15833]245      var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[RegressionTreeParameterVariableName].Value;
[17081]246      var model = (IDecisionTreeModel)stateScope.Variables[ModelVariableName].Value;
[15830]247      var trainingRows = (IntArray)stateScope.Variables[TrainingSetVariableName].Value;
248      var pruningRows = (IntArray)stateScope.Variables[PruningSetVariableName].Value;
[15833]249      if (1 > trainingRows.Length)
[15967]250        return new PreconstructedLinearModel(new Dictionary<string, double>(), 0, regressionTreeParams.TargetVariable);
[15833]251      if (regressionTreeParams.MinLeafSize > trainingRows.Length) {
252        var targets = regressionTreeParams.Data.GetDoubleValues(regressionTreeParams.TargetVariable).ToArray();
[15967]253        return new PreconstructedLinearModel(new Dictionary<string, double>(), targets.Average(), regressionTreeParams.TargetVariable);
[15833]254      }
[15830]255      model.Build(trainingRows.ToArray(), pruningRows.ToArray(), stateScope, results, cancellationToken);
256      return model;
[15430]257    }
258
[15614]259    private static void GeneratePruningSet(IReadOnlyList<int> allrows, IRandom random, bool useHoldout, double holdoutSize, out IReadOnlyList<int> training, out IReadOnlyList<int> pruning) {
260      if (!useHoldout) {
261        training = allrows;
262        pruning = allrows;
263        return;
264      }
265      var perm = new Permutation(PermutationTypes.Absolute, allrows.Count, random);
266      var cut = (int)(holdoutSize * allrows.Count);
267      pruning = perm.Take(cut).Select(i => allrows[i]).ToArray();
268      training = perm.Take(cut).Select(i => allrows[i]).ToArray();
269    }
270
[15830]271    private void AnalyzeSolution(IRegressionSolution solution, ResultCollection results, IRegressionProblemData problemData) {
272      results.Add(new Result("RegressionSolution", (IItem)solution.Clone()));
[15430]273
[15830]274      Dictionary<string, int> frequencies = null;
275
276      var tree = solution.Model as RegressionNodeTreeModel;
277      if (tree != null) {
278        results.Add(RegressionTreeAnalyzer.CreateLeafDepthHistogram(tree));
279        frequencies = RegressionTreeAnalyzer.GetTreeVariableFrequences(tree);
280        RegressionTreeAnalyzer.AnalyzeNodes(tree, results, problemData);
[15430]281      }
[15830]282
283      var ruleSet = solution.Model as RegressionRuleSetModel;
284      if (ruleSet != null) {
[17080]285        results.Add(RegressionTreeAnalyzer.CreateRulesResult(ruleSet, problemData, "Rules", true));
[15830]286        frequencies = RegressionTreeAnalyzer.GetRuleVariableFrequences(ruleSet);
287        results.Add(RegressionTreeAnalyzer.CreateCoverageDiagram(ruleSet, problemData));
[15430]288      }
289
290      //Variable frequencies
[15830]291      if (frequencies != null) {
292        var sum = frequencies.Values.Sum();
293        sum = sum == 0 ? 1 : sum;
294        var impactArray = new DoubleArray(frequencies.Select(i => (double)i.Value / sum).ToArray()) {
295          ElementNames = frequencies.Select(i => i.Key)
296        };
297        results.Add(new Result("Variable Frequences", "relative frequencies of variables in rules and tree nodes", impactArray));
298      }
299
300      var pruning = Pruning as ComplexityPruning;
301      if (pruning != null && tree != null)
302        RegressionTreeAnalyzer.PruningChart(tree, pruning, results);
[15430]303    }
304    #endregion
305  }
306}
Note: See TracBrowser for help on using the repository browser.