Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis.DecisionTrees/3.4/Pruning/ComplexityPruning.cs @ 17440

Last change on this file since 17440 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 16.0 KB
RevLine 
[15830]1#region License Information
2/* HeuristicLab
[17181]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[15830]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Parameters;
30using HeuristicLab.Problems.DataAnalysis;
[16847]31using HEAL.Attic;
[15830]32
33namespace HeuristicLab.Algorithms.DataAnalysis {
[16847]34  [StorableType("B643D965-D13F-415D-8589-3F3527460347")]
[15830]35  public class ComplexityPruning : ParameterizedNamedItem, IPruning {
36    public const string PruningStateVariableName = "PruningState";
37
[16847]38    public const string PruningStrengthParameterName = "PruningStrength";
39    public const string PruningDecayParameterName = "PruningDecay";
40    public const string FastPruningParameterName = "FastPruning";
[15830]41
42    public IFixedValueParameter<DoubleValue> PruningStrengthParameter {
43      get { return (IFixedValueParameter<DoubleValue>)Parameters[PruningStrengthParameterName]; }
44    }
45    public IFixedValueParameter<DoubleValue> PruningDecayParameter {
46      get { return (IFixedValueParameter<DoubleValue>)Parameters[PruningDecayParameterName]; }
47    }
48    public IFixedValueParameter<BoolValue> FastPruningParameter {
49      get { return (IFixedValueParameter<BoolValue>)Parameters[FastPruningParameterName]; }
50    }
51
52    public double PruningStrength {
53      get { return PruningStrengthParameter.Value.Value; }
[16847]54      set { PruningStrengthParameter.Value.Value = value; }
[15830]55    }
56    public double PruningDecay {
57      get { return PruningDecayParameter.Value.Value; }
[16847]58      set { PruningDecayParameter.Value.Value = value; }
[15830]59    }
60    public bool FastPruning {
61      get { return FastPruningParameter.Value.Value; }
[16847]62      set { FastPruningParameter.Value.Value = value; }
[15830]63    }
64
65    #region Constructors & Cloning
66    [StorableConstructor]
[16847]67    protected ComplexityPruning(StorableConstructorFlag _) : base(_) { }
[15830]68    protected ComplexityPruning(ComplexityPruning original, Cloner cloner) : base(original, cloner) { }
69    public ComplexityPruning() {
[16847]70      Parameters.Add(new FixedValueParameter<DoubleValue>(PruningStrengthParameterName, "The strength of the pruning. Higher values force the algorithm to create simpler models (default=2.0).", new DoubleValue(2.0)));
71      Parameters.Add(new FixedValueParameter<DoubleValue>(PruningDecayParameterName, "Pruning decay allows nodes higher up in the tree to be more stable (default=1.0).", new DoubleValue(1.0)));
72      Parameters.Add(new FixedValueParameter<BoolValue>(FastPruningParameterName, "Accelerate pruning by using linear models instead of leaf models (default=true).", new BoolValue(true)));
[15830]73    }
74    public override IDeepCloneable Clone(Cloner cloner) {
75      return new ComplexityPruning(this, cloner);
76    }
77    #endregion
78
79    #region IPruning
80    public int MinLeafSize(IRegressionProblemData pd, ILeafModel leafModel) {
81      return (FastPruning ? new LinearLeaf() : leafModel).MinLeafSize(pd);
82    }
83    public void Initialize(IScope states) {
84      states.Variables.Add(new Variable(PruningStateVariableName, new PruningState()));
85    }
86
87    public void Prune(RegressionNodeTreeModel treeModel, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, IScope statescope, CancellationToken cancellationToken) {
[17080]88      var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
[15830]89      var state = (PruningState)statescope.Variables[PruningStateVariableName].Value;
90
91      var leaf = FastPruning ? new LinearLeaf() : regressionTreeParams.LeafModel;
92      if (state.Code <= 1) {
93        InstallModels(treeModel, state, trainingRows, pruningRows, leaf, regressionTreeParams, cancellationToken);
94        cancellationToken.ThrowIfCancellationRequested();
95      }
96      if (state.Code <= 2) {
97        AssignPruningThresholds(treeModel, state, PruningDecay);
98        cancellationToken.ThrowIfCancellationRequested();
99      }
100      if (state.Code <= 3) {
101        UpdateThreshold(treeModel, state);
102        cancellationToken.ThrowIfCancellationRequested();
103      }
104      if (state.Code <= 4) {
105        Prune(treeModel, state, PruningStrength);
106        cancellationToken.ThrowIfCancellationRequested();
107      }
108
109      state.Code = 5;
110    }
111    #endregion
112
113    private static void InstallModels(RegressionNodeTreeModel tree, PruningState state, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, ILeafModel leaf, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) {
114      if (state.Code == 0) {
115        state.FillBottomUp(tree, trainingRows, pruningRows, regressionTreeParams.Data);
116        state.Code = 1;
117      }
118      while (state.nodeQueue.Count != 0) {
119        cancellationToken.ThrowIfCancellationRequested();
120        var n = state.nodeQueue.Peek();
121        var training = state.trainingRowsQueue.Peek();
122        var pruning = state.pruningRowsQueue.Peek();
123        BuildPruningModel(n, leaf, training, pruning, state, regressionTreeParams, cancellationToken);
124        state.nodeQueue.Dequeue();
125        state.trainingRowsQueue.Dequeue();
126        state.pruningRowsQueue.Dequeue();
127      }
128    }
129
130    private static void AssignPruningThresholds(RegressionNodeTreeModel tree, PruningState state, double pruningDecay) {
131      if (state.Code == 1) {
132        state.FillBottomUp(tree);
133        state.Code = 2;
134      }
135      while (state.nodeQueue.Count != 0) {
136        var n = state.nodeQueue.Dequeue();
137        if (n.IsLeaf) continue;
138        n.PruningStrength = PruningThreshold(state.pruningSizes[n], state.modelComplexities[n], state.nodeComplexities[n], state.modelErrors[n], SubtreeError(n, state.pruningSizes, state.modelErrors), pruningDecay);
139      }
140    }
141
142    private static void UpdateThreshold(RegressionNodeTreeModel tree, PruningState state) {
143      if (state.Code == 2) {
144        state.FillTopDown(tree);
145        state.Code = 3;
146      }
147      while (state.nodeQueue.Count != 0) {
148        var n = state.nodeQueue.Dequeue();
149        if (n.IsLeaf || n.Parent == null || double.IsNaN(n.Parent.PruningStrength)) continue;
150        n.PruningStrength = Math.Min(n.PruningStrength, n.Parent.PruningStrength);
151      }
152    }
153
154    private static void Prune(RegressionNodeTreeModel tree, PruningState state, double pruningStrength) {
155      if (state.Code == 3) {
156        state.FillTopDown(tree);
157        state.Code = 4;
158      }
159      while (state.nodeQueue.Count != 0) {
160        var n = state.nodeQueue.Dequeue();
161        if (n.IsLeaf || pruningStrength <= n.PruningStrength) continue;
162        n.ToLeaf();
163      }
164    }
165
166    private static void BuildPruningModel(RegressionNodeModel regressionNode, ILeafModel leaf, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, PruningState state, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) {
167      //create regressionProblemdata from pruning data
[16847]168      var vars = regressionTreeParams.AllowedInputVariables.Concat(new[] { regressionTreeParams.TargetVariable }).ToArray();
[15830]169      var reducedData = new Dataset(vars, vars.Select(x => regressionTreeParams.Data.GetDoubleValues(x, pruningRows).ToList()));
170      var pd = new RegressionProblemData(reducedData, regressionTreeParams.AllowedInputVariables, regressionTreeParams.TargetVariable);
171      pd.TrainingPartition.Start = pd.TrainingPartition.End = pd.TestPartition.Start = 0;
172      pd.TestPartition.End = reducedData.Rows;
173
174      //build pruning model
175      int numModelParams;
176      var model = leaf.BuildModel(trainingRows, regressionTreeParams, cancellationToken, out numModelParams);
177
178      //record error and complexities
179      var rmsModel = model.CreateRegressionSolution(pd).TestRootMeanSquaredError;
180      state.pruningSizes.Add(regressionNode, pruningRows.Count);
181      state.modelErrors.Add(regressionNode, rmsModel);
182      state.modelComplexities.Add(regressionNode, numModelParams);
[16847]183      if (regressionNode.IsLeaf) { state.nodeComplexities[regressionNode] = state.modelComplexities[regressionNode]; } else { state.nodeComplexities.Add(regressionNode, state.nodeComplexities[regressionNode.Left] + state.nodeComplexities[regressionNode.Right] + 1); }
[15830]184    }
185
186    private static double PruningThreshold(double noIntances, double modelParams, double nodeParams, double modelError, double nodeError, double w) {
187      var res = modelError / nodeError;
188      if (modelError.IsAlmost(nodeError)) res = 1.0;
189      res /= Math.Pow((nodeParams + noIntances) / (2 * (modelParams + noIntances)), w);
190      return res;
191    }
192
193    private static double SubtreeError(RegressionNodeModel regressionNode, IDictionary<RegressionNodeModel, int> pruningSizes,
194      IDictionary<RegressionNodeModel, double> modelErrors) {
195      if (regressionNode.IsLeaf) return modelErrors[regressionNode];
196      var errorL = SubtreeError(regressionNode.Left, pruningSizes, modelErrors);
197      var errorR = SubtreeError(regressionNode.Right, pruningSizes, modelErrors);
198      errorL = errorL * errorL * pruningSizes[regressionNode.Left];
199      errorR = errorR * errorR * pruningSizes[regressionNode.Right];
200      return Math.Sqrt((errorR + errorL) / pruningSizes[regressionNode]);
201    }
202
[16847]203    [StorableType("EAD60C7E-2C58-45C4-9697-6F735F518CFD")]
[15830]204    public class PruningState : Item {
205      [Storable]
[17159]206      public IDictionary<RegressionNodeModel, int> modelComplexities;
[15830]207      [Storable]
[17159]208      public IDictionary<RegressionNodeModel, int> nodeComplexities;
[15830]209      [Storable]
[17159]210      public IDictionary<RegressionNodeModel, int> pruningSizes;
[15830]211      [Storable]
[17159]212      public IDictionary<RegressionNodeModel, double> modelErrors;
[15830]213
214      [Storable]
[17159]215      private RegressionNodeModel[] storableNodeQueue { get { return nodeQueue.ToArray(); } set { nodeQueue = new Queue<RegressionNodeModel>(value); } }
216      public Queue<RegressionNodeModel> nodeQueue;
[15830]217      [Storable]
[17159]218      private IReadOnlyList<int>[] storabletrainingRowsQueue { get { return trainingRowsQueue.ToArray(); } set { trainingRowsQueue = new Queue<IReadOnlyList<int>>(value); } }
219      public Queue<IReadOnlyList<int>> trainingRowsQueue;
[15830]220      [Storable]
[17159]221      private IReadOnlyList<int>[] storablepruningRowsQueue { get { return pruningRowsQueue.ToArray(); } set { pruningRowsQueue = new Queue<IReadOnlyList<int>>(value); } }
222      public Queue<IReadOnlyList<int>> pruningRowsQueue;
[15830]223
224      //State.Code values denote the current action (for pausing)
225      //0...nothing has been done;
226      //1...building Models;
227      //2...assigning threshold
228      //3...adjusting threshold
229      //4...pruning
230      //5...finished
231      [Storable]
232      public int Code = 0;
233
234      #region HLConstructors & Cloning
235      [StorableConstructor]
[16847]236      protected PruningState(StorableConstructorFlag _) : base(_) { }
[15830]237      protected PruningState(PruningState original, Cloner cloner) : base(original, cloner) {
238        modelComplexities = original.modelComplexities.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
239        nodeComplexities = original.nodeComplexities.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
240        pruningSizes = original.pruningSizes.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
241        modelErrors = original.modelErrors.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
242        nodeQueue = new Queue<RegressionNodeModel>(original.nodeQueue.Select(cloner.Clone));
243        trainingRowsQueue = new Queue<IReadOnlyList<int>>(original.trainingRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
244        pruningRowsQueue = new Queue<IReadOnlyList<int>>(original.pruningRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
245        Code = original.Code;
246      }
[17159]247      public PruningState() {
248        modelComplexities = new Dictionary<RegressionNodeModel, int>();
249        nodeComplexities = new Dictionary<RegressionNodeModel, int>();
250        pruningSizes = new Dictionary<RegressionNodeModel, int>();
251        modelErrors = new Dictionary<RegressionNodeModel, double>();
252        nodeQueue = new Queue<RegressionNodeModel>();
253        trainingRowsQueue = new Queue<IReadOnlyList<int>>();
254        pruningRowsQueue = new Queue<IReadOnlyList<int>>();
255      }
[15830]256      public override IDeepCloneable Clone(Cloner cloner) {
257        return new PruningState(this, cloner);
258      }
259      #endregion
260
261      public void FillTopDown(RegressionNodeTreeModel tree) {
262        var helperQueue = new Queue<RegressionNodeModel>();
263        nodeQueue.Clear();
264
265        helperQueue.Enqueue(tree.Root);
266        nodeQueue.Enqueue(tree.Root);
267
268        while (helperQueue.Count != 0) {
269          var n = helperQueue.Dequeue();
270          if (n.IsLeaf) continue;
271          helperQueue.Enqueue(n.Left);
272          helperQueue.Enqueue(n.Right);
273          nodeQueue.Enqueue(n.Left);
274          nodeQueue.Enqueue(n.Right);
275        }
276      }
277
278      public void FillTopDown(RegressionNodeTreeModel tree, IReadOnlyList<int> pruningRows, IReadOnlyList<int> trainingRows, IDataset data) {
279        var helperQueue = new Queue<RegressionNodeModel>();
280        var trainingHelperQueue = new Queue<IReadOnlyList<int>>();
281        var pruningHelperQueue = new Queue<IReadOnlyList<int>>();
282        nodeQueue.Clear();
283        trainingRowsQueue.Clear();
284        pruningRowsQueue.Clear();
285
286        helperQueue.Enqueue(tree.Root);
287
288        trainingHelperQueue.Enqueue(trainingRows);
289        pruningHelperQueue.Enqueue(pruningRows);
290
291        nodeQueue.Enqueue(tree.Root);
292        trainingRowsQueue.Enqueue(trainingRows);
293        pruningRowsQueue.Enqueue(pruningRows);
294
295
296        while (helperQueue.Count != 0) {
297          var n = helperQueue.Dequeue();
298          var p = pruningHelperQueue.Dequeue();
299          var t = trainingHelperQueue.Dequeue();
300          if (n.IsLeaf) continue;
301
302          IReadOnlyList<int> leftPruning, rightPruning;
303          RegressionTreeUtilities.SplitRows(p, data, n.SplitAttribute, n.SplitValue, out leftPruning, out rightPruning);
304          IReadOnlyList<int> leftTraining, rightTraining;
305          RegressionTreeUtilities.SplitRows(t, data, n.SplitAttribute, n.SplitValue, out leftTraining, out rightTraining);
306
307          helperQueue.Enqueue(n.Left);
308          helperQueue.Enqueue(n.Right);
309          trainingHelperQueue.Enqueue(leftTraining);
310          trainingHelperQueue.Enqueue(rightTraining);
311          pruningHelperQueue.Enqueue(leftPruning);
312          pruningHelperQueue.Enqueue(rightPruning);
313
314          nodeQueue.Enqueue(n.Left);
315          nodeQueue.Enqueue(n.Right);
316          trainingRowsQueue.Enqueue(leftTraining);
317          trainingRowsQueue.Enqueue(rightTraining);
318          pruningRowsQueue.Enqueue(leftPruning);
319          pruningRowsQueue.Enqueue(rightPruning);
320        }
321      }
322
323      public void FillBottomUp(RegressionNodeTreeModel tree) {
324        FillTopDown(tree);
325        nodeQueue = new Queue<RegressionNodeModel>(nodeQueue.Reverse());
326      }
327
328      public void FillBottomUp(RegressionNodeTreeModel tree, IReadOnlyList<int> pruningRows, IReadOnlyList<int> trainingRows, IDataset data) {
329        FillTopDown(tree, pruningRows, trainingRows, data);
330        nodeQueue = new Queue<RegressionNodeModel>(nodeQueue.Reverse());
331        pruningRowsQueue = new Queue<IReadOnlyList<int>>(pruningRowsQueue.Reverse());
332        trainingRowsQueue = new Queue<IReadOnlyList<int>>(trainingRowsQueue.Reverse());
333      }
334    }
335  }
336}
Note: See TracBrowser for help on using the repository browser.