Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2847_M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/Pruning/ComplexityPruning.cs @ 16847

Last change on this file since 16847 was 16847, checked in by gkronber, 5 years ago

#2847: made some minor changes while reviewing

File size: 15.3 KB
RevLine 
[15830]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Parameters;
30using HeuristicLab.Problems.DataAnalysis;
[16847]31using HEAL.Attic;
[15830]32
33namespace HeuristicLab.Algorithms.DataAnalysis {
[16847]34  [StorableType("B643D965-D13F-415D-8589-3F3527460347")]
[15830]35  public class ComplexityPruning : ParameterizedNamedItem, IPruning {
36    public const string PruningStateVariableName = "PruningState";
37
[16847]38    public const string PruningStrengthParameterName = "PruningStrength";
39    public const string PruningDecayParameterName = "PruningDecay";
40    public const string FastPruningParameterName = "FastPruning";
[15830]41
42    public IFixedValueParameter<DoubleValue> PruningStrengthParameter {
43      get { return (IFixedValueParameter<DoubleValue>)Parameters[PruningStrengthParameterName]; }
44    }
45    public IFixedValueParameter<DoubleValue> PruningDecayParameter {
46      get { return (IFixedValueParameter<DoubleValue>)Parameters[PruningDecayParameterName]; }
47    }
48    public IFixedValueParameter<BoolValue> FastPruningParameter {
49      get { return (IFixedValueParameter<BoolValue>)Parameters[FastPruningParameterName]; }
50    }
51
52    public double PruningStrength {
53      get { return PruningStrengthParameter.Value.Value; }
[16847]54      set { PruningStrengthParameter.Value.Value = value; }
[15830]55    }
56    public double PruningDecay {
57      get { return PruningDecayParameter.Value.Value; }
[16847]58      set { PruningDecayParameter.Value.Value = value; }
[15830]59    }
60    public bool FastPruning {
61      get { return FastPruningParameter.Value.Value; }
[16847]62      set { FastPruningParameter.Value.Value = value; }
[15830]63    }
64
65    #region Constructors & Cloning
66    [StorableConstructor]
[16847]67    protected ComplexityPruning(StorableConstructorFlag _) : base(_) { }
[15830]68    protected ComplexityPruning(ComplexityPruning original, Cloner cloner) : base(original, cloner) { }
69    public ComplexityPruning() {
[16847]70      Parameters.Add(new FixedValueParameter<DoubleValue>(PruningStrengthParameterName, "The strength of the pruning. Higher values force the algorithm to create simpler models (default=2.0).", new DoubleValue(2.0)));
71      Parameters.Add(new FixedValueParameter<DoubleValue>(PruningDecayParameterName, "Pruning decay allows nodes higher up in the tree to be more stable (default=1.0).", new DoubleValue(1.0)));
72      Parameters.Add(new FixedValueParameter<BoolValue>(FastPruningParameterName, "Accelerate pruning by using linear models instead of leaf models (default=true).", new BoolValue(true)));
[15830]73    }
74    public override IDeepCloneable Clone(Cloner cloner) {
75      return new ComplexityPruning(this, cloner);
76    }
77    #endregion
78
79    #region IPruning
80    public int MinLeafSize(IRegressionProblemData pd, ILeafModel leafModel) {
81      return (FastPruning ? new LinearLeaf() : leafModel).MinLeafSize(pd);
82    }
83    public void Initialize(IScope states) {
84      states.Variables.Add(new Variable(PruningStateVariableName, new PruningState()));
85    }
86
87    public void Prune(RegressionNodeTreeModel treeModel, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, IScope statescope, CancellationToken cancellationToken) {
88      var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[M5Regression.RegressionTreeParameterVariableName].Value;
89      var state = (PruningState)statescope.Variables[PruningStateVariableName].Value;
90
91      var leaf = FastPruning ? new LinearLeaf() : regressionTreeParams.LeafModel;
92      if (state.Code <= 1) {
93        InstallModels(treeModel, state, trainingRows, pruningRows, leaf, regressionTreeParams, cancellationToken);
94        cancellationToken.ThrowIfCancellationRequested();
95      }
96      if (state.Code <= 2) {
97        AssignPruningThresholds(treeModel, state, PruningDecay);
98        cancellationToken.ThrowIfCancellationRequested();
99      }
100      if (state.Code <= 3) {
101        UpdateThreshold(treeModel, state);
102        cancellationToken.ThrowIfCancellationRequested();
103      }
104      if (state.Code <= 4) {
105        Prune(treeModel, state, PruningStrength);
106        cancellationToken.ThrowIfCancellationRequested();
107      }
108
109      state.Code = 5;
110    }
111    #endregion
112
113    private static void InstallModels(RegressionNodeTreeModel tree, PruningState state, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, ILeafModel leaf, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) {
114      if (state.Code == 0) {
115        state.FillBottomUp(tree, trainingRows, pruningRows, regressionTreeParams.Data);
116        state.Code = 1;
117      }
118      while (state.nodeQueue.Count != 0) {
119        cancellationToken.ThrowIfCancellationRequested();
120        var n = state.nodeQueue.Peek();
121        var training = state.trainingRowsQueue.Peek();
122        var pruning = state.pruningRowsQueue.Peek();
123        BuildPruningModel(n, leaf, training, pruning, state, regressionTreeParams, cancellationToken);
124        state.nodeQueue.Dequeue();
125        state.trainingRowsQueue.Dequeue();
126        state.pruningRowsQueue.Dequeue();
127      }
128    }
129
130    private static void AssignPruningThresholds(RegressionNodeTreeModel tree, PruningState state, double pruningDecay) {
131      if (state.Code == 1) {
132        state.FillBottomUp(tree);
133        state.Code = 2;
134      }
135      while (state.nodeQueue.Count != 0) {
136        var n = state.nodeQueue.Dequeue();
137        if (n.IsLeaf) continue;
138        n.PruningStrength = PruningThreshold(state.pruningSizes[n], state.modelComplexities[n], state.nodeComplexities[n], state.modelErrors[n], SubtreeError(n, state.pruningSizes, state.modelErrors), pruningDecay);
139      }
140    }
141
142    private static void UpdateThreshold(RegressionNodeTreeModel tree, PruningState state) {
143      if (state.Code == 2) {
144        state.FillTopDown(tree);
145        state.Code = 3;
146      }
147      while (state.nodeQueue.Count != 0) {
148        var n = state.nodeQueue.Dequeue();
149        if (n.IsLeaf || n.Parent == null || double.IsNaN(n.Parent.PruningStrength)) continue;
150        n.PruningStrength = Math.Min(n.PruningStrength, n.Parent.PruningStrength);
151      }
152    }
153
154    private static void Prune(RegressionNodeTreeModel tree, PruningState state, double pruningStrength) {
155      if (state.Code == 3) {
156        state.FillTopDown(tree);
157        state.Code = 4;
158      }
159      while (state.nodeQueue.Count != 0) {
160        var n = state.nodeQueue.Dequeue();
161        if (n.IsLeaf || pruningStrength <= n.PruningStrength) continue;
162        n.ToLeaf();
163      }
164    }
165
166    private static void BuildPruningModel(RegressionNodeModel regressionNode, ILeafModel leaf, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, PruningState state, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) {
167      //create regressionProblemdata from pruning data
[16847]168      var vars = regressionTreeParams.AllowedInputVariables.Concat(new[] { regressionTreeParams.TargetVariable }).ToArray();
[15830]169      var reducedData = new Dataset(vars, vars.Select(x => regressionTreeParams.Data.GetDoubleValues(x, pruningRows).ToList()));
170      var pd = new RegressionProblemData(reducedData, regressionTreeParams.AllowedInputVariables, regressionTreeParams.TargetVariable);
171      pd.TrainingPartition.Start = pd.TrainingPartition.End = pd.TestPartition.Start = 0;
172      pd.TestPartition.End = reducedData.Rows;
173
174      //build pruning model
175      int numModelParams;
176      var model = leaf.BuildModel(trainingRows, regressionTreeParams, cancellationToken, out numModelParams);
177
178      //record error and complexities
179      var rmsModel = model.CreateRegressionSolution(pd).TestRootMeanSquaredError;
180      state.pruningSizes.Add(regressionNode, pruningRows.Count);
181      state.modelErrors.Add(regressionNode, rmsModel);
182      state.modelComplexities.Add(regressionNode, numModelParams);
[16847]183      if (regressionNode.IsLeaf) { state.nodeComplexities[regressionNode] = state.modelComplexities[regressionNode]; } else { state.nodeComplexities.Add(regressionNode, state.nodeComplexities[regressionNode.Left] + state.nodeComplexities[regressionNode.Right] + 1); }
[15830]184    }
185
186    private static double PruningThreshold(double noIntances, double modelParams, double nodeParams, double modelError, double nodeError, double w) {
187      var res = modelError / nodeError;
188      if (modelError.IsAlmost(nodeError)) res = 1.0;
189      res /= Math.Pow((nodeParams + noIntances) / (2 * (modelParams + noIntances)), w);
190      return res;
191    }
192
193    private static double SubtreeError(RegressionNodeModel regressionNode, IDictionary<RegressionNodeModel, int> pruningSizes,
194      IDictionary<RegressionNodeModel, double> modelErrors) {
195      if (regressionNode.IsLeaf) return modelErrors[regressionNode];
196      var errorL = SubtreeError(regressionNode.Left, pruningSizes, modelErrors);
197      var errorR = SubtreeError(regressionNode.Right, pruningSizes, modelErrors);
198      errorL = errorL * errorL * pruningSizes[regressionNode.Left];
199      errorR = errorR * errorR * pruningSizes[regressionNode.Right];
200      return Math.Sqrt((errorR + errorL) / pruningSizes[regressionNode]);
201    }
202
[16847]203    [StorableType("EAD60C7E-2C58-45C4-9697-6F735F518CFD")]
[15830]204    public class PruningState : Item {
205      [Storable]
206      public IDictionary<RegressionNodeModel, int> modelComplexities = new Dictionary<RegressionNodeModel, int>();
207      [Storable]
208      public IDictionary<RegressionNodeModel, int> nodeComplexities = new Dictionary<RegressionNodeModel, int>();
209      [Storable]
210      public IDictionary<RegressionNodeModel, int> pruningSizes = new Dictionary<RegressionNodeModel, int>();
211      [Storable]
212      public IDictionary<RegressionNodeModel, double> modelErrors = new Dictionary<RegressionNodeModel, double>();
213
214      [Storable]
215      public Queue<RegressionNodeModel> nodeQueue = new Queue<RegressionNodeModel>();
216      [Storable]
217      public Queue<IReadOnlyList<int>> trainingRowsQueue = new Queue<IReadOnlyList<int>>();
218      [Storable]
219      public Queue<IReadOnlyList<int>> pruningRowsQueue = new Queue<IReadOnlyList<int>>();
220
221      //State.Code values denote the current action (for pausing)
222      //0...nothing has been done;
223      //1...building Models;
224      //2...assigning threshold
225      //3...adjusting threshold
226      //4...pruning
227      //5...finished
228      [Storable]
229      public int Code = 0;
230
231      #region HLConstructors & Cloning
232      [StorableConstructor]
[16847]233      protected PruningState(StorableConstructorFlag _) : base(_) { }
[15830]234      protected PruningState(PruningState original, Cloner cloner) : base(original, cloner) {
235        modelComplexities = original.modelComplexities.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
236        nodeComplexities = original.nodeComplexities.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
237        pruningSizes = original.pruningSizes.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
238        modelErrors = original.modelErrors.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
239        nodeQueue = new Queue<RegressionNodeModel>(original.nodeQueue.Select(cloner.Clone));
240        trainingRowsQueue = new Queue<IReadOnlyList<int>>(original.trainingRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
241        pruningRowsQueue = new Queue<IReadOnlyList<int>>(original.pruningRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
242        Code = original.Code;
243      }
244      public PruningState() { }
245      public override IDeepCloneable Clone(Cloner cloner) {
246        return new PruningState(this, cloner);
247      }
248      #endregion
249
250      public void FillTopDown(RegressionNodeTreeModel tree) {
251        var helperQueue = new Queue<RegressionNodeModel>();
252        nodeQueue.Clear();
253
254        helperQueue.Enqueue(tree.Root);
255        nodeQueue.Enqueue(tree.Root);
256
257        while (helperQueue.Count != 0) {
258          var n = helperQueue.Dequeue();
259          if (n.IsLeaf) continue;
260          helperQueue.Enqueue(n.Left);
261          helperQueue.Enqueue(n.Right);
262          nodeQueue.Enqueue(n.Left);
263          nodeQueue.Enqueue(n.Right);
264        }
265      }
266
267      public void FillTopDown(RegressionNodeTreeModel tree, IReadOnlyList<int> pruningRows, IReadOnlyList<int> trainingRows, IDataset data) {
268        var helperQueue = new Queue<RegressionNodeModel>();
269        var trainingHelperQueue = new Queue<IReadOnlyList<int>>();
270        var pruningHelperQueue = new Queue<IReadOnlyList<int>>();
271        nodeQueue.Clear();
272        trainingRowsQueue.Clear();
273        pruningRowsQueue.Clear();
274
275        helperQueue.Enqueue(tree.Root);
276
277        trainingHelperQueue.Enqueue(trainingRows);
278        pruningHelperQueue.Enqueue(pruningRows);
279
280        nodeQueue.Enqueue(tree.Root);
281        trainingRowsQueue.Enqueue(trainingRows);
282        pruningRowsQueue.Enqueue(pruningRows);
283
284
285        while (helperQueue.Count != 0) {
286          var n = helperQueue.Dequeue();
287          var p = pruningHelperQueue.Dequeue();
288          var t = trainingHelperQueue.Dequeue();
289          if (n.IsLeaf) continue;
290
291          IReadOnlyList<int> leftPruning, rightPruning;
292          RegressionTreeUtilities.SplitRows(p, data, n.SplitAttribute, n.SplitValue, out leftPruning, out rightPruning);
293          IReadOnlyList<int> leftTraining, rightTraining;
294          RegressionTreeUtilities.SplitRows(t, data, n.SplitAttribute, n.SplitValue, out leftTraining, out rightTraining);
295
296          helperQueue.Enqueue(n.Left);
297          helperQueue.Enqueue(n.Right);
298          trainingHelperQueue.Enqueue(leftTraining);
299          trainingHelperQueue.Enqueue(rightTraining);
300          pruningHelperQueue.Enqueue(leftPruning);
301          pruningHelperQueue.Enqueue(rightPruning);
302
303          nodeQueue.Enqueue(n.Left);
304          nodeQueue.Enqueue(n.Right);
305          trainingRowsQueue.Enqueue(leftTraining);
306          trainingRowsQueue.Enqueue(rightTraining);
307          pruningRowsQueue.Enqueue(leftPruning);
308          pruningRowsQueue.Enqueue(rightPruning);
309        }
310      }
311
312      public void FillBottomUp(RegressionNodeTreeModel tree) {
313        FillTopDown(tree);
314        nodeQueue = new Queue<RegressionNodeModel>(nodeQueue.Reverse());
315      }
316
317      public void FillBottomUp(RegressionNodeTreeModel tree, IReadOnlyList<int> pruningRows, IReadOnlyList<int> trainingRows, IDataset data) {
318        FillTopDown(tree, pruningRows, trainingRows, data);
319        nodeQueue = new Queue<RegressionNodeModel>(nodeQueue.Reverse());
320        pruningRowsQueue = new Queue<IReadOnlyList<int>>(pruningRowsQueue.Reverse());
321        trainingRowsQueue = new Queue<IReadOnlyList<int>>(trainingRowsQueue.Reverse());
322      }
323    }
324  }
325}
Note: See TracBrowser for help on using the repository browser.