source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/Pruning/ComplexityPruning.cs @ 15830

Last change on this file since 15830 was 15830, checked in by bwerth, 16 months ago

#2847 adapted project to new rep structure; major changes to interfaces; restructures splitting and pruning

File size: 15.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  public class ComplexityPruning : ParameterizedNamedItem, IPruning {
35    public const string PruningStateVariableName = "PruningState";
36
37    private const string PruningStrengthParameterName = "PruningStrength";
38    private const string PruningDecayParameterName = "PruningDecay";
39    private const string FastPruningParameterName = "FastPruning";
40
41    public IFixedValueParameter<DoubleValue> PruningStrengthParameter {
42      get { return (IFixedValueParameter<DoubleValue>)Parameters[PruningStrengthParameterName]; }
43    }
44    public IFixedValueParameter<DoubleValue> PruningDecayParameter {
45      get { return (IFixedValueParameter<DoubleValue>)Parameters[PruningDecayParameterName]; }
46    }
47    public IFixedValueParameter<BoolValue> FastPruningParameter {
48      get { return (IFixedValueParameter<BoolValue>)Parameters[FastPruningParameterName]; }
49    }
50
51    public double PruningStrength {
52      get { return PruningStrengthParameter.Value.Value; }
53    }
54    public double PruningDecay {
55      get { return PruningDecayParameter.Value.Value; }
56    }
57    public bool FastPruning {
58      get { return FastPruningParameter.Value.Value; }
59    }
60
61    #region Constructors & Cloning
62    [StorableConstructor]
63    protected ComplexityPruning(bool deserializing) : base(deserializing) { }
64    protected ComplexityPruning(ComplexityPruning original, Cloner cloner) : base(original, cloner) { }
65    public ComplexityPruning() {
66      Parameters.Add(new FixedValueParameter<DoubleValue>(PruningStrengthParameterName, "The strength of the pruning. Higher values force the algorithm to create simpler models", new DoubleValue(2.0)));
67      Parameters.Add(new FixedValueParameter<DoubleValue>(PruningDecayParameterName, "Pruning decay allows nodes higher up in the tree to be more stable.", new DoubleValue(1.0)));
68      Parameters.Add(new FixedValueParameter<BoolValue>(FastPruningParameterName, "Accelerate Pruning by using linear models instead of leaf models", new BoolValue(true)));
69    }
70    public override IDeepCloneable Clone(Cloner cloner) {
71      return new ComplexityPruning(this, cloner);
72    }
73    #endregion
74
75    #region IPruning
76    public int MinLeafSize(IRegressionProblemData pd, ILeafModel leafModel) {
77      return (FastPruning ? new LinearLeaf() : leafModel).MinLeafSize(pd);
78    }
79    public void Initialize(IScope states) {
80      states.Variables.Add(new Variable(PruningStateVariableName, new PruningState()));
81    }
82
83    public void Prune(RegressionNodeTreeModel treeModel, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, IScope statescope, CancellationToken cancellationToken) {
84      var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[M5Regression.RegressionTreeParameterVariableName].Value;
85      var state = (PruningState)statescope.Variables[PruningStateVariableName].Value;
86
87      var leaf = FastPruning ? new LinearLeaf() : regressionTreeParams.LeafModel;
88      if (state.Code <= 1) {
89        InstallModels(treeModel, state, trainingRows, pruningRows, leaf, regressionTreeParams, cancellationToken);
90        cancellationToken.ThrowIfCancellationRequested();
91      }
92      if (state.Code <= 2) {
93        AssignPruningThresholds(treeModel, state, PruningDecay);
94        cancellationToken.ThrowIfCancellationRequested();
95      }
96      if (state.Code <= 3) {
97        UpdateThreshold(treeModel, state);
98        cancellationToken.ThrowIfCancellationRequested();
99      }
100      if (state.Code <= 4) {
101        Prune(treeModel, state, PruningStrength);
102        cancellationToken.ThrowIfCancellationRequested();
103      }
104
105      state.Code = 5;
106    }
107    #endregion
108
109    private static void InstallModels(RegressionNodeTreeModel tree, PruningState state, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, ILeafModel leaf, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) {
110      if (state.Code == 0) {
111        state.FillBottomUp(tree, trainingRows, pruningRows, regressionTreeParams.Data);
112        state.Code = 1;
113      }
114      while (state.nodeQueue.Count != 0) {
115        cancellationToken.ThrowIfCancellationRequested();
116        var n = state.nodeQueue.Peek();
117        var training = state.trainingRowsQueue.Peek();
118        var pruning = state.pruningRowsQueue.Peek();
119        BuildPruningModel(n, leaf, training, pruning, state, regressionTreeParams, cancellationToken);
120        state.nodeQueue.Dequeue();
121        state.trainingRowsQueue.Dequeue();
122        state.pruningRowsQueue.Dequeue();
123      }
124    }
125
126    private static void AssignPruningThresholds(RegressionNodeTreeModel tree, PruningState state, double pruningDecay) {
127      if (state.Code == 1) {
128        state.FillBottomUp(tree);
129        state.Code = 2;
130      }
131      while (state.nodeQueue.Count != 0) {
132        var n = state.nodeQueue.Dequeue();
133        if (n.IsLeaf) continue;
134        n.PruningStrength = PruningThreshold(state.pruningSizes[n], state.modelComplexities[n], state.nodeComplexities[n], state.modelErrors[n], SubtreeError(n, state.pruningSizes, state.modelErrors), pruningDecay);
135      }
136    }
137
138
139    private static void UpdateThreshold(RegressionNodeTreeModel tree, PruningState state) {
140      if (state.Code == 2) {
141        state.FillTopDown(tree);
142        state.Code = 3;
143      }
144      while (state.nodeQueue.Count != 0) {
145        var n = state.nodeQueue.Dequeue();
146        if (n.IsLeaf || n.Parent == null || double.IsNaN(n.Parent.PruningStrength)) continue;
147        n.PruningStrength = Math.Min(n.PruningStrength, n.Parent.PruningStrength);
148      }
149    }
150
151    private static void Prune(RegressionNodeTreeModel tree, PruningState state, double pruningStrength) {
152      if (state.Code == 3) {
153        state.FillTopDown(tree);
154        state.Code = 4;
155      }
156      while (state.nodeQueue.Count != 0) {
157        var n = state.nodeQueue.Dequeue();
158        if (n.IsLeaf || pruningStrength <= n.PruningStrength) continue;
159        n.ToLeaf();
160      }
161    }
162
163
164    private static void BuildPruningModel(RegressionNodeModel regressionNode, ILeafModel leaf, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, PruningState state, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) {
165      //create regressionProblemdata from pruning data
166      var vars = regressionTreeParams.AllowedInputVariables.Concat(new[] {regressionTreeParams.TargetVariable}).ToArray();
167      var reducedData = new Dataset(vars, vars.Select(x => regressionTreeParams.Data.GetDoubleValues(x, pruningRows).ToList()));
168      var pd = new RegressionProblemData(reducedData, regressionTreeParams.AllowedInputVariables, regressionTreeParams.TargetVariable);
169      pd.TrainingPartition.Start = pd.TrainingPartition.End = pd.TestPartition.Start = 0;
170      pd.TestPartition.End = reducedData.Rows;
171
172      //build pruning model
173      int numModelParams;
174      var model = leaf.BuildModel(trainingRows, regressionTreeParams, cancellationToken, out numModelParams);
175
176      //record error and complexities
177      var rmsModel = model.CreateRegressionSolution(pd).TestRootMeanSquaredError;
178      state.pruningSizes.Add(regressionNode, pruningRows.Count);
179      state.modelErrors.Add(regressionNode, rmsModel);
180      state.modelComplexities.Add(regressionNode, numModelParams);
181      if (regressionNode.IsLeaf) { state.nodeComplexities[regressionNode] = state.modelComplexities[regressionNode]; }
182      else { state.nodeComplexities.Add(regressionNode, state.nodeComplexities[regressionNode.Left] + state.nodeComplexities[regressionNode.Right] + 1); }
183    }
184
185    private static double PruningThreshold(double noIntances, double modelParams, double nodeParams, double modelError, double nodeError, double w) {
186      var res = modelError / nodeError;
187      if (modelError.IsAlmost(nodeError)) res = 1.0;
188      res /= Math.Pow((nodeParams + noIntances) / (2 * (modelParams + noIntances)), w);
189      return res;
190    }
191
192    private static double SubtreeError(RegressionNodeModel regressionNode, IDictionary<RegressionNodeModel, int> pruningSizes,
193      IDictionary<RegressionNodeModel, double> modelErrors) {
194      if (regressionNode.IsLeaf) return modelErrors[regressionNode];
195      var errorL = SubtreeError(regressionNode.Left, pruningSizes, modelErrors);
196      var errorR = SubtreeError(regressionNode.Right, pruningSizes, modelErrors);
197      errorL = errorL * errorL * pruningSizes[regressionNode.Left];
198      errorR = errorR * errorR * pruningSizes[regressionNode.Right];
199      return Math.Sqrt((errorR + errorL) / pruningSizes[regressionNode]);
200    }
201
202    [StorableClass]
203    public class PruningState : Item {
204      [Storable]
205      public IDictionary<RegressionNodeModel, int> modelComplexities = new Dictionary<RegressionNodeModel, int>();
206      [Storable]
207      public IDictionary<RegressionNodeModel, int> nodeComplexities = new Dictionary<RegressionNodeModel, int>();
208      [Storable]
209      public IDictionary<RegressionNodeModel, int> pruningSizes = new Dictionary<RegressionNodeModel, int>();
210      [Storable]
211      public IDictionary<RegressionNodeModel, double> modelErrors = new Dictionary<RegressionNodeModel, double>();
212
213      [Storable]
214      public Queue<RegressionNodeModel> nodeQueue = new Queue<RegressionNodeModel>();
215      [Storable]
216      public Queue<IReadOnlyList<int>> trainingRowsQueue = new Queue<IReadOnlyList<int>>();
217      [Storable]
218      public Queue<IReadOnlyList<int>> pruningRowsQueue = new Queue<IReadOnlyList<int>>();
219
220      //State.Code values denote the current action (for pausing)
221      //0...nothing has been done;
222      //1...building Models;
223      //2...assigning threshold
224      //3...adjusting threshold
225      //4...pruning
226      //5...finished
227      [Storable]
228      public int Code = 0;
229
230      #region HLConstructors & Cloning
231      [StorableConstructor]
232      protected PruningState(bool deserializing) : base(deserializing) { }
233      protected PruningState(PruningState original, Cloner cloner) : base(original, cloner) {
234        modelComplexities = original.modelComplexities.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
235        nodeComplexities = original.nodeComplexities.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
236        pruningSizes = original.pruningSizes.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
237        modelErrors = original.modelErrors.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
238        nodeQueue = new Queue<RegressionNodeModel>(original.nodeQueue.Select(cloner.Clone));
239        trainingRowsQueue = new Queue<IReadOnlyList<int>>(original.trainingRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
240        pruningRowsQueue = new Queue<IReadOnlyList<int>>(original.pruningRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
241        Code = original.Code;
242      }
243      public PruningState() { }
244      public override IDeepCloneable Clone(Cloner cloner) {
245        return new PruningState(this, cloner);
246      }
247      #endregion
248
249      public void FillTopDown(RegressionNodeTreeModel tree) {
250        var helperQueue = new Queue<RegressionNodeModel>();
251        nodeQueue.Clear();
252
253        helperQueue.Enqueue(tree.Root);
254        nodeQueue.Enqueue(tree.Root);
255
256        while (helperQueue.Count != 0) {
257          var n = helperQueue.Dequeue();
258          if (n.IsLeaf) continue;
259          helperQueue.Enqueue(n.Left);
260          helperQueue.Enqueue(n.Right);
261          nodeQueue.Enqueue(n.Left);
262          nodeQueue.Enqueue(n.Right);
263        }
264      }
265
266      public void FillTopDown(RegressionNodeTreeModel tree, IReadOnlyList<int> pruningRows, IReadOnlyList<int> trainingRows, IDataset data) {
267        var helperQueue = new Queue<RegressionNodeModel>();
268        var trainingHelperQueue = new Queue<IReadOnlyList<int>>();
269        var pruningHelperQueue = new Queue<IReadOnlyList<int>>();
270        nodeQueue.Clear();
271        trainingRowsQueue.Clear();
272        pruningRowsQueue.Clear();
273
274        helperQueue.Enqueue(tree.Root);
275
276        trainingHelperQueue.Enqueue(trainingRows);
277        pruningHelperQueue.Enqueue(pruningRows);
278
279        nodeQueue.Enqueue(tree.Root);
280        trainingRowsQueue.Enqueue(trainingRows);
281        pruningRowsQueue.Enqueue(pruningRows);
282
283
284        while (helperQueue.Count != 0) {
285          var n = helperQueue.Dequeue();
286          var p = pruningHelperQueue.Dequeue();
287          var t = trainingHelperQueue.Dequeue();
288          if (n.IsLeaf) continue;
289
290          IReadOnlyList<int> leftPruning, rightPruning;
291          RegressionTreeUtilities.SplitRows(p, data, n.SplitAttribute, n.SplitValue, out leftPruning, out rightPruning);
292          IReadOnlyList<int> leftTraining, rightTraining;
293          RegressionTreeUtilities.SplitRows(t, data, n.SplitAttribute, n.SplitValue, out leftTraining, out rightTraining);
294
295          helperQueue.Enqueue(n.Left);
296          helperQueue.Enqueue(n.Right);
297          trainingHelperQueue.Enqueue(leftTraining);
298          trainingHelperQueue.Enqueue(rightTraining);
299          pruningHelperQueue.Enqueue(leftPruning);
300          pruningHelperQueue.Enqueue(rightPruning);
301
302          nodeQueue.Enqueue(n.Left);
303          nodeQueue.Enqueue(n.Right);
304          trainingRowsQueue.Enqueue(leftTraining);
305          trainingRowsQueue.Enqueue(rightTraining);
306          pruningRowsQueue.Enqueue(leftPruning);
307          pruningRowsQueue.Enqueue(rightPruning);
308        }
309      }
310
311      public void FillBottomUp(RegressionNodeTreeModel tree) {
312        FillTopDown(tree);
313        nodeQueue = new Queue<RegressionNodeModel>(nodeQueue.Reverse());
314      }
315
316      public void FillBottomUp(RegressionNodeTreeModel tree, IReadOnlyList<int> pruningRows, IReadOnlyList<int> trainingRows, IDataset data) {
317        FillTopDown(tree, pruningRows, trainingRows, data);
318        nodeQueue = new Queue<RegressionNodeModel>(nodeQueue.Reverse());
319        pruningRowsQueue = new Queue<IReadOnlyList<int>>(pruningRowsQueue.Reverse());
320        trainingRowsQueue = new Queue<IReadOnlyList<int>>(trainingRowsQueue.Reverse());
321      }
322    }
323  }
324}
Note: See TracBrowser for help on using the repository browser.